Python library for RAD Embeddings, provably correct latent DFA representations.
Project description
This repo contains a JAX implementation of RAD embeddings, see project webpage for more information.
Installation
Install using pip.
pip install rad-embeddings # CPU-only
pip install rad-embeddings[cuda] # With CUDA
Usage
Instantiate a pretrained encoder.
from rad_embeddings import Encoder
encoder = Encoder() # Loads a pretrained DFA encoder with default parameters: handles at most 10-state DFAs with 10-token alphabets
Use DFAx to sample DFAs.
import jax
from dfax.samplers import ReachSampler, ReachAvoidSampler, RADSampler
sampler = RADSampler()
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
dfax = sampler.sample(subkey)
Pass the DFA to the encoder to get its embedding — both DFAx and DFA objects are supported.
from dfax import dfax2dfa
rad_embed_from_dfax = encoder(dfax)
rad_embed_from_dfa = encoder(dfax2dfa(dfax)) # Returns the same embedding
Compute bisimulation distance between two DFA embeddings.
key, subkey = jax.random.split(key)
dfax_l = sampler.sample(subkey)
rad_l = encoder(dfax_l)
key, subkey = jax.random.split(key)
dfax_r = sampler.sample(subkey)
rad_r = encoder(dfax_r)
distance = encoder.distance(rad_l, rad_r)
Solve a one-step bisimulation problem.
from dfax import DFAx
import jax.numpy as jnp
# Reach token 1 and then token 2 while avoding token 9
dfa_l = DFAx.create(
start = 0,
transitions = jnp.array([
[0, 1, 0, 0, 0, 0, 0, 0, 0, 3],
[1, 1, 2, 1, 1, 1, 1, 1, 1, 3],
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
]),
labels = jnp.array([False, False, True, False, False])
)
# Reach token 9
dfa_r = DFAx.create(
start = 0,
transitions = jnp.array([
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
]),
labels = jnp.array([False, True, False, False, False])
)
distinguishing_action = encoder.solve(dfa_l, dfa_r) # Returns 1 as token 1 is the one-step distinguishing action
Train a new encoder — see encoder.py for the default arguments of the EncoderModule.train class method.
from rad_embeddings import EncoderModule
EncoderModule.train(max_size=5, n_tokens=5, debug=True, save_dir="my_storage")
Load your trained parameters.
from rad_embeddings import Encoder
encoder = Encoder(max_size=5, n_tokens=5, storage_dir="my_storage")
Citation
Please cite the following papers if you use RAD Embeddings in your work.
@inproceedings{DBLP:conf/nips/YalcinkayaLVS24,
author = {Beyazit Yalcinkaya and
Niklas Lauffer and
Marcell Vazquez{-}Chanlatte and
Sanjit A. Seshia},
title = {Compositional Automata Embeddings for Goal-Conditioned Reinforcement
Learning},
booktitle = {NeurIPS},
year = {2024}
}
@inproceedings{DBLP:conf/neus/YalcinkayaLVS25,
author = {Beyazit Yalcinkaya and
Niklas Lauffer and
Marcell Vazquez{-}Chanlatte and
Sanjit A. Seshia},
title = {Provably Correct Automata Embeddings for Optimal Automata-Conditioned
Reinforcement Learning},
booktitle = {NeuS},
series = {Proceedings of Machine Learning Research},
volume = {288},
pages = {661--675},
publisher = {{PMLR}},
year = {2025}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file rad_embeddings-0.2.8.tar.gz.
File metadata
- Download URL: rad_embeddings-0.2.8.tar.gz
- Upload date:
- Size: 17.0 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
d86bbdeb4caea3018718fd31a21282cae839678dc0d2b72792c53343de2fef31
|
|
| MD5 |
32dd6e2d40763e5d19342cf002de2683
|
|
| BLAKE2b-256 |
8fe9d661bc3c703cc86a1363d56a2d13f6cd9ed821fcfe6bf036dfb74cac3574
|
File details
Details for the file rad_embeddings-0.2.8-py3-none-any.whl.
File metadata
- Download URL: rad_embeddings-0.2.8-py3-none-any.whl
- Upload date:
- Size: 16.9 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.5.18
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
10d06b1a607f8d134a82e67c5bf6b754b8bb664db83a53dab257410f6ef0d41e
|
|
| MD5 |
e1d8615e620122316443e3c4bb14c40c
|
|
| BLAKE2b-256 |
eb4e5324ce9b889844c95afdd85483ee0a77da9706fde0e99ac38f4bf7106b39
|