ESM2 implemented in Equinox+JAX.
Project description
ESM2quinox
An implementation of ESM2 in Equinox+JAX
Installation
pip install esm2quinox
Public API
See their docstrings for details:
esm2quinox
.ESM2
.__init__(self, num_layers: int, embed_size: int, num_heads: int, token_dropout: bool, key: PRNGKeyArray)
.__call__(self, tokens: Int[np.ndarray | jax.Array, " length"]) -> esm2quinox.ESM2Result
.ESM2Result
.hidden: Float[Array, "length embed_size"]
.logits: Float[Array, "length alphabet_size"]
.tokenise(proteins: list[str], length: None | int = None, key: None | PRNGKeyArray = None)
.from_torch(torch_esm2: esm.ESM2) -> esm2quinox.ESM2
Quick examples
Load an equivalent pretrained model from PyTorch:
import esm # pip install fair-esm==2.0.0
import esm2quinox
torch_model, _ = esm.pretrained.esm2_t6_8M_UR50D()
model = esm2quinox.from_torch(torch_model)
Create a randomly-initialised model:
import esm2quinox
import jax.random as jr
key = jr.key(1337)
model = esm2quinox.ESM2(num_layers=3, embed_size=32, num_heads=2, token_dropout=False, key=key)
Forward pass (note the model operates on unbatched data):
proteins = esm2quinox.tokenise(["SPIDERMAN", "FOO"])
out = jax.vmap(model)(proteins)
out.hidden # hidden representation from last layer
out.logits # logits for masked positions
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
esm2quinox-0.2.0.tar.gz
(12.4 kB
view details)
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file esm2quinox-0.2.0.tar.gz.
File metadata
- Download URL: esm2quinox-0.2.0.tar.gz
- Upload date:
- Size: 12.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
12c11632f0b3efd3e3a4f94b031efb3f50f010f8fcbddd9a95b38da716bbb983
|
|
| MD5 |
16c88d7c4ed6331ab339526b62352d24
|
|
| BLAKE2b-256 |
e487d48d18a1faa94ee4560f14cfc559413fcfc610854c29af897bcf9fb91cf9
|
File details
Details for the file esm2quinox-0.2.0-py3-none-any.whl.
File metadata
- Download URL: esm2quinox-0.2.0-py3-none-any.whl
- Upload date:
- Size: 17.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ac89d44ba1e141a60de1a76e10b345fab70d629d7c438031a76f53bd405fc244
|
|
| MD5 |
e0c7fc48b3514f56e7e3ad2035fc3580
|
|
| BLAKE2b-256 |
bc34b72aeafa3521e42e300117c1d1cb749a6db534ae88d9c5b28c1dd054001a
|