Skip to main content

TILEARN for LLM

Project description

Tilearn.llm使用说明

1. CUDA Kernel(以LLAMA为例)

支持显卡:Ampere, Ada, or Hopper GPUs (e.g., A100, A800, H100, H800)

新版本

新版本Dependencies: pytorch >= 2.0.0

该版本完全兼容huggingface接口,不需要额外的转模型操作

LLAMA1/LLAMA2 A800 16GPU seq=1024相比deepspeed zero2训练加速约20%

cuda kernel使用方法-启动脚本修改如下

### TIACC CUDA Kernel
### Open: TIACC_TRAINING_CUDA_KERNEL=1
### Close: TIACC_TRAINING_CUDA_KERNEL=0
export TIACC_TRAINING_CUDA_KERNEL=1

cuda kernel使用方法-代码修改如下

### TIACC
TIACC_TRAINING_CUDA_KERNEL = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL', '0'))
if TIACC_TRAINING_CUDA_KERNEL == 1:
    from tilearn.llm.transformers import LlamaForCausalLM

### 模型接口与标准huggingface一致
model = LlamaForCausalLM.from_pretrained(...)
### TIACC
TIACC_TRAINING_CUDA_KERNEL = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL', '0'))
if TIACC_TRAINING_CUDA_KERNEL == 1:
    from tilearn.llm.transformers import AutoModelForCausalLM

### 模型接口与标准huggingface一致
model = AutoModelForCausalLM.from_pretrained(...)

旧版本

旧版本Dependencies: flash-attention 请安装https://github.com/Dao-AILab/flash-attention, 建议源码安装

### compile from source
git clone --recursive https://github.com/Dao-AILab/flash-attention
cd flash-attention && python setup.py install

### install layer_norm, fused_dense and rotary kernel
cd flash-attention/csrc/layer_norm && pip3 install .
cd flash-attention/csrc/fused_dense_lib && pip install .
cd flash-attention/csrc/rotary && pip install .

该版本不兼容huggingface接口,可直接读取huggingface模型和原始cuda kernel模型(训练保存的模型结构)

由于训练保存的模型为原始cuda kernel模型,非huggingface结构,若需要huggingface模型则手动执行脚本转换

LLAMA1/LLAMA2 A800 16GPU seq=1024相比deepspeed zero2训练加速约30%

cuda kernel使用方法-启动脚本修改如下

### TIACC CUDA Kernel
### Open: TIACC_TRAINING_CUDA_KERNEL_V0=1
### Close: TIACC_TRAINING_CUDA_KERNEL_V0=0
export TIACC_TRAINING_CUDA_KERNEL_V0=1
export TIACC_TRAINING_MODEL_FORMAT=llama-hf
# 若读取huggingface模型结构,则设置llama-hf
export TIACC_TRAINING_MODEL_FORMAT=llama-hf
# 若原始cuda kernel模型,则设置llama-origin
export TIACC_TRAINING_MODEL_FORMAT=llama-origin

cuda kernel使用方法-代码修改如下

### TIACC
TIACC_TRAINING_CUDA_KERNEL_V0 = int(os.getenv('TIACC_TRAINING_CUDA_KERNEL_V0', '0'))
if TIACC_TRAINING_CUDA_KERNEL_V0 == 1:
    from tilearn import llm

### LLAMA模型初始化
TIACC_TRAINING_MODEL_FORMAT = os.getenv('TIACC_TRAINING_MODEL_FORMAT', 'llama-origin')
model = llm.models.llama(model_args.model_name_or_path, model_format=TIACC_TRAINING_MODEL_FORMAT)

2. Static Zero

适用场景:在deepspeed zero1、zero2、zero3、offload、int8等不同优化状态间切换

启动脚本修改如下

### TIACC STATIC ZERO
### Open: TIACC_TRAINING_CUDA_KERNEL='O2' 
### support 'O2' / 'O2.5' / 'O3' / 'O3.5' / 'O3_Q8'(doing)
### Close: TIACC_TRAINING_CUDA_KERNEL='None'
export TIACC_TRAINING_STATIC_ZERO='None' #'O2'

代码修改如下

from transformers import HfArgumentParser

TIACC_TRAINING_STATIC_ZERO = os.getenv('TIACC_TRAINING_STATIC_ZERO', 'None')
if TIACC_TRAINING_STATIC_ZERO != 'None':
    from tilearn.llm.transformers import TrainingArguments
	
### 接口与标准huggingface一致
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

3. Dynamic Zero

适用场景:适用于zero3 + offload场景,大幅优化显存从而提升batchsize

启动脚本修改如下

### TIACC DYNAMIC ZERO
### Open: TIACC_TRAINING_DYNAMIC_ZERO=1 and set TIACC_ZERO_STAGE/TIACC_ZERO_STAGE/TIACC_PLACEMENT/TIACC_SHARD_INIT/TIACC_CPU_INIT
### Close: TIACC_TRAINING_DYNAMIC_ZERO=0
export TIACC_TRAINING_DYNAMIC_ZERO=0
export TIACC_ZERO_STAGE=3 #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_PLACEMENT='cpu' #'cuda' #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_SHARD_INIT=0 #work when TIACC_TRAINING_DYNAMIC_ZERO=1
export TIACC_CPU_INIT=1 #work when TIACC_TRAINING_DYNAMIC_ZERO=1

if [ ${TIACC_TRAINING_DYNAMIC_ZERO} = 0 ]; then
  #USE_DS="--deepspeed=./ds_config_zero3.json"
  USE_DS="--deepspeed=${deepspeed_config_file}"
else
  USE_DS=""
fi

torchrun --nnodes 1 --nproc_per_node 8 run_clm.py \
    ${USE_DS} \
	...

代码修改如下

TIACC_TRAINING_DYNAMIC_ZERO = int(os.getenv('TIACC_TRAINING_DYNAMIC_ZERO', '0'))
from contextlib import nullcontext
if TIACC_TRAINING_DYNAMIC_ZERO == 1:
    from tilearn.llm.trainer import TrainerTiacc as Trainer
    from tilearn.llm import init as llm_init
    from tilearn.llm import get_config as llm_get_config
	

	
### init in main func
def main():
    if TIACC_TRAINING_DYNAMIC_ZERO == 1:
        llm_config = llm_get_config()
        llm_init_context = llm_init(init_in_cpu=llm_config.cpu_init,
                                    shard_init=llm_config.shard_init,
                                    model_dtype=torch.half)
									
### add init_context when model init
    init_context = llm_init_context if TIACC_TRAINING_DYNAMIC_ZERO == 1 else nullcontext
    with init_context():
		### 接口与标准huggingface一致
        model = LlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            low_cpu_mem_usage=False #True,
			...
        )
		
		
### use trainer
    ### 接口与标准huggingface一致
    trainer = Trainer(
        model=model,
        ...
    )

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

tilearn_llm-0.5.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

tilearn_llm-0.5.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

tilearn_llm-0.5.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.7 MB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file tilearn_llm-0.5.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 eccfa716ddb8ee0c1fd1c087e3acba06d1498200b583787a9a6122a881811506
MD5 611d5bbfaf127409fdc0fe26661db803
BLAKE2b-256 4646684438114e130a995190a56e512669ee181933186a484721e34af714f195

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.5.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6e8ca2f2f5aa36affec2a8080da479c2f92eb6a120437740fec8e4d2e4385b88
MD5 f3cf5463a79f5c8ba97b31ceadf14fd6
BLAKE2b-256 adebd2b645d104452fd5db8296f30ba90377c0e858d20441ff2901766e4a6e07

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.5.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.5.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 51328d3bdbc304ac6ba001fa53e151386eb371b588a6857508c3ac8117edb0c4
MD5 16112edf24e94b980745d2f5be141c3e
BLAKE2b-256 570feea7bfb3af1557f21fb52039c39e1f7148c3c3e15da0836bc65315ebbe62

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page