Skip to main content

Discrete Continuous Embed Readout

Project description

Discrete Continuous Embed Readout

Embedding and readout for simple categorical and gaussian distributions, from the language model to sophisticated robotic action spaces

Install

pip install discrete-continuous-embed-readout

Usage

Discrete

For standard autoregressive language modeling or discrete action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 20000              # vocabulary size
)

embed, readout = embed_readout

# 2. Embed

ids = torch.randint(0, 20000, (2, 1024))

embeds = embed(ids) # (2, 1024, 512)

# ... pass through your transformer / network ...

# 3. Readout

logits = readout(embeds) # (2, 1024, 20000)

# Calculate loss (automatically handles cross entropy)

labels = torch.randint(0, 20000, (2, 1024))

loss = readout(embeds, labels, return_loss = True)
loss.backward()

# Sampling and other utilities

sampled = readout.sample(logits)                # (2, 1024)
log_probs = readout.log_prob(logits, sampled)   # (2, 1024)
entropy = readout.entropy(logits)               # (2, 1024)

Continuous

For continuous control or regression tasks.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_continuous = 4,              # 4 continuous dimensions
    continuous_mean_std = torch.ones(4, 2) # optional mean and std for normalization
)

embed, readout = embed_readout

# 2. Embed

values = torch.randn(2, 1024, 4)

embeds = embed(values) # (2, 1024, 512)

# ... pass through network ...

# 3. Readout (returns distinct Gaussian parameters)

dist_params = readout(embeds) # (2, 1024, 4, 2) - mean and log var

# Loss (Gaussian NLL)

targets = torch.randn(2, 1024, 4)

loss = readout(embeds, targets, return_loss = True)
loss.backward()

# Sampling

sampled = readout.sample(dist_params)               # (2, 1024, 4)

Mixed Discrete and Continuous

For complex environments with both discrete and continuous action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 100,
    num_continuous = 4
)

embed, readout = embed_readout

# 2. Embed inputs (passed as tuple)

discrete_in = torch.randint(0, 100, (2, 32))
continuous_in = torch.randn(2, 32, 4)

embeds = embed((discrete_in, continuous_in)) # (2, 32, 512)

# ... network ...

# 3. Readout

output = readout(embeds)

# Access individual logits/params
print(output.discrete.shape)   # (2, 32, 100)
print(output.continuous.shape) # (2, 32, 4, 2)

# Sampling returns tuple

sampled_discrete, sampled_continuous = readout.sample(output)

Multi-Discrete

For action spaces with multiple independent discrete actions.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = (10, 5, 8),    # 3 independent discrete actions
    use_parallel_multi_discrete = True # optimized parallel processing
)

embed, readout = embed_readout

# Input shape: (batch, seq, 3)
action_indices = torch.randint(0, 5, (2, 16, 3))

embeds = embed(action_indices)

# Readout returns list of logits if not using parallel optimization, or a special structure if so.
# However, the wrapper handles it seamlessly.

logits = readout(embeds)
sampled = readout.sample(logits) # (2, 16, 3)

Runtime Selectors

You can also define inputs dynamically at runtime if your architecture shares embeddings across different modalities.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize with the total capacity of the system
#    e.g. 10 discrete embeddings total, 5 continuous dimensions total

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 10,
    num_continuous = 5,
    continuous_mean_std = torch.ones(5, 2) # normalization for continuous
)

embed, readout = embed_readout

# 2. Define a Runtime Schema (Selector Config)
#    This defines which specific embeddings this particular input uses.
#    For example, this input uses discrete indices 0, 1, 2 and continuous indices 0, 1.

discrete_config = [[0, 1, 2]] # List of lists (for potentially multiple discrete groups)
continuous_config = [0, 1]    # List of indices
selector_config = (discrete_config, continuous_config)

# 3. Create Inputs that match the schema
#    Discrete: 3 values (ranges matching the config is handled by index looking up the config)
#    Continuous: 2 values

# (Batch, Seq) - values must be valid for the local schema size (3)
discrete_input = torch.randint(0, 3, (2, 32))

# (Batch, Seq, 2)
continuous_input = torch.randn(2, 32, 2)

# 4. Embed with the specific selector config

embeds = embed(
    (discrete_input, continuous_input),
    selector_config = selector_config
)

# 5. Readout with the same selector config

logits = readout(
    embeds,
    selector_config = selector_config
)

# logits will be a NamedTuple with .discrete and .continuous matching the config
print(logits.discrete.shape)   # (2, 32, 3) - matches discrete_config size
print(logits.continuous.shape) # (2, 32, 2, 2) - matches continuous_config size

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

discrete_continuous_embed_readout-0.1.4.tar.gz (18.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file discrete_continuous_embed_readout-0.1.4.tar.gz.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.1.4.tar.gz
Algorithm Hash digest
SHA256 c0fe31a7279f866435a92048fee30f32b68c4779df29012861a51a62fe02e7f6
MD5 9f87b1a9d1e6e55c6f7af02f20ed2cda
BLAKE2b-256 dccce42c06de16b9e649e4f21b589871382e5a2e6fd85428485a6b21513cadd6

See more details on using hashes here.

File details

Details for the file discrete_continuous_embed_readout-0.1.4-py3-none-any.whl.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 6162b6bd8d2f1f1a267dc82d74d610ececc51cb259a8a297dda2b5bafb0d163f
MD5 b5f35f8a103211827e0b73b78007e88a
BLAKE2b-256 4e8e4132ae47cfed7d8069ef29a4cf06b0d72d77336d414902e909499c640760

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page