Foundation Model Stack is a collection of components for development, inference, training, and tuning of foundation models leveraging PyTorch native components.
Project description
Foundation Model Stack
Foundation Model Stack is a collection of components for development, inference, training, and tuning of foundation models leveraging PyTorch native components. For inference optimizations we aim to support PyTorch compile, accelerated transformers, and tensor parallelism. At training time we aim to support FSDP, accelerated transformers, and PyTorch compile. To enable these optimizations, we will provide reimplementations of several popular model architectures starting with Llama and GPT-BigCode.
Models Supported
| Model family | Inference | Tuning and Training |
|---|---|---|
| LLaMA | :heavy_check_mark: | :heavy_check_mark: |
| GPT-BigCode | :heavy_check_mark: | :x: |
| RoBERTa | :heavy_check_mark: | :x: |
Installation
We recommend running this on Python 3.11 and CUDA 12.1 for best performance, as the CPU overheads of the models are reduced significantly.
Pypi
pip install ibm-fms
Local
Requires PyTorch >= 2.1.
pip install -e .
For installing the development dependencies, use:
pip install .[dev]
For installing the Hugging Face specific dependencies, use:
pip install .[hf]
Inference
Approach
Our approach for inference optimization is to use PyTorch compile, accelerated transformers, and tensor parallelism. PyTorch compile compiles the code into optimized kernels, accelerated transformers leverages scaled_dot_product_attention (SDPA) for accelerating attention computation while saving memory, and tensor parallelism is necessary for larger models.
To enable the Llama models to compile, we had to reimplement RoPE encodings without complex numbers. With this change, Llama model inference is able to leverage model compilation for latency reduction.
Inference latency
We measured inference latencies with 1024 token prompt and generation of 256 tokens on AWS P4de instance nodes with 8 80G A100 GPUs and report the median latency in the below table.
| Model | # GPUs | Median latency (ms) |
|---|---|---|
| 7B | 1 | 14ms |
| 13B | 1 | 22ms |
| 70B | 8 | 30ms |
If you would like to reproduce the latencies, you can run the scripts/benchmark_inference.py and the details are described in inference.
For more information on reproducing the benchmarks and running some examples, see here
HF Model Support
The support for HF models is provided by our HF model adapter. One can obtain similar latencies as tabulated above with HF models using our HF model adapter:
from fms.models import get_model
from fms.models.hf import to_hf_api
import torch
from transformers import pipeline
# fms model
llama = get_model("llama", "13b")
# huggingface model backed by fms internals
llama_hf = to_hf_api(llama)
# compile the model -- in HF, the decoder only
llama_hf.decoder = torch.compile(llama_hf.decoder)
# generate some text -- the first time will be slow since the model needs to be compiled, but subsequent generations should be faster.
llama_generator = pipeline(task="text-generation", model=llama_hf, tokenizer=tokenizer)
llama_generator("""q: how are you? a: I am good. How about you? q: What is the weather like today? a:""")
A detailed example is provided here.
Tuning
To fine-tune LLaMA, use the scripts/train_causal.py training script. Here's
an example of that command.
torchrun --nproc_per_node=2 \
scripts/train_causal.py \
--architecture=llama \
--variant=7b \
--tokenizer=~/models/tokenizer.model \
--model_path=~/models/7B/ \
--report_steps=10 \
--checkpoint_format=meta \
--distributed=fsdp
See options in the script for other ways to train and tune.
Structure and contents of this Repository
fms/models/- Pure pytorch implementations of popular model architectures, without requiring any specific common interface beyondnn.Module. Each model configuration is registered withfms.models.register_model()so that instances can be obtained throughfms.models.get_model('architecture', 'variant', '/path/to/data'). Each model can also register sources/formats/versions of data to load (e.g. checkpoints provided by meta, HF, or trained from this repo). Users of the repo (e.g.fms-extras) can register their own model architectures as well.fms/models/hf/- Adapters that compose our native PyTorch FMS model architecture implementations in HF-compatible wrapper interfaces. Each FMS model implements an adapter, and adapted instances are obtained viafms.models.hf.to_hf_api(model)fms/datasets/- Code for loading data for pre-training and fine-tuning. Individual datasets are retrieved byfms.datasets.get_dataset('name', tokenizer, 'optional path or other data reference'). The expected tokenizer conforms to anfms.utils.tokenizers.BaseTokenizerinterface.fms/modules/- Components extendingnn.Moduleused in our model architecture implementations. Each Module has a correspondingTPModuleso that modules can be sharded using a tensor-parallel distribution strategy. FMS modules should all supporttorch.compilewithout graph breaks.fms/training/- Pre-training and fine-tuning code.fms/utils/- Other operators useful in working with LLMs. These include agenerate()function,Tensorsubclasses, code for dealing with LLM checkpoints that might be saved/sharded in a variety of formats, tokenization code, and various other useful helper functions.scripts/- Various scripts for inference, benchmarking, and evaluation, as well as an entry-point for tuning/training.
Extensions and Use Cases
This library is used by three dependent projects at IBM.
- fms-fsdp - This repo shares training code that has been used to pretrain an fms implementation of LLaMA on IBM internal data.
- fms-extras - This repo shares code for additional fms-based models trained by IBM. This repo will also be a home for other extensions, and may also include research or in-developent work intended for eventual upstreaming to fms.
- TGIS - This inference server includes support for serving fms models.
Open Issues
- https://github.com/pytorch/pytorch/issues/107824 prevents training/finetuning from working with
torch.compile. - In addition, there are several open issues we are tracking to improve stability and memory footprint of inference
References
- Huggingface TGI: https://github.com/huggingface/text-generation-inference
- IBM TGIS: https://github.com/IBM/text-generation-inference
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file ibm_fms-1.7.0-py3-none-any.whl.
File metadata
- Download URL: ibm_fms-1.7.0-py3-none-any.whl
- Upload date:
- Size: 212.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
afd2f1a05a22bbfcd538d2fb211bc03be018c16f5ca5badf42c5fd62456f20d3
|
|
| MD5 |
7634059109eef0d2288340bdc2661e72
|
|
| BLAKE2b-256 |
742f368d34b2eec3216691f97b095f65d71b07a0aebe0209794d07293bc8904c
|
Provenance
The following attestation bundles were made for ibm_fms-1.7.0-py3-none-any.whl:
Publisher:
build-and-publish.yaml on foundation-model-stack/foundation-model-stack
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
ibm_fms-1.7.0-py3-none-any.whl -
Subject digest:
afd2f1a05a22bbfcd538d2fb211bc03be018c16f5ca5badf42c5fd62456f20d3 - Sigstore transparency entry: 955663234
- Sigstore integration time:
-
Permalink:
foundation-model-stack/foundation-model-stack@9124614bcf394eb6eec366a5444e54933092f4e1 -
Branch / Tag:
refs/tags/v1.7.0 - Owner: https://github.com/foundation-model-stack
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
build-and-publish.yaml@9124614bcf394eb6eec366a5444e54933092f4e1 -
Trigger Event:
release
-
Statement type: