Skip to main content

SwitchTransformers - Pytorch

Project description

Multi-Modality

Switch Transformers

Switch Transformer

Implementation of Switch Transformers from the paper: "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity" in PyTorch, Einops, and Zeta. PAPER LINK

Installation

pip install switch-transformers

Usage

import torch
from switch_transformers import SwitchTransformer

# Generate a random tensor of shape (1, 10) with values between 0 and 100
x = torch.randint(0, 100, (1, 10))

# Create an instance of the SwitchTransformer model
# num_tokens: the number of tokens in the input sequence
# dim: the dimensionality of the model
# heads: the number of attention heads
# dim_head: the dimensionality of each attention head
model = SwitchTransformer(
    num_tokens=100, dim=512, heads=8, dim_head=64
)

# Pass the input tensor through the model
out = model(x)

# Print the shape of the output tensor
print(out.shape)

Citation

@misc{fedus2022switch,
    title={Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity}, 
    author={William Fedus and Barret Zoph and Noam Shazeer},
    year={2022},
    eprint={2101.03961},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

switch_transformers-0.0.4.tar.gz (5.6 kB view hashes)

Uploaded Source

Built Distribution

switch_transformers-0.0.4-py3-none-any.whl (5.6 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page