Skip to main content

TILEARN for LLM

Project description

Tilearn.llm使用说明

1. CUDA Kernel(以LLAMA为例)

支持显卡:Ampere, Ada, or Hopper GPUs (e.g., A100, A800, H100, H800)

Dependencies: pytorch >= 2.0.0

当前版本完全兼容huggingface接口,不需要额外的操作

LLAMA1/LLAMA2 A800 16GPU seq=1024相比deepspeed zero2训练加速约20%

cuda kernel使用方法-代码修改如下

### TILEARN.LLM
from tilearn.llm.transformers import LlamaForCausalLM

### 模型接口与标准huggingface一致
model = LlamaForCausalLM.from_pretrained(...)

或者使用AutoModelForCausalLM接口

### TILEARN.LLM
from tilearn.llm.transformers import AutoModelForCausalLM

### 模型接口与标准huggingface一致
model = AutoModelForCausalLM.from_pretrained(...)

特殊说明:

1、由于baichuan1 13B和baichuan2 13B会产生冲突,目前tilearn.llm.transformers.AutoModelForCausalLM默认开启了baichuan1 13B,如果需要使用baichuan2 13B,需要在启动训练脚本中设置环境变量:export TILEARN_LLM_BAICHUAN_13B=2

### TILEARN_LLM_BAICHUAN_13B open baichuan2 model
export TILEARN_LLM_BAICHUAN_13B=2

2、目前加速已经支持的模型列表:

# llama
from tilearn.llm.transformers.models.llama.modeling_llama import LlamaForCausalLM

# bloom
from tilearn.llm.transformers.models.bloom.modeling_bloom import BloomForCausalLM

# baichuan1
from tilearn.llm.transformers.models.baichuan.baichuan1_13B.modeling_baichuan import BaichuanForCausalLM
from tilearn.llm.transformers.models.baichuan.baichuan1_7B.modeling_baichuan import BaiChuanForCausalLM

# baichuan2
# 默认使用TILEARN.LLM且无需任何设置
# 单独使用xformer,需安装xformer且设置环境变量TIACC_TRAINING_CUDA_KERNEL=2
from tilearn.llm.transformers.models.baichuan.baichuan2_7B.modeling_baichuan import BaichuanForCausalLM
from tilearn.llm.transformers.models.baichuan.baichuan2_13B.modeling_baichuan import BaichuanForCausalLM

2. Static Zero

适用场景:在deepspeed zero1、zero2、zero3、offload、int8等不同优化状态间切换

启动脚本修改如下

### TILEARN STATIC ZERO
### Open: TIACC_TRAINING_CUDA_KERNEL='O2' 
### support 'O2' / 'O2.5' / 'O3' / 'O3.5' / 'O3_Q8'(doing)
### Close: TIACC_TRAINING_CUDA_KERNEL='None'
export TIACC_TRAINING_STATIC_ZERO='None' #'O2'

代码修改如下

from transformers import HfArgumentParser
from tilearn.llm.transformers import TrainingArguments
	
### 接口与标准huggingface一致
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

tilearn_llm-0.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

tilearn_llm-0.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

tilearn_llm-0.6.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file tilearn_llm-0.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 d4775718c52e98a44d1ac3365b3dbea559ae62428d0ff31f19bb78ffd29304b2
MD5 40240b16fc0560e86999862649cfcfff
BLAKE2b-256 204098ee9887a3b98f0b887ed54a2f0cc0b2051c7b735a9bd433f42748494487

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 dde6bc2895361f6cc5802f7602e6d3c7dd443043111a5a6d839ef20ddf17f79c
MD5 1186ec6993c17bc6ce02baea5dec8b40
BLAKE2b-256 75da3c520dc51b66ae14a85b3e24420181e66545c162820673d26facdd8fcaa1

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.6.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.6.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 3c3c1ebaa1f6f9ca3562e1110f7cbc05305413751690452c46cb62f22c3318cd
MD5 27eaf3a8c3132e298770d3323af3db21
BLAKE2b-256 ed6c68ec45d19cea9b3e8cabe6dfa7d79894cde2ff5b42cf50fcca4c0f379d00

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page