MambaTransformer - Pytorch
Project description
Mamba Transformer
Integrating Mamba/SSMs with Transformer for Enhanced Long Context and High-Quality Sequence Modeling.
This is 100% novel architecture that I have designed to combine the strengths and weaknesses out of SSMs and Attention for an all-new advanced architecture with the purpose of surpassing our old limits. Faster processing speed, longer context lengths, lower perplexity over long sequences, enhanced and superior reasoning while remaining small and compact.
The architecture is essentially: x -> norm -> mamba -> norm -> transformer -> norm -> ffn -> norm -> out
.
I added in many normalizations as I believe by default training stability would be severly degraded due to 2 foreign architecture's integrating with one another.
Install
pip3 install mambatransformer
Usage
import torch
from mamba_transformer import MambaTransformer
# Generate a random tensor of shape (1, 10) with values between 0 and 99
x = torch.randint(0, 100, (1, 10))
# Create an instance of the MambaTransformer model
model = MambaTransformer(
num_tokens=100, # Number of tokens in the input sequence
dim=512, # Dimension of the model
heads=8, # Number of attention heads
depth=4, # Number of transformer layers
dim_head=64, # Dimension of each attention head
d_state=512, # Dimension of the state
dropout=0.1, # Dropout rate
ff_mult=4, # Multiplier for the feed-forward layer dimension
return_embeddings=False, # Whether to return the embeddings,
transformer_depth=2, # Number of transformer blocks
mamba_depth=10, # Number of Mamba blocks,
use_linear_attn=True, # Whether to use linear attention
)
# Pass the input tensor through the model and print the output shape
out = model(x)
print(out.shape)
# to train
model.eval()
# Would you like to train this model? Zeta Corporation offers unmatchable GPU clusters at unbeatable prices, let's partner!
# Tokenizer
model.generate(text)
License
MIT
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mambatransformer-0.0.4.tar.gz
.
File metadata
- Download URL: mambatransformer-0.0.4.tar.gz
- Upload date:
- Size: 5.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e184403a7b76210f5b1352244a7e907d2a4adf950ce63e72bba4660a5300a30 |
|
MD5 | 9b29ce77860fe92ccd09e86dda0187ae |
|
BLAKE2b-256 | c89992c39da8c4038b2ebb78628ae821c3f1f63a654bf6854453fa2e924a207e |
File details
Details for the file mambatransformer-0.0.4-py3-none-any.whl
.
File metadata
- Download URL: mambatransformer-0.0.4-py3-none-any.whl
- Upload date:
- Size: 6.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 83067b44708ab0f56fd6a02e03acde2df87f929061a7e022160e2c8466f3c673 |
|
MD5 | 8544846ff9a4958529f9ca98de2623c5 |
|
BLAKE2b-256 | fc24fd3afd92f98e8c2f084ef00e70a4bdc2c3bb5467d3578d4f76e76396096a |