Simple Flax implementations of long-range sequence models
Project description
Long Range Models
A collection of simple implementations of long-range sequence models, including LRU, S5, and S4. More implementations to come.
Install
$ pip install long_range_models
Usage
This library offers detailed documentation for every module and layer implemented. Models are created by composing different pieces together. Check out the examples below.
Discrete sequence data
Consider a language model built with an LRU sequence layer and the architecture proposed in the S4 paper:
from functools import partial
import jax.random as jrandom
from long_range_models import SequenceModel, S4Module, LRULayer
rng = jrandom.PRNGKey(0)
model = SequenceModel(
num_tokens=1000,
module=S4Module(
sequence_layer=partial(LRULayer, state_dim=256),
dim=128,
depth=6,
),
)
x = jrandom.randint(rng, (1, 1024), 0, 1000)
variables = model.init(rng, x)
model.apply(variables, x) # (1, 1024, 1000)
Continuous sequence data
For sequences with continuous values, the setup looks as follows:
from functools import partial
import jax.random as jrandom
from long_range_models import ContinuousSequenceModel, S4Module, LRULayer
rng = jrandom.PRNGKey(0)
model = ContinuousSequenceModel(
out_dim=10,
module=S4Module(
sequence_layer=partial(LRULayer, state_dim=256),
dim=128,
depth=6,
),
)
x = jrandom.normal(rng, (1, 1024, 32))
variables = model.init(rng, x)
model.apply(variables, x) # (1, 1024, 10)
Note: both model types offer several customization options. Make sure to check out their documentation.
Upcoming features
- More implementations: Extend the library with models like S4D, S4Liquid, BiGS, Hyena, RetNet, SGConv, H3, and others.
- Customization: Allow users to better customize currently implemented layers and architectures (e.g., activation functions, initialization, etc.).
- Sequential API: Allow recurrent models to run sequentially, allowing for efficient inference.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file long_range_models-0.0.1.tar.gz
.
File metadata
- Download URL: long_range_models-0.0.1.tar.gz
- Upload date:
- Size: 9.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8e58431831f8c08ea4d6c7f75a6fccd1274cdb046f28568787535ec140904f42 |
|
MD5 | 3c38eab5b1d81ebca87865c2e4f2d16e |
|
BLAKE2b-256 | f0c909c3991abdad63eae9648a07c04bfa9fcf89f223670e9215a95d214cb602 |
File details
Details for the file long_range_models-0.0.1-py3-none-any.whl
.
File metadata
- Download URL: long_range_models-0.0.1-py3-none-any.whl
- Upload date:
- Size: 10.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.9.18
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0a181dc51f44b79e6d375e5d758c650ae4660ab90b9300f053d19ebe2f07e0d6 |
|
MD5 | c90f6b9c4230140ba6a9f18b73a88dfc |
|
BLAKE2b-256 | c2ffd0ab4547344bb3d6fcf90cad0e96a1c820d421ea0feb9b7afd288ff0b493 |