YaRN Implementation compatible with HuggingFace Transformers
Project description
meta-llama-3-yarn
Getting Started
Install with pip
pip install meta-llama-3-yarn
Build from source
You can also build and install meta-llama-3-yarn from source:
git clone https://github.com/MeetKai/meta-llama-3-yarn.git
cd meta-llama-3-yarn
pip install -e .
Usage
To use a Llama-3 model regardless whether it is a YaRN-scaled model:
import torch
from transformers import AutoTokenizer
from meta_llama_3_yarn import LlamaForCausalLM
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype=torch.bfloat16
)
messages = [
{"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)
To specifically scale a model using YaRN:
import torch
from transformers import AutoTokenizer
from meta_llama_3_yarn import LlamaForCausalLM, LlamaConfig
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
context_length = 32768
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = LlamaConfig.from_pretrained(model_name)
config.rope_scaling = {"type": "yarn", "factor": context_length / config.max_position_embeddings}
model = LlamaForCausalLM.from_pretrained(
model_name, config=config, device_map="auto", torch_dtype=torch.bfloat16
)
messages = [
{"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
meta_llama_3_yarn-0.1.2.tar.gz
(23.5 kB
view details)
File details
Details for the file meta_llama_3_yarn-0.1.2.tar.gz
.
File metadata
- Download URL: meta_llama_3_yarn-0.1.2.tar.gz
- Upload date:
- Size: 23.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/5.1.0 CPython/3.12.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | aea22067c28cb80890544752bc40b74e25a9ee5caebbee1c0b6a7ea82ede6494 |
|
MD5 | 07fd15366abeed860fff688614d1a2a1 |
|
BLAKE2b-256 | 7fea0dd773d3a980f4047af0687edb9515930c7b965762d13a8d06df9ab92d5f |