Skip to main content

Simple Flax implementations of long-range sequence models

Project description

Long Range Models

PyPI version Static Badge Static Badge

A collection of simple implementations of long-range sequence models, including LRU, S5, and S4. More implementations to come.

Install

$ pip install long_range_models

Usage

This library offers detailed documentation for every module and layer implemented. Models are created by composing different pieces together. Check out the examples below.

Discrete sequence data

Consider a language model built with an LRU sequence layer and the architecture proposed in the S4 paper:

from functools import partial
import jax.random as jrandom
from long_range_models import SequenceModel, S4Module, LRULayer

rng = jrandom.PRNGKey(0)

model = SequenceModel(
  num_tokens=1000,
  module=S4Module(
    sequence_layer=partial(LRULayer, state_dim=256),
    dim=128,
    depth=6,
  ),
)

x = jrandom.randint(rng, (1, 1024), 0, 1000)

variables = model.init(rng, x)
model.apply(variables, x)  # (1, 1024, 1000)

Continuous sequence data

For sequences with continuous values, the setup looks as follows:

from functools import partial
import jax.random as jrandom
from long_range_models import ContinuousSequenceModel, S4Module, LRULayer

rng = jrandom.PRNGKey(0)

model = ContinuousSequenceModel(
  out_dim=10,
  module=S4Module(
    sequence_layer=partial(LRULayer, state_dim=256),
    dim=128,
    depth=6,
  ),
)

x = jrandom.normal(rng, (1, 1024, 32))

variables = model.init(rng, x)
model.apply(variables, x)  # (1, 1024, 10)

Note: both model types offer several customization options. Make sure to check out their documentation.

Upcoming features

  • More implementations: Extend the library with models like S4D, S4Liquid, BiGS, Hyena, RetNet, SGConv, H3, and others.
  • Customization: Allow users to better customize currently implemented layers and architectures (e.g., activation functions, initialization, etc.).
  • Sequential API: Allow recurrent models to run sequentially, allowing for efficient inference.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

long_range_models-0.0.1.tar.gz (9.5 kB view hashes)

Uploaded Source

Built Distribution

long_range_models-0.0.1-py3-none-any.whl (10.9 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page