Skip to main content

LEGO blocks for AI: Build state-of-the-art models faster with modular PyTorch components.

Project description

Zeta

Zeta banner

Zeta is a modular PyTorch framework designed to simplify the development of AI models by providing reusable, high-performance building blocks. Think of it as a collection of LEGO blocks for AI each component is carefully crafted, tested, and optimized, allowing you to quickly assemble state-of-the-art models without reinventing the wheel.

MIT License PyPI Docs

Join our Discord Subscribe on YouTube Connect on LinkedIn Follow on X.com

Overview

Zeta provides a comprehensive library of modular components commonly used in modern AI architectures, including:

  • Attention Mechanisms: Multi-query attention, sigmoid attention, flash attention, and more
  • Mixture of Experts (MoE): Efficient expert routing and gating mechanisms
  • Neural Network Modules: Feedforward networks, activation functions, normalization layers
  • Quantization: BitLinear, dynamic quantization, and other optimization techniques
  • Architectures: Transformers, encoders, decoders, vision transformers, and complete model implementations
  • Training Utilities: Optimization algorithms, logging, and performance monitoring

Each component is designed to be:

  • Modular: Drop-in replacements that work seamlessly with PyTorch
  • High-Performance: Optimized implementations with fused kernels where applicable
  • Well-Tested: Comprehensive test coverage ensuring reliability
  • Production-Ready: Used in hundreds of models across various domains

Installation

pip3 install -U zetascale

Quick Start

Multi-Query Attention

Multi-query attention reduces memory usage while maintaining model quality by sharing key and value projections across attention heads.

import torch
from zeta import MultiQueryAttention

# Initialize the model
model = MultiQueryAttention(
    dim=512,
    heads=8,
)

# Forward pass
text = torch.randn(2, 4, 512)
output, _, _ = model(text)
print(output.shape)  # torch.Size([2, 4, 512])

SwiGLU Activation

The SwiGLU activation function applies a gating mechanism to selectively pass information through the network.

import torch
from zeta.nn import SwiGLUStacked

x = torch.randn(5, 10)
swiglu = SwiGLUStacked(10, 20)
output = swiglu(x)
print(output.shape)  # torch.Size([5, 20])

Relative Position Bias

Relative position bias quantizes the distance between positions into buckets and uses embeddings to provide position-aware attention biases.

import torch
from torch import nn
from zeta.nn import RelativePositionBias

# Initialize the module
rel_pos_bias = RelativePositionBias()

# Compute bias for attention mechanism
bias_matrix = rel_pos_bias(1, 10, 10)

# Use in custom attention
class CustomAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.rel_pos_bias = RelativePositionBias()

    def forward(self, queries, keys):
        bias = self.rel_pos_bias(queries.size(0), queries.size(1), keys.size(1))
        # Use bias in attention computation
        return None

FeedForward Network

A flexible feedforward module with optional GLU activation and LayerNorm, commonly used in transformer architectures.

import torch
from zeta.nn import FeedForward

model = FeedForward(256, 512, glu=True, post_act_ln=True, dropout=0.2)
x = torch.randn(1, 256)
output = model(x)
print(output.shape)  # torch.Size([1, 512])

BitLinear Quantization

BitLinear performs linear transformation with quantization and dequantization, reducing memory usage while maintaining performance. Based on BitNet: Scaling 1-bit Transformers for Large Language Models.

import torch
from torch import nn
import zeta.quant as qt

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = qt.BitLinear(10, 20)

    def forward(self, x):
        return self.linear(x)

model = MyModel()
input = torch.randn(128, 10)
output = model(input)
print(output.size())  # torch.Size([128, 20])

PalmE: Multi-Modal Architecture

A complete implementation of the PalmE multi-modal model architecture, combining a ViT image encoder with a transformer decoder for vision-language tasks.

import torch
from zeta.structs import (
    AutoRegressiveWrapper,
    Decoder,
    Encoder,
    Transformer,
    ViTransformerWrapper,
)

class PalmE(torch.nn.Module):
    """
    PalmE is a transformer architecture that uses a ViT encoder and a transformer decoder.
    
    This implementation demonstrates how to combine Zeta's modular components to build
    a complete multi-modal model architecture.
    """
    
    def __init__(
        self,
        image_size=256,
        patch_size=32,
        encoder_dim=512,
        encoder_depth=6,
        encoder_heads=8,
        num_tokens=20000,
        max_seq_len=1024,
        decoder_dim=512,
        decoder_depth=6,
        decoder_heads=8,
        alibi_num_heads=4,
        attn_kv_heads=2,
        use_abs_pos_emb=False,
        cross_attend=True,
        alibi_pos_bias=True,
        rotary_xpos=True,
        attn_flash=True,
        qk_norm=True,
    ):
        super().__init__()
        
        # Vision encoder
        self.encoder = ViTransformerWrapper(
            image_size=image_size,
            patch_size=patch_size,
            attn_layers=Encoder(
                dim=encoder_dim, 
                depth=encoder_depth, 
                heads=encoder_heads
            ),
        )
        
        # Language decoder
        self.decoder = Transformer(
            num_tokens=num_tokens,
            max_seq_len=max_seq_len,
            use_abs_pos_emb=use_abs_pos_emb,
            attn_layers=Decoder(
                dim=decoder_dim,
                depth=decoder_depth,
                heads=decoder_heads,
                cross_attend=cross_attend,
                alibi_pos_bias=alibi_pos_bias,
                alibi_num_heads=alibi_num_heads,
                rotary_xpos=rotary_xpos,
                attn_kv_heads=attn_kv_heads,
                attn_flash=attn_flash,
                qk_norm=qk_norm,
            ),
        )
        
        # Enable autoregressive generation
        self.decoder = AutoRegressiveWrapper(self.decoder)
    
    def forward(self, img: torch.Tensor, text: torch.Tensor):
        """Forward pass of the model."""
        encoded = self.encoder(img, return_embeddings=True)
        return self.decoder(text, context=encoded)

# Usage
img = torch.randn(1, 3, 256, 256)
text = torch.randint(0, 20000, (1, 1024))
model = PalmE()
output = model(img, text)
print(output.shape)

U-Net Architecture

A complete U-Net implementation for image segmentation and generative tasks.

import torch
from zeta.nn import Unet

model = Unet(n_channels=1, n_classes=2)
x = torch.randn(1, 1, 572, 572)
y = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")

Vision Embeddings

Convert images into patch embeddings suitable for transformer-based vision models.

import torch
from zeta.nn import VisionEmbedding

vision_embedding = VisionEmbedding(
    img_size=224,
    patch_size=16,
    in_chans=3,
    embed_dim=768,
    contain_mask_token=True,
    prepend_cls_token=True,
)

input_image = torch.rand(1, 3, 224, 224)
output = vision_embedding(input_image)
print(output.shape)

Dynamic Quantization with Niva

Niva provides dynamic quantization for specific layer types, ideal for models with variable runtime activations.

import torch
from torch import nn
from zeta import niva

# Load a pre-trained model
model = YourModelClass()

# Quantize the model dynamically
niva(
    model=model,
    model_path="path_to_pretrained_weights.pt",
    output_path="quantized_model.pt",
    quant_type="dynamic",
    quantize_layers=[nn.Linear, nn.Conv2d],
    dtype=torch.qint8,
)

Fused Operations

Zeta includes several fused operations that combine multiple operations into single kernels for improved performance.

FusedDenseGELUDense

Fuses two dense operations with GELU activation for up to 2x speedup.

import torch
from zeta.nn import FusedDenseGELUDense

x = torch.randn(1, 512)
model = FusedDenseGELUDense(512, 1024)
out = model(x)
print(out.shape)  # torch.Size([1, 1024])

FusedDropoutLayerNorm

Fuses dropout and layer normalization for faster feedforward networks.

import torch
from zeta.nn import FusedDropoutLayerNorm

model = FusedDropoutLayerNorm(dim=512)
x = torch.randn(1, 512)
output = model(x)
print(output.shape)  # torch.Size([1, 512])

Mamba: State Space Model

PyTorch implementation of the Mamba state space model architecture.

import torch
from zeta.nn import MambaBlock

block = MambaBlock(dim=64, depth=1)
x = torch.randn(1, 10, 64)
y = block(x)
print(y.shape)  # torch.Size([1, 10, 64])

FiLM: Feature-wise Linear Modulation

Feature-wise Linear Modulation for conditional feature transformation.

import torch
from zeta.nn import Film

film_layer = Film(dim=128, hidden_dim=64, expanse_ratio=4)
conditions = torch.randn(10, 128)
hiddens = torch.randn(10, 1, 128)
modulated_features = film_layer(conditions, hiddens)
print(modulated_features.shape)  # torch.Size([10, 1, 128])

Model Optimization

The hyper_optimize decorator` provides a unified interface for multiple optimization techniques.

import torch
from zeta.nn import hyper_optimize

@hyper_optimize(
    torch_fx=False,
    torch_script=False,
    torch_compile=True,
    quantize=True,
    mixed_precision=True,
    enable_metrics=True,
)
def model(x):
    return x @ x

out = model(torch.randn(1, 3, 32, 32))
print(out)

Direct Policy Optimization (DPO)

DPO implementation for reinforcement learning from human feedback (RLHF) applications.

import torch
from torch import nn
from zeta.rl import DPO

class PolicyModel(nn.Module):
    def __init__(self, dim, output_dim):
        super().__init__()
        self.fc = nn.Linear(dim, output_dim)
    
    def forward(self, x):
        return self.fc(x)

dim = 10
output_dim = 5
policy_model = PolicyModel(dim, output_dim)
dpo_model = DPO(model=policy_model, beta=0.1)

preferred_seq = torch.randint(0, output_dim, (3, dim))
unpreferred_seq = torch.randint(0, output_dim, (3, dim))
loss = dpo_model(preferred_seq, unpreferred_seq)
print(loss)

PyTorch Model Logging

A decorator for comprehensive model execution logging, including parameters, gradients, and memory usage.

import torch
from torch import nn
from zeta.utils.verbose_execution import verbose_execution

@verbose_execution(log_params=True, log_gradients=True, log_memory=True)
class YourPyTorchModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64 * 222 * 222, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

model = YourPyTorchModel()
input_tensor = torch.randn(1, 3, 224, 224)
output = model(input_tensor)

# Gradient information requires backward pass
loss = output.sum()
loss.backward()

Sigmoid Attention

An attention mechanism that replaces softmax with sigmoid, providing up to 18% speedup while maintaining performance.

import torch
from zeta import SigmoidAttention

batch_size = 32
seq_len = 128
dim = 512
heads = 8

x = torch.rand(batch_size, seq_len, dim)
mask = torch.ones(batch_size, seq_len, seq_len)

sigmoid_attn = SigmoidAttention(dim, heads, seq_len)
output = sigmoid_attn(x, mask)
print(output.shape)  # torch.Size([32, 128, 512])

Documentation

Comprehensive documentation is available at zeta.apac.ai.

Running Tests

Install the pre-commit hooks to run linters, type checking, and a subset of tests on every commit:

pre-commit install

To run the full test suite:

python3 -m pip install -e '.[testing]'  # Install extra dependencies for testing
python3 -m pytest tests/                # Run the entire test suite

For more details, refer to the CI workflow configuration.

Community

Join our growing community for real-time support, ideas, and discussions on building better AI models.

Platform Link Description
Docs zeta.apac.ai Official documentation
Discord Join our Discord Live chat & community
Twitter @kyegomez Follow for updates
LinkedIn The Swarm Corporation Connect professionally
YouTube YouTube Channel Watch our videos

Contributing

Zeta is an open-source project, and contributions are welcome! If you want to create new features, fix bugs, or improve the infrastructure, we'd love to have you contribute.

Getting Started:

Report Issues:

Our Contributors

Thank you to all of our contributors who have built this great framework 🙌

Contributors

Citation

If you use Zeta in your research or projects, please cite it:

@misc{zetascale,
    title = {Zetascale Framework},
    author = {Kye Gomez},
    year = {2024},
    howpublished = {\url{https://github.com/kyegomez/zeta}},
}

License

Apache 2.0 License

Project details


Release history Release notifications | RSS feed

This version

2.8.8

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

zetascale-2.8.8.tar.gz (370.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

zetascale-2.8.8-py3-none-any.whl (522.7 kB view details)

Uploaded Python 3

File details

Details for the file zetascale-2.8.8.tar.gz.

File metadata

  • Download URL: zetascale-2.8.8.tar.gz
  • Upload date:
  • Size: 370.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.3 Darwin/24.5.0

File hashes

Hashes for zetascale-2.8.8.tar.gz
Algorithm Hash digest
SHA256 0012f361d4d67ff57df47e8fbf4dd7a6dc00818449a4fbfa02438d128de9cfe5
MD5 ed1bc85af81c6ae32b7211f19cbbeb53
BLAKE2b-256 78fb23030c39943af4e4fe2a33e15bb38a4dd7d9cc7fe61401a1ebb4d7ee9c46

See more details on using hashes here.

File details

Details for the file zetascale-2.8.8-py3-none-any.whl.

File metadata

  • Download URL: zetascale-2.8.8-py3-none-any.whl
  • Upload date:
  • Size: 522.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/2.1.3 CPython/3.12.3 Darwin/24.5.0

File hashes

Hashes for zetascale-2.8.8-py3-none-any.whl
Algorithm Hash digest
SHA256 7fbe74e235dae1985ba777a2101a6d9fcd13df80c9d5fa0a583b4c80aeb230f7
MD5 d2ff756953e5b365bffbedf51b68e2ae
BLAKE2b-256 d2a1adf3ce30b6b791f0e32506b175103767aaee780bae31862eb751132b91ca

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page