Transformer components in Triton
Project description
triton-transformers
This will be an implementation of transformers using triton,
- This is my first introduction to low-level GPU coding neurel networks i guess.
- I will try to Also train the model not sure yet but maybe
- As of right now I am still learning Triton
Installation
- First install triformer
pip install triformer==1.3.0
- Then you can use the components
- please keep in mind that the TritonLinear is a fused with relu
from triformer import TritonLinear
class TritonMLP(nn.Module):
def __init__(self, input_size, num_classes, hidden_size=768):
super(TritonMLP, self).__init__()
self.fc1 = TritonLinear(input_size, hidden_size,use_relu=True)
self.fc2 = TritonLinear(hidden_size, hidden_size*2,use_relu=True)
self.fc3 = TritonLinear(hidden_size*2, num_classes,use_relu=False)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
Try it out!
You can try out the TritonMLP on CIFAR10 dataset using this Colab notebook:
Future Plans - To Do
- Create a library specifically for transformers in vision and language
- Make the TritonLinear more flexible to different activation functions
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
triformer-1.1.0.tar.gz
(5.8 kB
view hashes)
Built Distribution
Close
Hashes for triformer-1.1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0e8e00f2eea1c265d598fc9e50e22d91d97813e4f42fc5c0ec4fabcb8abb7af3 |
|
MD5 | 99a7abdc4c742d9c7fec645abde7e18a |
|
BLAKE2b-256 | 465cd3d3e4f4cd87e7bb17bf7dfcfcf0ae6d3e5440bac7e50448b4c4efd4ca10 |