Skip to main content

MambaTransformer - Pytorch

Project description

Multi-Modality

Mamba Transformer

Mamba Transformer

Integrating Mamba/SSMs with Transformer for Enhanced Long Context and High-Quality Sequence Modeling.

This is 100% novel architecture that I have designed to combine the strengths and weaknesses out of SSMs and Attention for an all-new advanced architecture with the purpose of surpassing our old limits. Faster processing speed, longer context lengths, lower perplexity over long sequences, enhanced and superior reasoning while remaining small and compact.

The architecture is essentially: x -> norm -> mamba -> norm -> transformer -> norm -> ffn -> norm -> out.

I added in many normalizations as I believe by default training stability would be severly degraded due to 2 foreign architecture's integrating with one another.

Install

pip3 install mambatransformer

Usage

import torch
from mamba_transformer import MambaTransformer

# Generate a random tensor of shape (1, 10) with values between 0 and 99
x = torch.randint(0, 100, (1, 10))

# Create an instance of the MambaTransformer model
model = MambaTransformer(
    num_tokens=100,  # Number of tokens in the input sequence
    dim=512,  # Dimension of the model
    heads=8,  # Number of attention heads
    depth=4,  # Number of transformer layers
    dim_head=64,  # Dimension of each attention head
    d_state=512,  # Dimension of the state
    dropout=0.1,  # Dropout rate
    ff_mult=4,  # Multiplier for the feed-forward layer dimension
    return_embeddings=False,  # Whether to return the embeddings,
    transformer_depth=2,  # Number of transformer blocks
    mamba_depth=10,  # Number of Mamba blocks,
    use_linear_attn=True,  # Whether to use linear attention
)

# Pass the input tensor through the model and print the output shape
out = model(x)

print(out.shape)


# to train
model.eval()

# Would you like to train this model? Zeta Corporation offers unmatchable GPU clusters at unbeatable prices, let's partner!

# Tokenizer
model.generate(text)

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mambatransformer-0.0.4.tar.gz (5.8 kB view details)

Uploaded Source

Built Distribution

mambatransformer-0.0.4-py3-none-any.whl (6.1 kB view details)

Uploaded Python 3

File details

Details for the file mambatransformer-0.0.4.tar.gz.

File metadata

  • Download URL: mambatransformer-0.0.4.tar.gz
  • Upload date:
  • Size: 5.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.3.2 CPython/3.11.0 Darwin/22.4.0

File hashes

Hashes for mambatransformer-0.0.4.tar.gz
Algorithm Hash digest
SHA256 0e184403a7b76210f5b1352244a7e907d2a4adf950ce63e72bba4660a5300a30
MD5 9b29ce77860fe92ccd09e86dda0187ae
BLAKE2b-256 c89992c39da8c4038b2ebb78628ae821c3f1f63a654bf6854453fa2e924a207e

See more details on using hashes here.

File details

Details for the file mambatransformer-0.0.4-py3-none-any.whl.

File metadata

File hashes

Hashes for mambatransformer-0.0.4-py3-none-any.whl
Algorithm Hash digest
SHA256 83067b44708ab0f56fd6a02e03acde2df87f929061a7e022160e2c8466f3c673
MD5 8544846ff9a4958529f9ca98de2623c5
BLAKE2b-256 fc24fd3afd92f98e8c2f084ef00e70a4bdc2c3bb5467d3578d4f76e76396096a

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page