Skip to main content

S5 - Simplified State Space Layers for Sequence Modeling - Pytorch

Project description

S5: Simplified State Space Layers for Sequence Modeling

This is a ported version derived from https://github.com/lindermanlab/S5 and https://github.com/kavorite/S5. It includes a bunch of functions ported from jax/lax/flax/whatever since they didn't exist yet.

Jax is required because it relies on the pytree structure but it's not used for any computation. Since version 0.2.0 jax is not required, it's using the pytorch native torch.utils._pytree (this may be incompatible for pytorch future versions). Pytorch 2 or later is required because it makes heavy use of torch.vmap and torch.utils._pytree to substitute it's jax counterpart. Python 3.10 or later is required due to usage of the match keyword

---

Update:

In my experiments it follows the results found in the Hyena Hierarchy (& H3) paper that the state spaces alone lack the recall capabilities required for LLM but seem work well for regular sequence feature extraction and linear complexity.

You can use variable step-size as described in the paper using a 1D tensor for step_scale however this takes a lot of memory due to a lot of intermediate values needing to be held (which I believe is true for the official S5 repo, but not mentioned in the paper unless I missed it).

Install

pip install s5-pytorch 

Example

from s5 import S5, S5Block

# Raw S5 operator
x = torch.rand([2, 256, 32])
model = S5(32, 32)
model(x) # [2, 256, 32]

# S5-former block (S5+FFN-GLU w/ layernorm, dropout & residual)
model = S5Block(32, 32, False)
model(x) # [2, 256, 32]

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

s5-pytorch-0.2.1.tar.gz (22.2 kB view details)

Uploaded Source

File details

Details for the file s5-pytorch-0.2.1.tar.gz.

File metadata

  • Download URL: s5-pytorch-0.2.1.tar.gz
  • Upload date:
  • Size: 22.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.10.10

File hashes

Hashes for s5-pytorch-0.2.1.tar.gz
Algorithm Hash digest
SHA256 b0b07031400369fea45e0e3b91d12ca9bd3b58d67d7b93055768599c676cde35
MD5 58d82d65b6babd352f2ce12dadefe8d3
BLAKE2b-256 c36e83a0ed161a4626263c328fe743e6bff69599bc5733a73f475fe5b444d863

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page