A transformer-based framework with finetuning as the first class citizen.
Project description
Introduction
sat
(SwissArmyTransformer
) is a flexible and powerful library to develop your own Transformer variants.
sat
is named after "swiss army knife", meaning that all the models (e.g. BERT, GPT, T5, GLM, CogView, ViT...) share the same backone code and cater for versatile usages with some extra light-weight mixins.
sat
is powered by deepspeed-ZeRO
and model parallelism, aiming to provide the best practice for pretraining and finetuning large models (100M~20B parameters).
Install
pip install SwissArmyTransformer
Features
-
Add model-agnostic components, e.g. prefix-tuning, in just ONE line!
- Prefix-tuning (or P-tuning) improves finetuning via adding trainable parameters in each attention layer. To apply it to a GLM classification (or any other) model is easy with our library.
class ClassificationModel(GLMModel): # can also be BertModel, RobertaModel, etc. def __init__(self, args, transformer=None, **kwargs): super().__init__(args, transformer=transformer, **kwargs) self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1)) # Arm an arbitrary model with Prefix-tuning with this line! self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))
- GPT and other auto-regressive models act differently during training and inference. During inference, text is generated token-by-token and we need to cache previous states for efficiency. With our lib, you only need to consider the behavior during training (teacher-forcing) and transform it to a cached auto-regressive model via adding a mixin:
model, args = AutoModel.from_pretrained('glm-10b-chinese', args) model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) # Generate a sequence with beam search from sat.generation.autoregressive_sampling import filling_sequence from sat.generation.sampling_strategies import BeamSearchStrategy output, *mems = filling_sequence(model, input_seq, batch_size=args.batch_size, strategy=BeamSearchStrategy(args.batch_size))
-
Build your Transformer-based model with minimal codes. We mentioned GLM, which only differs from standard transformer (called BaseModel) on position embedding (and training losses). We only need to focus on the related part when coding.
Extend the whole definition:
class BlockPositionEmbeddingMixin(BaseMixin): # Here define parameters for the mixin def __init__(self, max_sequence_length, hidden_size, init_method_std=0.02): super(BlockPositionEmbeddingMixin, self).__init__() self.max_sequence_length = max_sequence_length self.hidden_size = hidden_size self.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size) torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std) # Here define the method for the mixin def position_embedding_forward(self, position_ids, **kwargs): position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1] position_embeddings = self.transformer.position_embeddings(position_ids) block_position_embeddings = self.block_position_embeddings(block_position_ids) return position_embeddings + block_position_embeddings class GLMModel(BaseModel): def __init__(self, args, transformer=None): super().__init__(args, transformer=transformer) self.add_mixin('block_position_embedding', BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size) ) # Add the mixin for GLM
-
Comprehensive supports for training.
sat
aims to provide the best practice for pretraining and finetuning, where you only need to finishforward_step
andcreate_dataset_function
but with hyperparameters to alter useful training configurations.- Extend the training to multiple GPUs or nodes by specifying
--num_nodes
,--num_gpus
and a simplehostfile
. - DeepSpeed and Model parallelism.
- Better integration of ZeRO-2 and activation checkpointing.
- Automatic extending and shuffling training data and
memmap
. - Successfully support the training of CogView2 and CogVideo.
- The only open-source codebase supporting finetuning T5-10B on GPUs currently.
- Extend the training to multiple GPUs or nodes by specifying
Quick Tour
The most typical python file to use Bert
in sat (for inference) is as follows:
# @File: inference_bert.py
from sat import get_args, get_tokenizer, AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args()
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model, args = AutoModel.from_pretrained('bert-base-uncased', args)
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer(args)
# Here to use bert as you want!
# ...
Then we can run the code via
SAT_HOME=/path/to/download python inference_bert.py --mode inference
All officially supported model names are in urls.py.
To finetune or pretrain a transformer is also extremely easy!
# @File: finetune_bert.py
from sat import get_args, get_tokenizer, AutoModel
from sat.model.mixins import MLPHeadMixin
def create_dataset_function(path, args):
# Here to load the dataset
# ...
assert isinstance(dataset, torch.utils.data.Dataset)
return dataset
def forward_step(data_iterator, model, args, timers):
inputs = next(data_iterator) # from the dataset of create_dataset_function.
loss, *others = model(inputs)
return loss
# Parse args, initialize the environment. This is necessary.
args = get_args()
model, args = AutoModel.from_pretrained('bert-base-uncased', args)
tokenizer = get_tokenizer(args)
# Here to use bert as you want!
model.del_mixin('bert-final')
model.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
# ONE LINE to train!
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main(args,
model_cls=model,
forward_step_function=forward_step, # user define
create_dataset_function=create_dataset_function # user define
)
Then we can run the code via
deepspeed --include localhost:0,1 finetune_bert.py \
--experiment-name ftbert \
--mode finetune --train-iters 1000 --save /path/to/save \
--train-data /path/to/train --valid-data /path/to/valid \
--lr 0.00002 --batch-size 8 --zero-stage 1 --fp16
Here we use data-parallel on GPUs 0,1. We can also launch the training on many inter-connected machines via --hostfile /path/to/hostfile
. See the tutorial for more details.
To write your own model, you only need to consider the difference between the standard Transformer. For example, if you have a idea to improve the attention operation:
from sat.model import BaseMixin
class MyAttention(BaseMixin):
def __init__(self, hidden_size):
super(MyAttention, self).__init__()
# MyAttention may needs some new params, e.g. a learnable alpha.
self.learnable_alpha = torch.nn.Parameter(torch.ones(hidden_size))
# This is a hook function, the name `attention_fn` is special.
def attention_fn(q, k, v, mask, dropout=None, **kwargs):
# Code for my attention.
# ...
return attention_results
Here attention_fn
is a hook function, replacing the default action by the new function. All available hooks are in transformer_defaults.py.
Now we can use add_mixin
to apply our change to all the transformers, such as BERT, Vit and CogView. See the tutorial for more details.
Tutorials
Citation
Currently we don't have a paper, so you don't need to formally cite us!~
If this project helps your research or engineering, use \footnote{https://github.com/THUDM/SwissArmyTransformer}
to mention us and recommend SwissArmyTransformer
to others.
The tutorial for contributing sat is on the way!
The project is based on (a user of) DeepSpeed, Megatron-LM and Huggingface transformers. Thanks for their awesome work.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file SwissArmyTransformer-0.4.12.tar.gz
.
File metadata
- Download URL: SwissArmyTransformer-0.4.12.tar.gz
- Upload date:
- Size: 2.4 MB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 79b469b5950cd8b00e11575b120e888baf825b3d308d212670712aace3fa7253 |
|
MD5 | 83f2acd350d0b959d90205992cf4f483 |
|
BLAKE2b-256 | 58380349a801aafbdced1e8daf76ca1d7d38e8a6e056cdc2ae9f1b8eeb593a17 |
File details
Details for the file SwissArmyTransformer-0.4.12-py3-none-any.whl
.
File metadata
- Download URL: SwissArmyTransformer-0.4.12-py3-none-any.whl
- Upload date:
- Size: 2.4 MB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/4.0.2 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 04aa539e3cadda1a1690d26d4f77648d56ddd789da1fcf4cc297e8a098e3f6c6 |
|
MD5 | 939289b6ed6a9e0a9520a4ca9ca79c99 |
|
BLAKE2b-256 | 18a44937b81c446732cf17803d876fa15174ae5a6ab492fe90183419ccfac33f |