Skip to main content

TILEARN for LLM

Project description

Tilearn.llm使用说明

1. CUDA Kernel(以LLAMA为例)

支持显卡:Ampere, Ada, or Hopper GPUs (e.g., A100, A800, H100, H800)

新版本

新版本Dependencies: pytorch >= 2.0.0

该版本完全兼容huggingface接口,不需要额外的转模型操作

LLAMA1/LLAMA2 A800 16GPU seq=1024相比deepspeed zero2训练加速约20%

cuda kernel使用方法-启动脚本修改如下

### TIACC CUDA Kernel
### Open: TIACC_TRAINING_CUDA_KERNEL=1
### Close: TIACC_TRAINING_CUDA_KERNEL=0
export TIACC_TRAINING_CUDA_KERNEL=1

cuda kernel使用方法-代码修改如下

### TIACC
TIACC_TRAINING_CUDA_KERNEL = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL', '0'))
if TIACC_TRAINING_CUDA_KERNEL == 1:
    from tilearn.llm.transformers import LlamaForCausalLM

### 模型接口与标准huggingface一致
model = LlamaForCausalLM.from_pretrained(...)
### TIACC
TIACC_TRAINING_CUDA_KERNEL = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL', '0'))
if TIACC_TRAINING_CUDA_KERNEL == 1:
    from tilearn.llm.transformers import AutoModelForCausalLM

### 模型接口与标准huggingface一致
model = AutoModelForCausalLM.from_pretrained(...)

旧版本

旧版本Dependencies: flash-attention 请安装https://github.com/Dao-AILab/flash-attention, 建议源码安装

### compile from source
git clone --recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention && python setup.py install

### install layer_norm, fused_dense and rotary kernel
cd flash-attention/csrc/layer_norm && pip3 install .
cd flash-attention/csrc/fused_dense_lib && pip install .
cd flash-attention/csrc/rotary && pip install .

该版本不兼容huggingface接口,可直接读取huggingface模型和原始cuda kernel模型(训练保存的模型结构)

由于训练保存的模型为原始cuda kernel模型,非huggingface结构,若需要huggingface模型则手动执行脚本转换

LLAMA1/LLAMA2 A800 16GPU seq=1024相比deepspeed zero2训练加速约30%

cuda kernel使用方法-启动脚本修改如下

### TIACC CUDA Kernel
### Open: TIACC_TRAINING_CUDA_KERNEL_V0=1
### Close: TIACC_TRAINING_CUDA_KERNEL_V0=0
export TIACC_TRAINING_CUDA_KERNEL_V0=1
export TIACC_TRAINING_MODEL_FORMAT=llama-hf
# 若读取huggingface模型结构,则设置llama-hf
export TIACC_TRAINING_MODEL_FORMAT=llama-hf
# 若原始cuda kernel模型,则设置llama-origin
export TIACC_TRAINING_MODEL_FORMAT=llama-origin

cuda kernel使用方法-代码修改如下

### TIACC
TIACC_TRAINING_CUDA_KERNEL_V0 = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL_V0', '0'))
if TIACC_TRAINING_CUDA_KERNEL_V0 == 1:
    from tilearn import llm

### LLAMA模型初始化
TIACC_TRAINING_MODEL_FORMAT = os.getenv('TIACC_TRAINING_MODEL_FORMAT', 'llama-origin')
model = llm.models.llama(model_args.model_name_or_path, model_format=TIACC_TRAINING_MODEL_FORMAT)

2. Static Zero

适用场景:在deepspeed zero1、zero2、zero3、offload、int8等不同优化状态间切换

启动脚本修改如下

### TIACC STATIC ZERO
### Open: TIACC_TRAINING_CUDA_KERNEL='O2' 
### support 'O2' / 'O2.5' / 'O3' / 'O3.5' / 'O3_Q8'(doing)
### Close: TIACC_TRAINING_CUDA_KERNEL='None'
export TIACC_TRAINING_STATIC_ZERO='None' #'O2'

代码修改如下

from transformers import HfArgumentParser

TIACC_TRAINING_STATIC_ZERO = os.getenv('TIACC_TRAINING_STATIC_ZERO', 'None')
if TIACC_TRAINING_STATIC_ZERO != 'None':
    from tilearn.llm.transformers import TrainingArguments
	
### 接口与标准huggingface一致
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

3. Dynamic Zero

适用场景:适用于zero3 + offload场景,大幅优化显存从而提升batchsize

启动脚本修改如下

### TIACC DYNAMIC ZERO
### Open: TIACC_TRAINING_DYNAMIC_ZERO=1 and set TIACC_ZERO_STAGE/TIACC_ZERO_STAGE/TIACC_PLACEMENT/TIACC_SHARD_INIT/TIACC_CPU_INIT
### Close: TIACC_TRAINING_DYNAMIC_ZERO=0
export TIACC_TRAINING_DYNAMIC_ZERO=0
export TIACC_ZERO_STAGE=3 #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_PLACEMENT='cpu' #'cuda' #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_SHARD_INIT=0 #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_CPU_INIT=1 #work when TIACC_TRAINING_DYNAMIC_ZERO=1

if [ ${TIACC_TRAINING_DYNAMIC_ZERO} = 0 ]; then
  #USE_DS="--deepspeed=./ds_config_zero3.json"
  USE_DS="--deepspeed=${deepspeed_config_file}"
else
  USE_DS=""
fi

torchrun --nnodes 1 --nproc_per_node 8 run_clm.py \
    ${USE_DS} \
	...

代码修改如下

TIACC_TRAINING_DYNAMIC_ZERO = int(os.getenv('TIACC_TRAINING_DYNAMIC_ZERO', '0'))
from contextlib import nullcontext
if TIACC_TRAINING_DYNAMIC_ZERO == 1:
    from tilearn.llm.trainer import TrainerTiacc as Trainer
    from tilearn.llm import init as llm_init
    from tilearn.llm import get_config as llm_get_config
	

	
### init in main func
def main():
    if TIACC_TRAINING_DYNAMIC_ZERO == 1:
        llm_config = llm_get_config()
        llm_init_context = llm_init(init_in_cpu=llm_config.cpu_init,
                                    shard_init=llm_config.shard_init,
                                    model_dtype=torch.half)
									
### add init_context when model init
    init_context = llm_init_context if TIACC_TRAINING_DYNAMIC_ZERO == 1 else nullcontext
    with init_context():
		### 接口与标准huggingface一致
        model = LlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            low_cpu_mem_usage=False #True,
			...
        )
		
		
### use trainer
    ### 接口与标准huggingface一致
    trainer = Trainer(
        model=model,
        ...
    )

Project details


Release history Release notifications | RSS feed

This version

0.5.4

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

tilearn_llm-0.5.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.6 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

tilearn_llm-0.5.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.6 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

tilearn_llm-0.5.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.6 MB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file tilearn_llm-0.5.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0ac868fe87575ae7c6f4f9c88663ce74d063fb5486a33c910901cbd8aebcebb5
MD5 2aa51febaa1d938c3b6c6b3b4fe74168
BLAKE2b-256 79bc1107ebc46c88b8080b10ac50629764f62ecdf17de5dfd356eea07e67f426

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.5.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 9689a38c77a0bc2d3192baf5eb121f9409f2d176ddb1eb00f3d72fc5ab6444fe
MD5 9342d89a88b55a9e6fbe85630aadc1a2
BLAKE2b-256 c63dbaf1345157c77fca9bafab0e470feb4efe98a77962d3c7519b4bc8fdfa1f

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.5.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 38b9cf32cf67c59abd487d14d00c478f13b5e7e499efdd3d8d3f7fd6629222e8
MD5 a69cee344143e9f3ff6d86dce1a7c192
BLAKE2b-256 9f9399b7a86c70c1095be32f4de81dd6a6cc5ce1f37754756c2cea94f911e4f1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page