Skip to main content

FMS Acceleration using Fused Operations and Kernels

Project description

FMS Acceleration for Fused Operations and Kernels

This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:

  1. Fused operations and kernels extracted from unsloth.
    • Low-Rank Adapter Fused Operations
    • Fast RoPE Triton Kernels
    • Fast RMS LayerNorm Triton Kernels
    • Fast Cross Entropy Triton Kernels

Plugins

Plugin Description Depends Loading Augmentation Callbacks
fast_kernels Enhanced version of fast_quantized_peft, also works for full-FT and non-quant peft Contains extracted code

Supported DataType Settings

Compatibility Matrix with Mixed Precision

torch_dtype Mixed Precision Full-FT-FOAK PEFT-FOAK QPEFT-FOAK
FLOAT16 - Compatible Compatible
FLOAT16 FP16 ValueError:
Attempting to
unscale FP16 gradients.
See here
Compatible Compatible
BFLOAT16 - Compatible Compatible
BFLOAT16 BF16 Compatible Compatible Less Performant

NOTE: this chart is also a good reference for supported types, even for the non-FOAK case.

Code Extracted from Unsloth

Notes on the extraction of code from unsloth:

  • While unsloth is released under Apache 2.0, there are comments indicating some exceptions strewn throughout the code base, see an example here.
    it would require a commercial license if used to run on more than 4 GPUs ...
    
  • These exceptions appear to be located around the trainer improvements, see another example here.
  • These exceptions appear around Feb 2024 Release; any code that appears in any file where such exceptions occur is not extracted.
  • Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and completely rewritten from scratch.
  • We have also enabled dropout on the lora fused operations.
  • All extracted code appears before the Feb 2024 Release.
  • In the table below we record what was extracted, and the exact commit from which it was taken.
Path Description Extracted From Modifications Date
fused_ops/unsloth_lora QLoRA fast dequant, activation kernels unsloth/main @ 1ecc0185 28 Jan 2024
fused_ops/unsloth_lora/bnb BNB fast lora unsloth/main @ 1ecc0185 fast_lora.py 28 Jan 2024
fused_ops/unsloth_lora/gptq GPTQ fast dequant (triton_v2) jeromeku/main @ 2839d39 fast_lora.py
triton/layers.py
6 Feb 2024
kernels/unsloth Fast RMS, RoPE, CrossEnt kernels unsloth/main @ 1ecc0185 cross_entropy_loss.py
rms_layernorm.py
28 Jan 2024

Supported Models

Model norm pos emb cross-ent fused_lora
LlamaForCausalLM
MistralForCausalLM
MixtralForCausalLM
GPTBigCodeForCausalLM
GraniteForCausalLM

Adding Support For A New Model

It is realtively easy by following an existing template, in what follows we use GraniteForCausalLM as an example.

  • implement a get_mp_rules for the new model, which returns a list of ModelPatcherRule.
  • logic that needs to be changed is the various classes that the rules are triggered on. Import the various module classes likes so:
    from transformers.models.granite.modeling_granite import ( 
        GraniteAttention,
        GraniteMLP,
        GraniteRMSNorm,
    )
    
  • replace the classes appropriately in various locations in ModelPatcherRule. In particular the ModelPatcherTrigger portions of it. Name rule_id appropriately.
    ModelPatcherRule(
        rule_id="granite-rms",
        trigger=ModelPatcherTrigger(check=GraniteRMSNorm),
        forward=fast_rms_layernorm,
    )
    

Running Liger Kernel Benchmarks

Using the scenarios-liger.yaml, this will run full fine tuning, lora peft, autoGPTQ lora peft, and bits-and-bytes lora peft with the triton kernels (Fast RMS, RoPE, CrossEnt) as a base and then run with the liger kernel for LigerFusedLinearCrossEntropy as well as Fast RMS, RoPE to compare results. It only runs against mistral and llama models.

The benchmarks were ran separately for each num_gpu entry; they can be run together in a single command, but this is more efficient.

tox -e run-benches -- 1 "4 8 16 32" benchmark_outputs_1 scenarios-liger.yaml none
tox -e run-benches -- 2 "8 16 32 64" benchmark_outputs_2 scenarios-liger.yaml none
tox -e run-benches -- 4 "16 32 64 128" benchmark_outputs_3 scenarios-liger.yaml none

Known Issues

  • MixedPrecision --fp16 or --bf16 should be used with fast_lora.
  • fast_lora has issues with FSDP V1 with the peft style of FSDP wrapping.
    • This is because the adapter's forward functions are bypassed in the fused ops.
    • For AutoGPTQ/QLoRA this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops.
  • fast_rope_embeddings does not work with postion_ids, it seems like HF has depracated passing these ids into the rope embedding methods.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

fms_acceleration_foak-0.4.0-py3-none-any.whl (64.4 kB view details)

Uploaded Python 3

File details

Details for the file fms_acceleration_foak-0.4.0-py3-none-any.whl.

File metadata

File hashes

Hashes for fms_acceleration_foak-0.4.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a4339784ef87bf78211fc60b20cb97ccb2a73b7c7c350b2a1ae58541866e992d
MD5 ea9dfb9c50b84d9c4a9201f1593b3df6
BLAKE2b-256 18b2e04d0dc0c74ce2c81395bfea306ef205ca8bcbc6f7c6d2d2f70b4c0685c4

See more details on using hashes here.

Provenance

The following attestation bundles were made for fms_acceleration_foak-0.4.0-py3-none-any.whl:

Publisher: build-and-publish.yml on foundation-model-stack/fms-acceleration

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page