Skip to main content

Guided latent subspace learning for deep neural networks

Project description

README

SPACE (Subspace Partitioning for Accessible Controlled Encodings)

SPACE-Learn provides a utility library build on pytorch, that offers functions to train your Deep-Learning Models on the Latent-SPACE regime. The SPACE method helps create structured and easy to decode embedding spaces or explicitly guide a models internal latent space using training side information. This method encodes features that are known during training time or intermediate results directly into a models latentspace.

Example Setup

1. Install

pip install spacelearn

2. Example Training Setup

import torch
from torch.optim import Adam
from collections import defaultdict

from spacelearn import (
    solve_dims,
    solve_subspaces,
    combined_loss,
)
from spacelearn.data import input_to_quantity
from spacelearn.optim import minimal_latent_dim

from my_model import Model
from my_data import (
    DataLoader,
    quantity_helper_a,
    quantity_helper_b,
)

EPOCHS = 5
LR = 1e-3

MAX_BINS = 10

# target variance retained by pooling
BIN_THRESHOLDS = 0.80

# target variance retained by subspaces
K_THRESHOLDS = {
    "A": 0.90,
    "B": 0.85,
}

REEVAL_EVERY = 2
INIT_SAMPLES = 100

dataloader = DataLoader()

# ------------------------------------------------------------------
# 1. Collect data for dimension estimation
# ------------------------------------------------------------------

test_data = defaultdict(list)
samples = []

for i in range(INIT_SAMPLES):
    sample = dataloader[i]

    test_data["A"].append(
        quantity_helper_a(samples=sample)
    )
    test_data["B"].append(
        quantity_helper_b(samples=sample)
    )

    samples.append(sample)

test_data["A"] = torch.stack(test_data["A"])
test_data["B"] = torch.stack(test_data["B"])

samples = torch.stack(samples)

# ------------------------------------------------------------------
# 2. Estimate n and k
# ------------------------------------------------------------------

dims = solve_dims(
    test_data,
    max_bins=MAX_BINS,
    bin_thresholds=BIN_THRESHOLDS,
    k_thresholds=K_THRESHOLDS,
    k_per_quantity=True,
)

# shared pooling resolution
N = next(iter(dims.values()))[0]

# per-quantity subspace dimensions
K = {q: k for q, (_, k) in dims.items()}

# latent dimension
D = minimal_latent_dim(
    K,
    free_frac=0.10,
)

# ------------------------------------------------------------------
# 3. Build quantity extraction helper
# ------------------------------------------------------------------

pool_helper = input_to_quantity(
    N,
    "avg",
    A=quantity_helper_a,
    B=quantity_helper_b,
)

# ------------------------------------------------------------------
# 4. Initialize model
# ------------------------------------------------------------------

model = Model(D)
optimizer = Adam(model.parameters(), lr=LR)

# ------------------------------------------------------------------
# 5. Training loop
# ------------------------------------------------------------------

W_prev = None

for epoch in range(EPOCHS):

    # periodically recompute subspaces
    if epoch % REEVAL_EVERY == 0:

        with torch.no_grad():
            Z_ref = model(samples)

        Y_ref = pool_helper(samples=samples)

        WAV = solve_subspaces(
            Z_ref,
            Y_ref,
            k=k,
        )

        W = WAV.W
        A = WAV.A

        if W_prev is None:
            W_prev = {
                q: w.clone()
                for q, w in W.items()
            }

    for batch in dataloader:

        Y = pool_helper(samples=batch)
        Z = model(batch)

        space_loss = combined_loss(
            Z,
            W,
            A,
            Y,
            W_prev,
        )

        # task_loss = ...
        # loss = task_loss + space_loss

        optimizer.zero_grad()
        space_loss.backward()
        optimizer.step()

    W_prev = {
        q: w.clone()
        for q, w in W.items()
    }

DOCS:

spacelearn

spacelearn.settings

spacelearn.data

spacelearn.loss

spacelearn.optim

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

spacelearn-0.1.0.tar.gz (95.9 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

spacelearn-0.1.0-py3-none-any.whl (31.3 kB view details)

Uploaded Python 3

File details

Details for the file spacelearn-0.1.0.tar.gz.

File metadata

  • Download URL: spacelearn-0.1.0.tar.gz
  • Upload date:
  • Size: 95.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.4

File hashes

Hashes for spacelearn-0.1.0.tar.gz
Algorithm Hash digest
SHA256 277b7725a6db75087b41cd86670a7b79e67f752fa9a2554cdf2dcc7ac6cb5d49
MD5 b195be6f30eb1eb7f793d245f8de470a
BLAKE2b-256 7aa9f5a0afe2f21a4dc83462a8bcb618078c7fed525bd774974f7dcdfc98cd26

See more details on using hashes here.

File details

Details for the file spacelearn-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: spacelearn-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 31.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.4

File hashes

Hashes for spacelearn-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 807abc2e3ef22695c59383027a6a281c18a864753ab44c7da6fc58c10607a266
MD5 251bf12b346195ccfa585e35103c2a3f
BLAKE2b-256 e35a302e58fa887f2033bd1922efb30c9fa5da8b6d09d2baba424c1d690dadd3

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page