Skip to main content

Discrete Continuous Embed Readout

Project description

Discrete Continuous Embed Readout

Embedding and readout for simple categorical and gaussian distributions, from the language model to sophisticated robotic action spaces

Install

pip install discrete-continuous-embed-readout

Usage

Discrete

For standard autoregressive language modeling or discrete action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 20000              # vocabulary size
)

embed, readout = embed_readout

# 2. Embed

ids = torch.randint(0, 20000, (2, 1024))

embeds = embed(ids) # (2, 1024, 512)

# ... pass through your transformer / network ...

# 3. Readout

logits = readout(embeds) # (2, 1024, 20000)

# Calculate loss (automatically handles cross entropy)

labels = torch.randint(0, 20000, (2, 1024))

loss = readout(embeds, labels, return_loss = True)
loss.backward()

# Sampling and other utilities

sampled = readout.sample(logits)                # (2, 1024)
log_probs = readout.log_prob(logits, sampled)   # (2, 1024)
entropy = readout.entropy(logits)               # (2, 1024)

Continuous

For continuous control or regression tasks.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_continuous = 4,              # 4 continuous dimensions
    continuous_mean_std = torch.ones(4, 2) # optional mean and std for normalization
)

embed, readout = embed_readout

# 2. Embed

values = torch.randn(2, 1024, 4)

embeds = embed(values) # (2, 1024, 512)

# ... pass through network ...

# 3. Readout (returns distinct Gaussian parameters)

dist_params = readout(embeds) # (2, 1024, 4, 2) - mean and log var

# Loss (Gaussian NLL)

targets = torch.randn(2, 1024, 4)

loss = readout(embeds, targets, return_loss = True)
loss.backward()

# Sampling

sampled = readout.sample(dist_params)               # (2, 1024, 4)

Mixed Discrete and Continuous

For complex environments with both discrete and continuous action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 100,
    num_continuous = 4
)

embed, readout = embed_readout

# 2. Embed inputs (passed as tuple)

discrete_in = torch.randint(0, 100, (2, 32))
continuous_in = torch.randn(2, 32, 4)

embeds = embed((discrete_in, continuous_in)) # (2, 32, 512)

# ... network ...

# 3. Readout

output = readout(embeds)

# Access individual logits/params
print(output.discrete.shape)   # (2, 32, 100)
print(output.continuous.shape) # (2, 32, 4, 2)

# Sampling returns tuple

sampled_discrete, sampled_continuous = readout.sample(output)

Multi-Discrete

For action spaces with multiple independent discrete actions.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = (10, 5, 8),    # 3 independent discrete actions
    use_parallel_multi_discrete = True # optimized parallel processing
)

embed, readout = embed_readout

# Input shape: (batch, seq, 3)
action_indices = torch.randint(0, 5, (2, 16, 3))

embeds = embed(action_indices)

# Readout returns list of logits if not using parallel optimization, or a special structure if so.
# However, the wrapper handles it seamlessly.

logits = readout(embeds)
sampled = readout.sample(logits) # (2, 16, 3)

Runtime Selectors

You can also define inputs dynamically at runtime if your architecture shares embeddings across different modalities.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize with the total capacity of the system
#    e.g. 10 discrete embeddings total, 5 continuous dimensions total

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 10,
    num_continuous = 5,
    continuous_mean_std = torch.ones(5, 2) # normalization for continuous
)

embed, readout = embed_readout

# 2. Define a Runtime Schema (Selector Config)
#    This defines which specific embeddings this particular input uses.
#    For example, this input uses discrete indices 0, 1, 2 and continuous indices 0, 1.

discrete_config = [[0, 1, 2]] # List of lists (for potentially multiple discrete groups)
continuous_config = [0, 1]    # List of indices
selector_config = (discrete_config, continuous_config)

# 3. Create Inputs that match the schema
#    Discrete: 3 values (ranges matching the config is handled by index looking up the config)
#    Continuous: 2 values

# (Batch, Seq) - values must be valid for the local schema size (3)
discrete_input = torch.randint(0, 3, (2, 32))

# (Batch, Seq, 2)
continuous_input = torch.randn(2, 32, 2)

# 4. Embed with the specific selector config

embeds = embed(
    (discrete_input, continuous_input),
    selector_config = selector_config
)

# 5. Readout with the same selector config

logits = readout(
    embeds,
    selector_config = selector_config
)

# logits will be a NamedTuple with .discrete and .continuous matching the config
print(logits.discrete.shape)   # (2, 32, 3) - matches discrete_config size
print(logits.continuous.shape) # (2, 32, 2, 2) - matches continuous_config size

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

discrete_continuous_embed_readout-0.1.7.tar.gz (18.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file discrete_continuous_embed_readout-0.1.7.tar.gz.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.1.7.tar.gz
Algorithm Hash digest
SHA256 9ee4365909ccd14e3e9987e5b926cc5d710b994bce8b45035d8101142da6ceac
MD5 4c421ccf0ac25e58ab509886ea01fb6e
BLAKE2b-256 437d9955277c1f0a0415cc77eced84e9090ef12c517e6223fa022f5a5abbcc4e

See more details on using hashes here.

File details

Details for the file discrete_continuous_embed_readout-0.1.7-py3-none-any.whl.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 56f9ad38e3d5f5c060ad02b4d638ccd190edb2954d51e893c066764379db1770
MD5 9a1301545b1acc45def10e1b21f11fbf
BLAKE2b-256 c9b1e3af9bbbfc6dc94a21cbb04dba83c1d8dced4886c615b778f6e00f70cee6

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page