Skip to main content

A transformer-based framework with finetuning as the first class citizen.

Project description

Introduction

sat(SwissArmyTransformer) is a flexible and powerful library to develop your own Transformer variants.

sat is named after "swiss army knife", meaning that all the models (e.g. BERT, GPT, T5, GLM, CogView, ViT...) share the same backone code and cater for versatile usages with some extra light-weight mixins.

sat is powered by deepspeed-ZeRO and model parallelism, aiming to provide the best practice for pretraining and finetuning large models (100M~20B parameters).

Install

    pip install SwissArmyTransformer

Features

  • Add model-agnostic components, e.g. prefix-tuning, in just ONE line!

    • Prefix-tuning (or P-tuning) improves finetuning via adding trainable parameters in each attention layer. To apply it to a GLM classification (or any other) model is easy with our library.
        class ClassificationModel(GLMModel): # can also be BertModel, RobertaModel, etc. 
            def __init__(self, args, transformer=None, **kwargs):
                super().__init__(args, transformer=transformer, **kwargs)
                self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
                # Arm an arbitrary model with Prefix-tuning with this line!
                self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))
    
    • GPT and other auto-regressive models act differently during training and inference. During inference, text is generated token-by-token and we need to cache previous states for efficiency. With our lib, you only need to consider the behavior during training (teacher-forcing) and transform it to a cached auto-regressive model via adding a mixin:
        model, args = AutoModel.from_pretrained('glm-10b-chinese', args)
        model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
        # Generate a sequence with beam search
        from sat.generation.autoregressive_sampling import filling_sequence
        from sat.generation.sampling_strategies import BeamSearchStrategy
        output, *mems = filling_sequence(model, input_seq,
                        batch_size=args.batch_size,
                        strategy=BeamSearchStrategy(args.batch_size))
    
  • Build your Transformer-based model with minimal codes. We mentioned GLM, which only differs from standard transformer (called BaseModel) on position embedding (and training losses). We only need to focus on the related part when coding.

    Extend the whole definition:

    class BlockPositionEmbeddingMixin(BaseMixin):
        # Here define parameters for the mixin
        def __init__(self, max_sequence_length, hidden_size, init_method_std=0.02):
            super(BlockPositionEmbeddingMixin, self).__init__()
            self.max_sequence_length = max_sequence_length
            self.hidden_size = hidden_size
            self.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)
            torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std)
        
        # Here define the method for the mixin
        def position_embedding_forward(self, position_ids, **kwargs):
            position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]
            position_embeddings = self.transformer.position_embeddings(position_ids)
            block_position_embeddings = self.block_position_embeddings(block_position_ids)
            return position_embeddings + block_position_embeddings
    
    class GLMModel(BaseModel):
        def __init__(self, args, transformer=None):
            super().__init__(args, transformer=transformer)
            self.add_mixin('block_position_embedding', 
                BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size)
            ) # Add the mixin for GLM
    
  • Comprehensive supports for training. sat aims to provide the best practice for pretraining and finetuning, where you only need to finish forward_step and create_dataset_function but with hyperparameters to alter useful training configurations.

    • Extend the training to multiple GPUs or nodes by specifying --num_nodes, --num_gpus and a simple hostfile.
    • DeepSpeed and Model parallelism.
    • Better integration of ZeRO-2 and activation checkpointing.
    • Automatic extending and shuffling training data and memmap.
    • Successfully support the training of CogView2 and CogVideo.
    • The only open-source codebase supporting finetuning T5-10B on GPUs currently.

Quick Tour

The most typical python file to use Bert in sat (for inference) is as follows:

# @File: inference_bert.py
from sat import get_args, get_tokenizer, AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args() 
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model, args = AutoModel.from_pretrained('bert-base-uncased', args) 
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer(args) 
# Here to use bert as you want!
# ...

Then we can run the code via

    SAT_HOME=/path/to/download python inference_bert.py --mode inference

All officially supported model names are in urls.py.

To finetune or pretrain a transformer is also extremely easy!

# @File: finetune_bert.py
from sat import get_args, get_tokenizer, AutoModel
from sat.model.mixins import MLPHeadMixin

def create_dataset_function(path, args):
    # Here to load the dataset
    # ...
    assert isinstance(dataset, torch.utils.data.Dataset)
    return dataset

def forward_step(data_iterator, model, args, timers):
    inputs = next(data_iterator) # from the dataset of create_dataset_function.
    loss, *others = model(inputs)
    return loss
    
# Parse args, initialize the environment. This is necessary.
args = get_args() 
model, args = AutoModel.from_pretrained('bert-base-uncased', args) 
tokenizer = get_tokenizer(args) 
# Here to use bert as you want!
model.del_mixin('bert-final')
model.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
# ONE LINE to train! 
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main(args, 
    model_cls=model, 
    forward_step_function=forward_step, # user define
    create_dataset_function=create_dataset_function # user define
)

Then we can run the code via

deepspeed --include localhost:0,1 finetune_bert.py \
    --experiment-name ftbert \
    --mode finetune --train-iters 1000 --save /path/to/save \
    --train-data /path/to/train --valid-data /path/to/valid \
    --lr 0.00002 --batch-size 8 --zero-stage 1 --fp16

Here we use data-parallel on GPUs 0,1. We can also launch the training on many inter-connected machines via --hostfile /path/to/hostfile. See the tutorial for more details.

To write your own model, you only need to consider the difference between the standard Transformer. For example, if you have a idea to improve the attention operation:

from sat.model import BaseMixin
class MyAttention(BaseMixin):
    def __init__(self, hidden_size):
        super(MyAttention, self).__init__()
        # MyAttention may needs some new params, e.g. a learnable alpha.
        self.learnable_alpha = torch.nn.Parameter(torch.ones(hidden_size))
    
    # This is a hook function, the name `attention_fn` is special.
    def attention_fn(q, k, v, mask, dropout=None, **kwargs):
        # Code for my attention.
        # ...
        return attention_results

Here attention_fn is a hook function, replacing the default action by the new function. All available hooks are in transformer_defaults.py. Now we can use add_mixin to apply our change to all the transformers, such as BERT, Vit and CogView. See the tutorial for more details.

Tutorials

Citation

Currently we don't have a paper, so you don't need to formally cite us!~

If this project helps your research or engineering, use \footnote{https://github.com/THUDM/SwissArmyTransformer} to mention us and recommend SwissArmyTransformer to others.

The tutorial for contributing sat is on the way!

The project is based on (a user of) DeepSpeed, Megatron-LM and Huggingface transformers. Thanks for their awesome work.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

SwissArmyTransformer-0.4.12.tar.gz (2.4 MB view details)

Uploaded Source

Built Distribution

SwissArmyTransformer-0.4.12-py3-none-any.whl (2.4 MB view details)

Uploaded Python 3

File details

Details for the file SwissArmyTransformer-0.4.12.tar.gz.

File metadata

  • Download URL: SwissArmyTransformer-0.4.12.tar.gz
  • Upload date:
  • Size: 2.4 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.7.4

File hashes

Hashes for SwissArmyTransformer-0.4.12.tar.gz
Algorithm Hash digest
SHA256 79b469b5950cd8b00e11575b120e888baf825b3d308d212670712aace3fa7253
MD5 83f2acd350d0b959d90205992cf4f483
BLAKE2b-256 58380349a801aafbdced1e8daf76ca1d7d38e8a6e056cdc2ae9f1b8eeb593a17

See more details on using hashes here.

File details

Details for the file SwissArmyTransformer-0.4.12-py3-none-any.whl.

File metadata

File hashes

Hashes for SwissArmyTransformer-0.4.12-py3-none-any.whl
Algorithm Hash digest
SHA256 04aa539e3cadda1a1690d26d4f77648d56ddd789da1fcf4cc297e8a098e3f6c6
MD5 939289b6ed6a9e0a9520a4ca9ca79c99
BLAKE2b-256 18a44937b81c446732cf17803d876fa15174ae5a6ab492fe90183419ccfac33f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page