Skip to main content

Discrete Continuous Embed Readout

Project description

Discrete Continuous Embed Readout

Embedding and readout for simple categorical and gaussian distributions, from the language model to sophisticated robotic action spaces

Install

pip install discrete-continuous-embed-readout

Usage

Discrete

For standard autoregressive language modeling or discrete action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 20000              # vocabulary size
)

embed, readout = embed_readout

# 2. Embed

ids = torch.randint(0, 20000, (2, 1024))

embeds = embed(ids) # (2, 1024, 512)

# ... pass through your transformer / network ...

# 3. Readout

logits = readout(embeds) # (2, 1024, 20000)

# Calculate loss (automatically handles cross entropy)

labels = torch.randint(0, 20000, (2, 1024))

loss = readout(embeds, labels, return_loss = True)
loss.backward()

# Sampling and other utilities

sampled = readout.sample(logits)                # (2, 1024)
log_probs = readout.log_prob(logits, sampled)   # (2, 1024)
entropy = readout.entropy(logits)               # (2, 1024)

Continuous

For continuous control or regression tasks.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_continuous = 4,              # 4 continuous dimensions
    continuous_mean_std = torch.ones(4, 2) # optional mean and std for normalization
)

embed, readout = embed_readout

# 2. Embed

values = torch.randn(2, 1024, 4)

embeds = embed(values) # (2, 1024, 512)

# ... pass through network ...

# 3. Readout (returns distinct Gaussian parameters)

dist_params = readout(embeds) # (2, 1024, 4, 2) - mean and log var

# Loss (Gaussian NLL)

targets = torch.randn(2, 1024, 4)

loss = readout(embeds, targets, return_loss = True)
loss.backward()

# Sampling

sampled = readout.sample(dist_params)               # (2, 1024, 4)

Mixed Discrete and Continuous

For complex environments with both discrete and continuous action spaces.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 100,
    num_continuous = 4
)

embed, readout = embed_readout

# 2. Embed inputs (passed as tuple)

discrete_in = torch.randint(0, 100, (2, 32))
continuous_in = torch.randn(2, 32, 4)

embeds = embed((discrete_in, continuous_in)) # (2, 32, 512)

# ... network ...

# 3. Readout

output = readout(embeds)

# Access individual logits/params
print(output.discrete.shape)   # (2, 32, 100)
print(output.continuous.shape) # (2, 32, 4, 2)

# Sampling returns tuple

sampled_discrete, sampled_continuous = readout.sample(output)

Multi-Discrete

For action spaces with multiple independent discrete actions.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = (10, 5, 8),    # 3 independent discrete actions
    use_parallel_multi_discrete = True # optimized parallel processing
)

embed, readout = embed_readout

# Input shape: (batch, seq, 3)
action_indices = torch.randint(0, 5, (2, 16, 3))

embeds = embed(action_indices)

# Readout returns list of logits if not using parallel optimization, or a special structure if so.
# However, the wrapper handles it seamlessly.

logits = readout(embeds)
sampled = readout.sample(logits) # (2, 16, 3)

Runtime Selectors

You can also define inputs dynamically at runtime if your architecture shares embeddings across different modalities.

import torch
from discrete_continuous_embed_readout import EmbedAndReadout

# 1. Initialize with the total capacity of the system
#    e.g. 10 discrete embeddings total, 5 continuous dimensions total

embed_readout = EmbedAndReadout(
    dim = 512,
    num_discrete = 10,
    num_continuous = 5,
    continuous_mean_std = torch.ones(5, 2) # normalization for continuous
)

embed, readout = embed_readout

# 2. Define a Runtime Schema (Selector Config)
#    This defines which specific embeddings this particular input uses.
#    For example, this input uses discrete indices 0, 1, 2 and continuous indices 0, 1.

discrete_config = [[0, 1, 2]] # List of lists (for potentially multiple discrete groups)
continuous_config = [0, 1]    # List of indices
selector_config = (discrete_config, continuous_config)

# 3. Create Inputs that match the schema
#    Discrete: 3 values (ranges matching the config is handled by index looking up the config)
#    Continuous: 2 values

# (Batch, Seq) - values must be valid for the local schema size (3)
discrete_input = torch.randint(0, 3, (2, 32))

# (Batch, Seq, 2)
continuous_input = torch.randn(2, 32, 2)

# 4. Embed with the specific selector config

embeds = embed(
    (discrete_input, continuous_input),
    selector_config = selector_config
)

# 5. Readout with the same selector config

logits = readout(
    embeds,
    selector_config = selector_config
)

# logits will be a NamedTuple with .discrete and .continuous matching the config
print(logits.discrete.shape)   # (2, 32, 3) - matches discrete_config size
print(logits.continuous.shape) # (2, 32, 2, 2) - matches continuous_config size

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

discrete_continuous_embed_readout-0.1.11.tar.gz (18.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

File details

Details for the file discrete_continuous_embed_readout-0.1.11.tar.gz.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.1.11.tar.gz
Algorithm Hash digest
SHA256 b905180e50d03dd21dce64119e716dafa07d272df467652328262d26a63ccecc
MD5 91fdb1cd2a85f4027dda12bc14782a24
BLAKE2b-256 fa456411244a6c124859dec9182a8a77a67698da9cd097c485692cbccb937faa

See more details on using hashes here.

File details

Details for the file discrete_continuous_embed_readout-0.1.11-py3-none-any.whl.

File metadata

File hashes

Hashes for discrete_continuous_embed_readout-0.1.11-py3-none-any.whl
Algorithm Hash digest
SHA256 d16e04fd4919084df53bd2dc8d7cb81c0a92e0772d97d330c070b2504a1dd97a
MD5 73811fe2486668a0b6d14f6f05f40407
BLAKE2b-256 bcb96f876405a41fe3ef7e314c5cbbbae9540abb735ceb26618e15608effe59b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page