FMS Acceleration Plugin for Attention and Distributed Packing Optimizations
Project description
FMS Acceleration for Attention And Distributed Packing Plugin
This library contains plugins to accelerate finetuning with the following optimizations:
- Padding-Free Flash Attention Computation
- Multipack Distributed Sampling
Plugins
Plugin | Description | Depends | Loading | Augmentation | Callbacks |
---|---|---|---|---|---|
padding_free | Padding-Free Flash Attention Computation | flash_attn | ✅ | ||
multipack sampler | Multipack Distributed Sampling | numba | ✅ |
Native Transformers Support from v4.44.0
Transformers natively supports padding-free from v4.44.0 see here. The padding-free plugin will use the transformers library if compatible,
otherwise if transformers < v4.44.0
the plugin will use an internal implementation instead.
Native TRL Support for PaddingFree with DataCollatorForCompletionOnlyLM from v0.10.1
Users will be able to use PaddingFree with untokenized data from TRL >= v0.10.1. The flattening of inputs and addition of position_ids
to the batch
is carried out inside DataCollatorForCompletionOnlyLM
when keyword padding_free
is passed to the collator. The plugin uses the TRL library if compatible,
otherwise if trl < v0.10.1
the plugin will use an internal implementation instead.
If a user still passes in a pretokenized dataset, the plugin will still use DataCollaterForFlattening
in the collate_fn
.
Running Benchmarks
To reproduce the benchmarks, simply run the following commands,
Reproduce Padding Free on A100 80GB
tox -e run-benches -- "1 2" "4 8" benchmark_outputs scenarios-orca.yaml "none"
Reproduce MultiPack on A100 80GB
tox -e run-benches -- "2 4 8" "16 32 64" benchmark_outputs scenarios-orca.yaml "padding-free"
Known Issues
Currenly Only Supports Multipack with Padding-Free
The multipack plugin currently also requires the padding-free plugin to work. This may change in the future if there is demand for multipack to work standalone without padding free.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Hashes for fms_acceleration_aadp-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f7cf38e5d93693d084306f59efde73dba4e3bd58982e89d32ef01ef523589c3a |
|
MD5 | b4c485b1a5227a38ebc95c650e1bcc4e |
|
BLAKE2b-256 | f9091ea1428ab69d28550fe4e0077f56832b4230c445fdb02b5d01439076990f |