Skip to main content

Simple Flax implementations of long-range sequence models

Project description

Long Range Models

PyPI version Static Badge Static Badge

A collection of simple implementations of long-range sequence models, including LRU, S5, and S4. More implementations to come.

Install

$ pip install long_range_models

Usage

This library offers detailed documentation for every module and layer implemented. Models are created by composing different pieces together. Check out the examples below.

Discrete sequence data

Consider a language model built with an LRU sequence layer and the architecture proposed in the S4 paper:

from functools import partial
import jax.random as jrandom
from long_range_models import SequenceModel, S4Module, LRULayer

rng = jrandom.PRNGKey(0)

model = SequenceModel(
  num_tokens=1000,
  module=S4Module(
    sequence_layer=partial(LRULayer, state_dim=256),
    dim=128,
    depth=6,
  ),
)

x = jrandom.randint(rng, (1, 1024), 0, 1000)

variables = model.init(rng, x)
model.apply(variables, x)  # (1, 1024, 1000)

Continuous sequence data

For sequences with continuous values, the setup looks as follows:

from functools import partial
import jax.random as jrandom
from long_range_models import ContinuousSequenceModel, S4Module, LRULayer

rng = jrandom.PRNGKey(0)

model = ContinuousSequenceModel(
  out_dim=10,
  module=S4Module(
    sequence_layer=partial(LRULayer, state_dim=256),
    dim=128,
    depth=6,
  ),
)

x = jrandom.normal(rng, (1, 1024, 32))

variables = model.init(rng, x)
model.apply(variables, x)  # (1, 1024, 10)

Note: both model types offer several customization options. Make sure to check out their documentation.

Upcoming features

  • More implementations: Extend the library with models like S4D, S4Liquid, BiGS, Hyena, RetNet, SGConv, H3, and others.
  • Customization: Allow users to better customize currently implemented layers and architectures (e.g., activation functions, initialization, etc.).
  • Sequential API: Allow recurrent models to run sequentially, allowing for efficient inference.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

long_range_models-0.0.1.tar.gz (9.5 kB view details)

Uploaded Source

Built Distribution

long_range_models-0.0.1-py3-none-any.whl (10.9 kB view details)

Uploaded Python 3

File details

Details for the file long_range_models-0.0.1.tar.gz.

File metadata

  • Download URL: long_range_models-0.0.1.tar.gz
  • Upload date:
  • Size: 9.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for long_range_models-0.0.1.tar.gz
Algorithm Hash digest
SHA256 8e58431831f8c08ea4d6c7f75a6fccd1274cdb046f28568787535ec140904f42
MD5 3c38eab5b1d81ebca87865c2e4f2d16e
BLAKE2b-256 f0c909c3991abdad63eae9648a07c04bfa9fcf89f223670e9215a95d214cb602

See more details on using hashes here.

File details

Details for the file long_range_models-0.0.1-py3-none-any.whl.

File metadata

File hashes

Hashes for long_range_models-0.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 0a181dc51f44b79e6d375e5d758c650ae4660ab90b9300f053d19ebe2f07e0d6
MD5 c90f6b9c4230140ba6a9f18b73a88dfc
BLAKE2b-256 c2ffd0ab4547344bb3d6fcf90cad0e96a1c820d421ea0feb9b7afd288ff0b493

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page