A transformer-based framework with finetuning as the first class citizen.
Project description
Introduction
SwissArmyTransformer
is a flexible and powerful library to develop your own Transformer variants.
SwissArmyTransformer
is named after "swiss army knife", meaning that all the models (e.g. BERT, GPT, T5, GLM, CogView, ViT...) share the same backone code and cater for versatile usages with some extra light-weight mixins.
SwissArmyTransformer
is powered by deepspeed-ZeRO
and model parallelism, aiming to provide the best practice for pretraining and finetuning large models (100M~20B parameters).
Install
pip install SwissArmyTransformer
Features
-
Add model-agnostic components, e.g. prefix-tuning, in just ONE line!
- Prefix-tuning (or P-tuning) improves finetuning via adding trainable parameters in each attention layer. To apply it to a GLM classification (or any other) model is easy with our library.
class ClassificationModel(GLMModel): def __init__(self, args, transformer=None, **kwargs): super().__init__(args, transformer=transformer, **kwargs) self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1)) # Arm an arbitrary model with Prefix-tuning with this line! self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))
- GPT and other auto-regressive models act differently during training and inference. During inference, text is generated token-by-token and we need to cache previous states for efficiency. With our lib, you only need to consider the behavior during training (teacher-forcing) and transform it to a cached auto-regressive model via adding a mixin:
model = GLMModel(args) model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) # Generate a sequence with beam search from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence from SwissArmyTransformer.generation.sampling_strategies import BeamSearchStrategy output, *mems = filling_sequence(model, input_seq, batch_size=args.batch_size, strategy=BeamSearchStrategy(args.batch_size))
-
Build your Transformer-based model with minimal codes. We mentioned GLM, which only differs from standard transformer (called BaseModel) on position embedding (and training losses). We only need to focus on the related part when coding.
Extend the whole definition:
class BlockPositionEmbeddingMixin(BaseMixin): # Here define parameters for the mixin def __init__(self, max_sequence_length, hidden_size, init_method_std=0.02): super(BlockPositionEmbeddingMixin, self).__init__() self.max_sequence_length = max_sequence_length self.hidden_size = hidden_size self.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size) torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std) # Here define the method for the mixin def position_embedding_forward(self, position_ids, **kwargs): position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1] position_embeddings = self.transformer.position_embeddings(position_ids) block_position_embeddings = self.block_position_embeddings(block_position_ids) return position_embeddings + block_position_embeddings class GLMModel(BaseModel): def __init__(self, args, transformer=None, parallel_output=True): super().__init__(args, transformer=transformer, parallel_output=parallel_output) self.add_mixin('block_position_embedding', BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size) ) # Add the mixin for GLM # we can also directly define hook-functions in the model. # E.g., The code below will remove position embeddings: # def position_embedding_forward(self, position_ids, **kwargs): # return 0
-
Comprehensive supports for training.
SwissArmyTransformer
aims to provide the best practice for pretraining and finetuning, where you only need to finishforward_step
andcreate_dataset_function
but with hyperparameters to alter useful training configurations.- Extend the training to multiple GPUs or nodes by specifying
--num_nodes
,--num_gpus
and a simplehostfile
. - DeepSpeed and Model parallelism.
- Better integration of ZeRO-2 and activation checkpointing.
- Automatic extending and shuffling training data and
memmap
. - Successfully support the training of CogView2.
- The only open-source codebase supporting finetuning T5-10B on GPUs currently.
- Extend the training to multiple GPUs or nodes by specifying
Get started
cd examples/cogview2
./scripts/text2image_cogview2.sh
Run GLM
- Prepare input.txt. Example: "Welcome! This is the main page of SwissArmyTransformer".
- Run the following commands:
cd examples/glm
./scripts/generate_glm.sh config/model_glm_10B_chinese.sh
Output: [CLS]Welcome! This is the main page of SwissArmyTransformer. It is a comprehensive and clear explanation of the technical problems in the transformer. It is also an introduction to the development of the SwissArmy transformers. Welcome to Swiss Army Transforters. This is the main page of Swiss army tranforter. It's a complete and clean explaination of technology problem in the Tranformer, which is an integral part of the army's technological development. It also anintroduction of the developments of the Army technicians. Well, if you have any questions, please feel free to contact the official webs
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for SwissArmyTransformer-0.1.5.tar.gz
Algorithm | Hash digest | |
---|---|---|
SHA256 | e56f04b3d69a4db79d8ea6df081c2f4b08c1e39b40c449046293ab93db17239a |
|
MD5 | 89dfdd1600e77489a6097e2787d94275 |
|
BLAKE2b-256 | b34f1ba03955fd519eb4f1d9ddbf0a60aab1ff675342a3f681dbf475587883ab |
Hashes for SwissArmyTransformer-0.1.5-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6f02cf847f421c4a4860c484dd31ff5289d5395039bfb40b396767475aef7294 |
|
MD5 | 2fae90e04b2e72c093e5d433d5458200 |
|
BLAKE2b-256 | 1c5b4888c80420df64f3149b651d2501634781912d2b4392c65e477aac4bda4e |