Skip to main content

Discrete Continuous Embed Readout

Project description

Discrete Continuous Embed Readout

Embedding and readout for simple categorical and gaussian distributions, from the language model to sophisticated robotic action spaces

Install

pip install discrete-continuous-embed-readout

Usage

Discrete

For standard autoregressive language modeling or discrete action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 20000              # vocabulary size
)

embed, readout = embed_readout

# 2. Embed

ids = torch.randint(0, 20000, (2, 1024))

embeds = embed(ids) # (2, 1024, 512)

# ... pass through your transformer / network ...

# 3. Readout

logits = readout(embeds) # (2, 1024, 20000)

# Calculate loss (automatically handles cross entropy)

labels = torch.randint(0, 20000, (2, 1024))

loss = readout(embeds, labels, return_loss = True)
loss.backward()

# Sampling and other utilities

sampled = readout.sample(logits)                # (2, 1024)
log_probs = readout.log_prob(logits, sampled)   # (2, 1024)
entropy = readout.entropy(logits)               # (2, 1024)

Continuous

For continuous control or regression tasks.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_continuous = 4,              # 4 continuous dimensions
    continuous_mean_std = torch.ones(4, 2) # optional mean and std for normalization
)

embed, readout = embed_readout

# 2. Embed

values = torch.randn(2, 1024, 4)

embeds = embed(values) # (2, 1024, 512)

# ... pass through network ...

# 3. Readout (returns distinct Gaussian parameters)

dist_params = readout(embeds) # (2, 1024, 4, 2) - mean and log var

# Loss (Gaussian NLL)

targets = torch.randn(2, 1024, 4)

loss = readout(embeds, targets, return_loss = True)
loss.backward()

# Sampling

sampled = readout.sample(dist_params)               # (2, 1024, 4)

Mixed Discrete and Continuous

For complex environments with both discrete and continuous action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 100,
    num_continuous = 4
)

embed, readout = embed_readout

# 2. Embed inputs (passed as tuple)

discrete_in = torch.randint(0, 100, (2, 32))
continuous_in = torch.randn(2, 32, 4)

embeds = embed((discrete_in, continuous_in)) # (2, 32, 512)

# ... network ...

# 3. Readout

output = readout(embeds)

# Access individual logits/params
print(output.discrete.shape)   # (2, 32, 100)
print(output.continuous.shape) # (2, 32, 4, 2)

# Sampling returns tuple

sampled_discrete, sampled_continuous = readout.sample(output)

Multi-Discrete

For action spaces with multiple independent discrete actions.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = (10, 5, 8),    # 3 independent discrete actions
    use_parallel_multi_discrete = True # optimized parallel processing
)

embed, readout = embed_readout

# Input shape: (batch, seq, 3)
action_indices = torch.randint(0, 5, (2, 16, 3))

embeds = embed(action_indices)

# Readout returns list of logits if not using parallel optimization, or a special structure if so.
# However, the wrapper handles it seamlessly.

logits = readout(embeds)
sampled = readout.sample(logits) # (2, 16, 3)

Runtime Selectors

You can also define inputs dynamically at runtime if your architecture shares embeddings across different modalities.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize with the total capacity of the system
#    e.g. 10 discrete embeddings total, 5 continuous dimensions total

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 10,
    num_continuous = 5,
    continuous_mean_std = torch.ones(5, 2) # normalization for continuous
)

embed, readout = embed_readout

# 2. Define a Runtime Schema (Selector Config)
#    This defines which specific embeddings this particular input uses.
#    For example, this input uses discrete indices 0, 1, 2 and continuous indices 0, 1.

discrete_config = [[0, 1, 2]] # List of lists (for potentially multiple discrete groups)
continuous_config = [0, 1]    # List of indices
selector_config = (discrete_config, continuous_config)

# 3. Create Inputs that match the schema
#    Discrete: 3 values (ranges matching the config is handled by index looking up the config)
#    Continuous: 2 values

# (Batch, Seq) - values must be valid for the local schema size (3)
discrete_input = torch.randint(0, 3, (2, 32))

# (Batch, Seq, 2)
continuous_input = torch.randn(2, 32, 2)

# 4. Embed with the specific selector config

embeds = embed(
    (discrete_input, continuous_input),
    selector_config = selector_config
)

# 5. Readout with the same selector config

logits = readout(
    embeds,
    selector_config = selector_config
)

# logits will be a NamedTuple with .discrete and .continuous matching the config
print(logits.discrete.shape)   # (2, 32, 3) - matches discrete_config size
print(logits.continuous.shape) # (2, 32, 2, 2) - matches continuous_config size

Citations

@article{lakshminarayanan2017simple,
    title     = {Simple and scalable predictive uncertainty estimation using deep ensembles},
    author    = {Lakshminarayanan, Balaji and Pritzel, Alexander and Blundell, Charles},
    journal   = {Advances in neural information processing systems},
    volume    = {30},
    year      = {2017}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

discrete_continuous_embed_readout-0.2.3.tar.gz (15.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file discrete_continuous_embed_readout-0.2.3.tar.gz.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.2.3.tar.gz
Algorithm Hash digest
SHA256 d3473de02321ef4911ce457a7bbc96f6466ec0a37f808467e5a541c3d8fc2064
MD5 39c2d2bd52f748cfee766895834f4215
BLAKE2b-256 c1524181b36002220fff1929d4251ad9987770fafea2ccaf978550bdf67a577b

See more details on using hashes here.

File details

Details for the file discrete_continuous_embed_readout-0.2.3-py3-none-any.whl.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 90d306ec2c75c4f91eb087c7f04897da80a42a21a2dad2d23e38ce871638eff1
MD5 96386f1514d3916735d5d4f22d159f39
BLAKE2b-256 888a8051bf00026e60dc14e6b6125017174baef5a3d61574ace5cdbb9a18b745

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page