Skip to main content

ESM2 implemented in Equinox+JAX.

Project description

ESM2quinox

An implementation of ESM2 in Equinox+JAX

Installation

pip install esm2quinox

Public API

See their docstrings for details:

esm2quinox
    .ESM2
        .__init__(self, num_layers: int, embed_size: int, num_heads: int, token_dropout: bool, key: PRNGKeyArray)
        .__call__(self, tokens: Int[np.ndarray | jax.Array, " length"]) -> esm2quinox.ESM2Result

    .ESM2Result
        .hidden: Float[Array, "length embed_size"]
        .logits: Float[Array, "length alphabet_size"]

    .tokenise(proteins: list[str], length: None | int = None, key: None | PRNGKeyArray = None)

    .from_torch(torch_esm2: esm.ESM2) -> esm2quinox.ESM2

Quick examples

Load an equivalent pretrained model from PyTorch:

import esm  # pip install fair-esm==2.0.0
import esm2quinox

torch_model, _ = esm.pretrained.esm2_t6_8M_UR50D()
model = esm2quinox.from_torch(torch_model)

Create a randomly-initialised model:

import esm2quinox
import jax.random as jr

key = jr.key(1337)
model = esm2quinox.ESM2(num_layers=3, embed_size=32, num_heads=2, token_dropout=False, key=key)

Forward pass (note the model operates on unbatched data):

proteins = esm2quinox.tokenise(["SPIDERMAN", "FOO"])
out = jax.vmap(model)(proteins)
out.hidden  # hidden representation from last layer
out.logits  # logits for masked positions

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

esm2quinox-0.2.0.tar.gz (12.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

esm2quinox-0.2.0-py3-none-any.whl (17.1 kB view details)

Uploaded Python 3

File details

Details for the file esm2quinox-0.2.0.tar.gz.

File metadata

  • Download URL: esm2quinox-0.2.0.tar.gz
  • Upload date:
  • Size: 12.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.12

File hashes

Hashes for esm2quinox-0.2.0.tar.gz
Algorithm Hash digest
SHA256 12c11632f0b3efd3e3a4f94b031efb3f50f010f8fcbddd9a95b38da716bbb983
MD5 16c88d7c4ed6331ab339526b62352d24
BLAKE2b-256 e487d48d18a1faa94ee4560f14cfc559413fcfc610854c29af897bcf9fb91cf9

See more details on using hashes here.

File details

Details for the file esm2quinox-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: esm2quinox-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 17.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.12

File hashes

Hashes for esm2quinox-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ac89d44ba1e141a60de1a76e10b345fab70d629d7c438031a76f53bd405fc244
MD5 e0c7fc48b3514f56e7e3ad2035fc3580
BLAKE2b-256 bc34b72aeafa3521e42e300117c1d1cb749a6db534ae88d9c5b28c1dd054001a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page