YaRN Implementation compatible with HuggingFace Transformers
Project description
meta-llama-3-yarn
Getting Started
Install with pip
pip install meta-llama-3-yarn
Build from source
You can also build and install meta-llama-3-yarn from source:
git clone https://github.com/MeetKai/meta-llama-3-yarn.git
cd meta-llama-3-yarn
pip install -e .
Usage
To use a Llama-3 model regardless whether it is a YaRN-scaled model:
import torch
from transformers import AutoTokenizer
from meta_llama_3_yarn import LlamaForCausalLM
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype=torch.bfloat16
)
messages = [
{"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)
To specifically scale a model using YaRN:
import torch
from transformers import AutoTokenizer
from meta_llama_3_yarn import LlamaForCausalLM, LlamaConfig
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
context_length = 32768
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = LlamaConfig.from_pretrained(model_name)
config.rope_scaling = {"type": "yarn", "factor": context_length / config.max_position_embeddings}
model = LlamaForCausalLM.from_pretrained(
model_name, config=config, device_map="auto", torch_dtype=torch.bfloat16
)
messages = [
{"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
meta_llama_3_yarn-0.1.3.tar.gz
(23.5 kB
view details)
Built Distribution
File details
Details for the file meta_llama_3_yarn-0.1.3.tar.gz
.
File metadata
- Download URL: meta_llama_3_yarn-0.1.3.tar.gz
- Upload date:
- Size: 23.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ce67e915aef2fbda3da3d1e13774a7a69c1eb8b679a2b059bcf5369682a2c665 |
|
MD5 | 70527e0626b06a6ca4d353ba25cd2f67 |
|
BLAKE2b-256 | 147030bd2f256cc5bb209d24faf3f1f161d546e211c0042dcf8e6ba979585c43 |
File details
Details for the file meta_llama_3_yarn-0.1.3-py3-none-any.whl
.
File metadata
- Download URL: meta_llama_3_yarn-0.1.3-py3-none-any.whl
- Upload date:
- Size: 24.8 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ed4df3993c2cbce05f805e3be7a8af75087a1dd38c81e1ac0cfe0a41a1f57e5e |
|
MD5 | 600388b467f298f45fe149520d1b68b4 |
|
BLAKE2b-256 | b7da512872590065ce30a2a519ce7785742d77546cbc7c837ebdb66fca45448f |