Skip to main content

Paper - Pytorch

Project description

Multi-Modality

Chai-1

Join our Discord Subscribe on YouTube Connect on LinkedIn Follow on X.com

An free and open source community implementation of Chai-1 in PyTorch. Paper is here

Join our discord to help us implement this paper!

Installation

pip3 install chai-one

Usage

######### example.py
import torch
from loguru import logger
from chai_one.model import ChaiOne

# Set up model parameters
dim_single = 128
dim_pairwise = 128
dim_msa = 128
dim_msa_input = 134  # Adjusted to match the expected input dimension
dim_additional_msa_feats = 2
window_size = 25

# Initialize the model
logger.info("Initializing ChaiOne model")
model = ChaiOne(
    dim_single=dim_single,
    dim_pairwise=dim_pairwise,
    msa_depth=4,
    dim_msa=dim_msa,
    dim_msa_input=dim_msa_input,  # Set to 134
    dim_additional_msa_feats=0,
    msa_pwa_heads=8,
    msa_pwa_dim_head=32,
    layerscale_output=False,
    heads=8,
    window_size=window_size,
    num_memory_kv=0,
    attn_layers=48,
)

# Create dummy input tensors
batch_size = 1
seq_length = 100
num_msa = 4

logger.info(
    f"Creating input tensors with shape: batch_size={batch_size}, seq_length={seq_length}, num_msa={num_msa}"
)
single_repr = torch.randn(batch_size, seq_length, dim_single)
pairwise_repr = torch.randn(
    batch_size, seq_length, seq_length, dim_pairwise
)

# Create msa tensor with matching input size for msa_init_proj (134 features)
msa = torch.randn(
    batch_size, num_msa, seq_length, dim_msa_input
)  # Adjusted to 134

# Forward pass
logger.info("Performing forward pass")
output = model(
    single_repr=single_repr,
    pairwise_repr=pairwise_repr,
    msa=msa,
)

logger.info(f"Output shape: {output.shape}")

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

chai_one-0.0.2.tar.gz (3.8 kB view hashes)

Uploaded Source

Built Distribution

chai_one-0.0.2-py3-none-any.whl (4.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page