jamba - Pytorch
Project description
Jamba
PyTorch Implementation of Jamba: "Jamba: A Hybrid Transformer-Mamba Language Model"
install
$ pip install jamba
usage
import torch
from jamba.model import JambaBlock
# Create a random tensor of shape (1, 128, 512)
x = torch.randn(1, 128, 512)
# Create an instance of the JambaBlock class
jamba = JambaBlock(
512, # input channels
128, # hidden channels
128, # key channels
8, # number of heads
4, # number of layers
)
# Pass the input tensor through the JambaBlock
output = jamba(x)
# Print the shape of the output tensor
print(output.shape)
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
jamba-0.0.1.tar.gz
(7.2 kB
view hashes)
Built Distribution
jamba-0.0.1-py3-none-any.whl
(7.3 kB
view hashes)