Skip to main content

TILEARN for LLM

Project description

Tilearn.llm使用说明

1. CUDA Kernel(以LLAMA为例)

支持显卡:Ampere, Ada, or Hopper GPUs (e.g., A100, A800, H100, H800)

Dependencies: pytorch >= 2.0.0

当前版本完全兼容huggingface接口,不需要额外的操作

LLAMA1/LLAMA2 A800 16GPU seq=1024相比deepspeed zero2训练加速约20%

cuda kernel使用方法-代码修改如下

### TILEARN.LLM
from tilearn.llm.transformers import LlamaForCausalLM

### 模型接口与标准huggingface一致
model = LlamaForCausalLM.from_pretrained(...)

或者使用AutoModelForCausalLM接口

### TILEARN.LLM
from tilearn.llm.transformers import AutoModelForCausalLM

### 模型接口与标准huggingface一致
model = AutoModelForCausalLM.from_pretrained(...)

特殊说明:

1、由于baichuan1 13B和baichuan2 13B会产生冲突,目前tilearn.llm.transformers.AutoModelForCausalLM默认开启了baichuan1 13B,如果需要使用baichuan2 13B,需要在启动训练脚本中设置环境变量:export TILEARN_LLM_BAICHUAN_13B=2

### TILEARN_LLM_BAICHUAN_13B open baichuan2 model
export TILEARN_LLM_BAICHUAN_13B=2

2、目前加速已经支持的模型列表:

# llama
from tilearn.llm.transformers.models.llama.modeling_llama import LlamaForCausalLM

# bloom
from tilearn.llm.transformers.models.bloom.modeling_bloom import BloomForCausalLM

# baichuan1
from tilearn.llm.transformers.models.baichuan.baichuan1_13B.modeling_baichuan import BaichuanForCausalLM
from tilearn.llm.transformers.models.baichuan.baichuan1_7B.modeling_baichuan import BaiChuanForCausalLM

# baichuan2
# 默认使用TILEARN.LLM且无需任何设置
# 单独使用xformer,需安装xformer且设置环境变量TIACC_TRAINING_CUDA_KERNEL=2
from tilearn.llm.transformers.models.baichuan.baichuan2_7B.modeling_baichuan import BaichuanForCausalLM
from tilearn.llm.transformers.models.baichuan.baichuan2_13B.modeling_baichuan import BaichuanForCausalLM

2. Static Zero

适用场景:在deepspeed zero1、zero2、zero3、offload、int8等不同优化状态间切换

启动脚本修改如下

### TILEARN STATIC ZERO
### Open: TIACC_TRAINING_CUDA_KERNEL='O2' 
### support 'O2' / 'O2.5' / 'O3' / 'O3.5' / 'O3_Q8'(doing)
### Close: TIACC_TRAINING_CUDA_KERNEL='None'
export TIACC_TRAINING_STATIC_ZERO='None' #'O2'

代码修改如下

from transformers import HfArgumentParser
from tilearn.llm.transformers import TrainingArguments
	
### 接口与标准huggingface一致
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

Project details


Release history Release notifications | RSS feed

This version

0.7.1

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distributions

If you're not sure about the file name format, learn more about wheel file names.

tilearn_llm-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.10manylinux: glibc 2.17+ x86-64

tilearn_llm-0.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.9manylinux: glibc 2.17+ x86-64

tilearn_llm-0.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB view details)

Uploaded CPython 3.8manylinux: glibc 2.17+ x86-64

File details

Details for the file tilearn_llm-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 16c650b8c61f8e02cbe07c3e7939f3269ea43745aed7d94da90177048d73a265
MD5 39b9943fbae23495c656e84b0e0ab02f
BLAKE2b-256 eebc22cbae35121810e12eb03d042ac76fdcce5bdddc6bda4361cc5487a68b45

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 6aa69c02f31a9410ee6413746e740a6cf188c0217f564c12d64b84f166a79d38
MD5 3cf16c9fc94d9d546a0520b71ad7afed
BLAKE2b-256 d0084ae5ba4677a68864e9c8aded08139c5b3f231353ff8f0554bfc56d2994bc

See more details on using hashes here.

File details

Details for the file tilearn_llm-0.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.

File metadata

File hashes

Hashes for tilearn_llm-0.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm Hash digest
SHA256 0964dee14179e477ae80b484408269dddeb8b0d0d6f7ef0c346ce0b9590f07ac
MD5 31e724834343fd6a3e2992cd77c979ad
BLAKE2b-256 0cea8da5c4d1147fb55b72eb3fdc9e71e63f6f1799c90520bdbefcd09a08cdcb

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page