Skip to main content

Discrete Continuous Embed Readout

Project description

Discrete Continuous Embed Readout

Embedding and readout for simple categorical and gaussian distributions, from the language model to sophisticated robotic action spaces

Install

pip install discrete-continuous-embed-readout

Usage

Discrete

For standard autoregressive language modeling or discrete action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 20000              # vocabulary size
)

embed, readout = embed_readout

# 2. Embed

ids = torch.randint(0, 20000, (2, 1024))

embeds = embed(ids) # (2, 1024, 512)

# ... pass through your transformer / network ...

# 3. Readout

logits = readout(embeds) # (2, 1024, 20000)

# Calculate loss (automatically handles cross entropy)

labels = torch.randint(0, 20000, (2, 1024))

loss = readout(embeds, labels, return_loss = True)
loss.backward()

# Sampling and other utilities

sampled = readout.sample(logits)                # (2, 1024)
log_probs = readout.log_prob(logits, sampled)   # (2, 1024)
entropy = readout.entropy(logits)               # (2, 1024)

Continuous

For continuous control or regression tasks.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_continuous = 4,              # 4 continuous dimensions
    continuous_mean_std = torch.ones(4, 2) # optional mean and std for normalization
)

embed, readout = embed_readout

# 2. Embed

values = torch.randn(2, 1024, 4)

embeds = embed(values) # (2, 1024, 512)

# ... pass through network ...

# 3. Readout (returns distinct Gaussian parameters)

dist_params = readout(embeds) # (2, 1024, 4, 2) - mean and log var

# Loss (Gaussian NLL)

targets = torch.randn(2, 1024, 4)

loss = readout(embeds, targets, return_loss = True)
loss.backward()

# Sampling

sampled = readout.sample(dist_params)               # (2, 1024, 4)

Mixed Discrete and Continuous

For complex environments with both discrete and continuous action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 100,
    num_continuous = 4
)

embed, readout = embed_readout

# 2. Embed inputs (passed as tuple)

discrete_in = torch.randint(0, 100, (2, 32))
continuous_in = torch.randn(2, 32, 4)

embeds = embed((discrete_in, continuous_in)) # (2, 32, 512)

# ... network ...

# 3. Readout

output = readout(embeds)

# Access individual logits/params
print(output.discrete.shape)   # (2, 32, 100)
print(output.continuous.shape) # (2, 32, 4, 2)

# Sampling returns tuple

sampled_discrete, sampled_continuous = readout.sample(output)

Multi-Discrete

For action spaces with multiple independent discrete actions.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = (10, 5, 8),    # 3 independent discrete actions
    use_parallel_multi_discrete = True # optimized parallel processing
)

embed, readout = embed_readout

# Input shape: (batch, seq, 3)
action_indices = torch.randint(0, 5, (2, 16, 3))

embeds = embed(action_indices)

# Readout returns list of logits if not using parallel optimization, or a special structure if so.
# However, the wrapper handles it seamlessly.

logits = readout(embeds)
sampled = readout.sample(logits) # (2, 16, 3)

Runtime Selectors

You can also define inputs dynamically at runtime if your architecture shares embeddings across different modalities.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize with the total capacity of the system
#    e.g. 10 discrete embeddings total, 5 continuous dimensions total

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 10,
    num_continuous = 5,
    continuous_mean_std = torch.ones(5, 2) # normalization for continuous
)

embed, readout = embed_readout

# 2. Define a Runtime Schema (Selector Config)
#    This defines which specific embeddings this particular input uses.
#    For example, this input uses discrete indices 0, 1, 2 and continuous indices 0, 1.

discrete_config = [[0, 1, 2]] # List of lists (for potentially multiple discrete groups)
continuous_config = [0, 1]    # List of indices
selector_config = (discrete_config, continuous_config)

# 3. Create Inputs that match the schema
#    Discrete: 3 values (ranges matching the config is handled by index looking up the config)
#    Continuous: 2 values

# (Batch, Seq) - values must be valid for the local schema size (3)
discrete_input = torch.randint(0, 3, (2, 32))

# (Batch, Seq, 2)
continuous_input = torch.randn(2, 32, 2)

# 4. Embed with the specific selector config

embeds = embed(
    (discrete_input, continuous_input),
    selector_config = selector_config
)

# 5. Readout with the same selector config

logits = readout(
    embeds,
    selector_config = selector_config
)

# logits will be a NamedTuple with .discrete and .continuous matching the config
print(logits.discrete.shape)   # (2, 32, 3) - matches discrete_config size
print(logits.continuous.shape) # (2, 32, 2, 2) - matches continuous_config size

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

discrete_continuous_embed_readout-0.1.12.tar.gz (18.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file discrete_continuous_embed_readout-0.1.12.tar.gz.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.1.12.tar.gz
Algorithm Hash digest
SHA256 b0176307599fed14d1ea0fa997542bcea96cc32d80a7566a38b4cc3094a96654
MD5 4ad34c9780900ef9aa8428dbbb278a38
BLAKE2b-256 e1521c99548b1fa131026c50e4bd8c5947ad4beb506439be93e07ac3ad3b952a

See more details on using hashes here.

File details

Details for the file discrete_continuous_embed_readout-0.1.12-py3-none-any.whl.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.1.12-py3-none-any.whl
Algorithm Hash digest
SHA256 f704a703b2ea865495b3f0defd6ceffbd4a8603533ddad9fcd96eb512eeacd2d
MD5 9def7fcdc2f7ae165c016b36de450f37
BLAKE2b-256 6a68b75337fe80ffa7c8ef6c403a526013c26d4922a57bf5228fb89766d4ea25

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page