Skip to main content

TILEARN for LLM

Project description

Tilearn.llm使用说明

1. CUDA Kernel(以LLAMA为例)

支持显卡:Ampere, Ada, or Hopper GPUs (e.g., A100, A800, H100, H800)

新版本

新版本Dependencies: pytorch >= 2.0.0

该版本完全兼容huggingface接口,不需要额外的转模型操作

LLAMA1/LLAMA2 A800 16GPU seq=1024相比deepspeed zero2训练加速约20%

cuda kernel使用方法-启动脚本修改如下

### TIACC CUDA Kernel
### Open: TIACC_TRAINING_CUDA_KERNEL=1
### Close: TIACC_TRAINING_CUDA_KERNEL=0
export TIACC_TRAINING_CUDA_KERNEL=1

cuda kernel使用方法-代码修改如下

### TIACC
TIACC_TRAINING_CUDA_KERNEL = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL', '0'))
if TIACC_TRAINING_CUDA_KERNEL == 1:
    from tilearn.llm.transformers import LlamaForCausalLM

### 模型接口与标准huggingface一致
model = LlamaForCausalLM.from_pretrained(...)
### TIACC
TIACC_TRAINING_CUDA_KERNEL = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL', '0'))
if TIACC_TRAINING_CUDA_KERNEL == 1:
    from tilearn.llm.transformers import AutoModelForCausalLM

### 模型接口与标准huggingface一致
model = AutoModelForCausalLM.from_pretrained(...)

旧版本

旧版本Dependencies: flash-attention 请安装https://github.com/Dao-AILab/flash-attention, 建议源码安装

### compile from source
git clone --recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention && python setup.py install

### install layer_norm, fused_dense and rotary kernel
cd flash-attention/csrc/layer_norm && pip3 install .
cd flash-attention/csrc/fused_dense_lib && pip install .
cd flash-attention/csrc/rotary && pip install .

该版本不兼容huggingface接口,可直接读取huggingface模型和原始cuda kernel模型(训练保存的模型结构)

由于训练保存的模型为原始cuda kernel模型,非huggingface结构,若需要huggingface模型则手动执行脚本转换

LLAMA1/LLAMA2 A800 16GPU seq=1024相比deepspeed zero2训练加速约30%

cuda kernel使用方法-启动脚本修改如下

### TIACC CUDA Kernel
### Open: TIACC_TRAINING_CUDA_KERNEL_V0=1
### Close: TIACC_TRAINING_CUDA_KERNEL_V0=0
export TIACC_TRAINING_CUDA_KERNEL_V0=1
export TIACC_TRAINING_MODEL_FORMAT=llama-hf
# 若读取huggingface模型结构,则设置llama-hf
export TIACC_TRAINING_MODEL_FORMAT=llama-hf
# 若原始cuda kernel模型,则设置llama-origin
export TIACC_TRAINING_MODEL_FORMAT=llama-origin

cuda kernel使用方法-代码修改如下

### TIACC
TIACC_TRAINING_CUDA_KERNEL_V0 = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL_V0', '0'))
if TIACC_TRAINING_CUDA_KERNEL_V0 == 1:
    from tilearn import llm

### LLAMA模型初始化
TIACC_TRAINING_MODEL_FORMAT = os.getenv('TIACC_TRAINING_MODEL_FORMAT', 'llama-origin')
model = llm.models.llama(model_args.model_name_or_path, model_format=TIACC_TRAINING_MODEL_FORMAT)

2. Static Zero

适用场景:在deepspeed zero1、zero2、zero3、offload、int8等不同优化状态间切换

启动脚本修改如下

### TIACC STATIC ZERO
### Open: TIACC_TRAINING_CUDA_KERNEL='O2' 
### support 'O2' / 'O2.5' / 'O3' / 'O3.5' / 'O3_Q8'(doing)
### Close: TIACC_TRAINING_CUDA_KERNEL='None'
export TIACC_TRAINING_STATIC_ZERO='None' #'O2'

代码修改如下

from transformers import HfArgumentParser

TIACC_TRAINING_STATIC_ZERO = os.getenv('TIACC_TRAINING_STATIC_ZERO', 'None')
if TIACC_TRAINING_STATIC_ZERO != 'None':
    from tilearn.llm.transformers import TrainingArguments
	
### 接口与标准huggingface一致
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

3. Dynamic Zero

适用场景:适用于zero3 + offload场景,大幅优化显存从而提升batchsize

启动脚本修改如下

### TIACC DYNAMIC ZERO
### Open: TIACC_TRAINING_DYNAMIC_ZERO=1 and set TIACC_ZERO_STAGE/TIACC_ZERO_STAGE/TIACC_PLACEMENT/TIACC_SHARD_INIT/TIACC_CPU_INIT
### Close: TIACC_TRAINING_DYNAMIC_ZERO=0
export TIACC_TRAINING_DYNAMIC_ZERO=0
export TIACC_ZERO_STAGE=3 #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_PLACEMENT='cpu' #'cuda' #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_SHARD_INIT=0 #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_CPU_INIT=1 #work when TIACC_TRAINING_DYNAMIC_ZERO=1

if [ ${TIACC_TRAINING_DYNAMIC_ZERO} = 0 ]; then
  #USE_DS="--deepspeed=./ds_config_zero3.json"
  USE_DS="--deepspeed=${deepspeed_config_file}"
else
  USE_DS=""
fi

torchrun --nnodes 1 --nproc_per_node 8 run_clm.py \
    ${USE_DS} \
	...

代码修改如下

TIACC_TRAINING_DYNAMIC_ZERO = int(os.getenv('TIACC_TRAINING_DYNAMIC_ZERO', '0'))
from contextlib import nullcontext
if TIACC_TRAINING_DYNAMIC_ZERO == 1:
    from tilearn.llm.trainer import TrainerTiacc as Trainer
    from tilearn.llm import init as llm_init
    from tilearn.llm import get_config as llm_get_config
	

	
### init in main func
def main():
    if TIACC_TRAINING_DYNAMIC_ZERO == 1:
        llm_config = llm_get_config()
        llm_init_context = llm_init(init_in_cpu=llm_config.cpu_init,
                                    shard_init=llm_config.shard_init,
                                    model_dtype=torch.half)
									
### add init_context when model init
    init_context = llm_init_context if TIACC_TRAINING_DYNAMIC_ZERO == 1 else nullcontext
    with init_context():
		### 接口与标准huggingface一致
        model = LlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            low_cpu_mem_usage=False #True,
			...
        )
		
		
### use trainer
    ### 接口与标准huggingface一致
    trainer = Trainer(
        model=model,
        ...
    )

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

tilearn_llm-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

tilearn_llm-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.5 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

tilearn_llm-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.4 MB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file tilearn_llm-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 b1f566322a9ada847c05f86eb74dd28abf99d8a9efb86bd68fb6f4e9dd79a0db
MD5 27c0af6d28d2f69231ff434a764105e7
BLAKE2b-256 061bace63049e96c06ef5ee4fe491914a5632eb2315983d54cba797b120194ab

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 7cab6e3e35a4043e0ad8969ccf5f48cc2099faa5975c7ede5985e9d944861f82
MD5 90b3f3a4e72090af1979a195842dbe9d
BLAKE2b-256 38364e29cae03cab8b33aad5cdb5262bc53fe1c4b4a3f6d023f272b19ba18941

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 a2defd325193c69124b884c435553e205457a781be658a3e7c1f782c5e96f197
MD5 8996ac57c0538851ab34cc3f249ef73e
BLAKE2b-256 c38bd48ae866a2f83c6e3b1432c2535b5ca883a9d320bc4b1dd2077d1ca9fe18

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page