Discrete Continuous Embed Readout
Project description
Discrete Continuous Embed Readout
Embedding and readout for simple categorical and gaussian distributions, from the language model to sophisticated robotic action spaces
Install
pip install discrete-continuous-embed-readout
Usage
Discrete
For standard autoregressive language modeling or discrete action spaces.
import torch
from discrete_continuous_embed_readout import EmbedAndReadout
# 1. Initialize
embed_readout = EmbedAndReadout(
dim = 512,
num_discrete = 20000 # vocabulary size
)
embed, readout = embed_readout
# 2. Embed
ids = torch.randint(0, 20000, (2, 1024))
embeds = embed(ids) # (2, 1024, 512)
# ... pass through your transformer / network ...
# 3. Readout
logits = readout(embeds) # (2, 1024, 20000)
# Calculate loss (automatically handles cross entropy)
labels = torch.randint(0, 20000, (2, 1024))
loss = readout(embeds, labels, return_loss = True)
loss.backward()
# Sampling and other utilities
sampled = readout.sample(logits) # (2, 1024)
log_probs = readout.log_prob(logits, sampled) # (2, 1024)
entropy = readout.entropy(logits) # (2, 1024)
Continuous
For continuous control or regression tasks.
import torch
from discrete_continuous_embed_readout import EmbedAndReadout
# 1. Initialize
embed_readout = EmbedAndReadout(
dim = 512,
num_continuous = 4, # 4 continuous dimensions
continuous_mean_std = torch.ones(4, 2) # optional mean and std for normalization
)
embed, readout = embed_readout
# 2. Embed
values = torch.randn(2, 1024, 4)
embeds = embed(values) # (2, 1024, 512)
# ... pass through network ...
# 3. Readout (returns distinct Gaussian parameters)
dist_params = readout(embeds) # (2, 1024, 4, 2) - mean and log var
# Loss (Gaussian NLL)
targets = torch.randn(2, 1024, 4)
loss = readout(embeds, targets, return_loss = True)
loss.backward()
# Sampling
sampled = readout.sample(dist_params) # (2, 1024, 4)
Mixed Discrete and Continuous
For complex environments with both discrete and continuous action spaces.
import torch
from discrete_continuous_embed_readout import EmbedAndReadout
# 1. Initialize
embed_readout = EmbedAndReadout(
dim = 512,
num_discrete = 100,
num_continuous = 4
)
embed, readout = embed_readout
# 2. Embed inputs (passed as tuple)
discrete_in = torch.randint(0, 100, (2, 32))
continuous_in = torch.randn(2, 32, 4)
embeds = embed((discrete_in, continuous_in)) # (2, 32, 512)
# ... network ...
# 3. Readout
output = readout(embeds)
# Access individual logits/params
print(output.discrete.shape) # (2, 32, 100)
print(output.continuous.shape) # (2, 32, 4, 2)
# Sampling returns tuple
sampled_discrete, sampled_continuous = readout.sample(output)
Multi-Discrete
For action spaces with multiple independent discrete actions.
import torch
from discrete_continuous_embed_readout import EmbedAndReadout
embed_readout = EmbedAndReadout(
dim = 512,
num_discrete = (10, 5, 8), # 3 independent discrete actions
use_parallel_multi_discrete = True # optimized parallel processing
)
embed, readout = embed_readout
# Input shape: (batch, seq, 3)
action_indices = torch.randint(0, 5, (2, 16, 3))
embeds = embed(action_indices)
# Readout returns list of logits if not using parallel optimization, or a special structure if so.
# However, the wrapper handles it seamlessly.
logits = readout(embeds)
sampled = readout.sample(logits) # (2, 16, 3)
Runtime Selectors
You can also define inputs dynamically at runtime if your architecture shares embeddings across different modalities.
import torch
from discrete_continuous_embed_readout import EmbedAndReadout
# 1. Initialize with the total capacity of the system
# e.g. 10 discrete embeddings total, 5 continuous dimensions total
embed_readout = EmbedAndReadout(
dim = 512,
num_discrete = 10,
num_continuous = 5,
continuous_mean_std = torch.ones(5, 2) # normalization for continuous
)
embed, readout = embed_readout
# 2. Define a Runtime Schema (Selector Config)
# This defines which specific embeddings this particular input uses.
# For example, this input uses discrete indices 0, 1, 2 and continuous indices 0, 1.
discrete_config = [[0, 1, 2]] # List of lists (for potentially multiple discrete groups)
continuous_config = [0, 1] # List of indices
selector_config = (discrete_config, continuous_config)
# 3. Create Inputs that match the schema
# Discrete: 3 values (ranges matching the config is handled by index looking up the config)
# Continuous: 2 values
# (Batch, Seq) - values must be valid for the local schema size (3)
discrete_input = torch.randint(0, 3, (2, 32))
# (Batch, Seq, 2)
continuous_input = torch.randn(2, 32, 2)
# 4. Embed with the specific selector config
embeds = embed(
(discrete_input, continuous_input),
selector_config = selector_config
)
# 5. Readout with the same selector config
logits = readout(
embeds,
selector_config = selector_config
)
# logits will be a NamedTuple with .discrete and .continuous matching the config
print(logits.discrete.shape) # (2, 32, 3) - matches discrete_config size
print(logits.continuous.shape) # (2, 32, 2, 2) - matches continuous_config size
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file discrete_continuous_embed_readout-0.1.2.tar.gz.
File metadata
- Download URL: discrete_continuous_embed_readout-0.1.2.tar.gz
- Upload date:
- Size: 17.9 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0d41e69db97171c61b591761e23b5b69841deecd0eb028a56a179c527f087dd5
|
|
| MD5 |
39a451f442d074f4fe5d8a40c4b51ace
|
|
| BLAKE2b-256 |
9e70dc2f3ba853fc4c0e59e86b3b0afd585e80a05dbf4bd7747c7fe0813abf36
|
File details
Details for the file discrete_continuous_embed_readout-0.1.2-py3-none-any.whl.
File metadata
- Download URL: discrete_continuous_embed_readout-0.1.2-py3-none-any.whl
- Upload date:
- Size: 13.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.25
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0ca3aad0902154f2a658097dea4d0ed14b60083b04c0a53371eb6bb9cc462105
|
|
| MD5 |
5a060d82abd2bd395f4ac4981f5ca309
|
|
| BLAKE2b-256 |
7416f39e9ee6d87503a19ff7f47d0cc39fcf9978cd9d9f427cbf35ee44e3f650
|