Guided latent subspace learning for deep neural networks
Project description
SPACE (Subspace Partitioning for Accessible Controlled Encodings)
SPACE-Learn provides a utility library build on pytorch, that offers functions to train your Deep-Learning Models on the Latent-SPACE regime. The SPACE method helps create structured and easy to decode embedding spaces or explicitly guide a models internal latent space using training side information. This method encodes features that are known during training time or intermediate results directly into a models latentspace.
Example Setup
1. Install
pip install spacelearn
2. Example Training Setup
import torch
from torch.optim import Adam
from collections import defaultdict
from spacelearn import (
solve_dims,
solve_subspaces,
combined_loss,
)
from spacelearn.data import input_to_quantity
from spacelearn.optim import minimal_latent_dim
from my_model import Model
from my_data import (
DataLoader,
quantity_helper_a,
quantity_helper_b,
)
EPOCHS = 5
LR = 1e-3
MAX_BINS = 10
# target variance retained by pooling
BIN_THRESHOLDS = 0.80
# target variance retained by subspaces
K_THRESHOLDS = {
"A": 0.90,
"B": 0.85,
}
REEVAL_EVERY = 2
INIT_SAMPLES = 100
dataloader = DataLoader()
# ------------------------------------------------------------------
# 1. Collect data for dimension estimation
# ------------------------------------------------------------------
test_data = defaultdict(list)
samples = []
for i in range(INIT_SAMPLES):
sample = dataloader[i]
test_data["A"].append(
quantity_helper_a(samples=sample)
)
test_data["B"].append(
quantity_helper_b(samples=sample)
)
samples.append(sample)
test_data["A"] = torch.stack(test_data["A"])
test_data["B"] = torch.stack(test_data["B"])
samples = torch.stack(samples)
# ------------------------------------------------------------------
# 2. Estimate n and k
# ------------------------------------------------------------------
dims = solve_dims(
test_data,
max_bins=MAX_BINS,
bin_thresholds=BIN_THRESHOLDS,
k_thresholds=K_THRESHOLDS,
k_per_quantity=True,
)
# shared pooling resolution
N = next(iter(dims.values()))[0]
# per-quantity subspace dimensions
K = {q: k for q, (_, k) in dims.items()}
# latent dimension
D = minimal_latent_dim(
K,
free_frac=0.10,
)
# ------------------------------------------------------------------
# 3. Build quantity extraction helper
# ------------------------------------------------------------------
pool_helper = input_to_quantity(
N,
"avg",
A=quantity_helper_a,
B=quantity_helper_b,
)
# ------------------------------------------------------------------
# 4. Initialize model
# ------------------------------------------------------------------
model = Model(D)
optimizer = Adam(model.parameters(), lr=LR)
# ------------------------------------------------------------------
# 5. Training loop
# ------------------------------------------------------------------
W_prev = None
for epoch in range(EPOCHS):
# periodically recompute subspaces
if epoch % REEVAL_EVERY == 0:
with torch.no_grad():
Z_ref = model(samples)
Y_ref = pool_helper(samples=samples)
WAV = solve_subspaces(
Z_ref,
Y_ref,
k=k,
)
W = WAV.W
A = WAV.A
if W_prev is None:
W_prev = {
q: w.clone()
for q, w in W.items()
}
for batch in dataloader:
Y = pool_helper(samples=batch)
Z = model(batch)
space_loss = combined_loss(
Z,
W,
A,
Y,
W_prev,
)
# task_loss = ...
# loss = task_loss + space_loss
optimizer.zero_grad()
space_loss.backward()
optimizer.step()
W_prev = {
q: w.clone()
for q, w in W.items()
}
DOCS:
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file spacelearn-0.1.0.tar.gz.
File metadata
- Download URL: spacelearn-0.1.0.tar.gz
- Upload date:
- Size: 95.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
277b7725a6db75087b41cd86670a7b79e67f752fa9a2554cdf2dcc7ac6cb5d49
|
|
| MD5 |
b195be6f30eb1eb7f793d245f8de470a
|
|
| BLAKE2b-256 |
7aa9f5a0afe2f21a4dc83462a8bcb618078c7fed525bd774974f7dcdfc98cd26
|
File details
Details for the file spacelearn-0.1.0-py3-none-any.whl.
File metadata
- Download URL: spacelearn-0.1.0-py3-none-any.whl
- Upload date:
- Size: 31.3 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.14.4
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
807abc2e3ef22695c59383027a6a281c18a864753ab44c7da6fc58c10607a266
|
|
| MD5 |
251bf12b346195ccfa585e35103c2a3f
|
|
| BLAKE2b-256 |
e35a302e58fa887f2033bd1922efb30c9fa5da8b6d09d2baba424c1d690dadd3
|