Skip to main content

YaRN Implementation compatible with HuggingFace Transformers

Project description

meta-llama-3-yarn

Getting Started

Install with pip

pip install meta-llama-3-yarn

Build from source

You can also build and install meta-llama-3-yarn from source:

git clone https://github.com/MeetKai/meta-llama-3-yarn.git
cd meta-llama-3-yarn
pip install -e .

Usage

To use a Llama-3 model regardless whether it is a YaRN-scaled model:

import torch
from transformers import AutoTokenizer
from meta_llama_3_yarn import LlamaForCausalLM

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(
    model_name, device_map="auto", torch_dtype=torch.bfloat16
)

messages = [
    {"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)

result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)

To specifically scale a model using YaRN:

import torch
from transformers import AutoTokenizer
from meta_llama_3_yarn import LlamaForCausalLM, LlamaConfig

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
context_length = 32768
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = LlamaConfig.from_pretrained(model_name)
config.rope_scaling = {"type": "yarn", "factor": context_length / config.max_position_embeddings}
model = LlamaForCausalLM.from_pretrained(
    model_name, config=config, device_map="auto", torch_dtype=torch.bfloat16
)

messages = [
    {"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)

result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

meta_llama_3_yarn-0.1.3.tar.gz (23.5 kB view details)

Uploaded Source

Built Distribution

meta_llama_3_yarn-0.1.3-py3-none-any.whl (24.8 kB view details)

Uploaded Python 3

File details

Details for the file meta_llama_3_yarn-0.1.3.tar.gz.

File metadata

  • Download URL: meta_llama_3_yarn-0.1.3.tar.gz
  • Upload date:
  • Size: 23.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for meta_llama_3_yarn-0.1.3.tar.gz
Algorithm Hash digest
SHA256 ce67e915aef2fbda3da3d1e13774a7a69c1eb8b679a2b059bcf5369682a2c665
MD5 70527e0626b06a6ca4d353ba25cd2f67
BLAKE2b-256 147030bd2f256cc5bb209d24faf3f1f161d546e211c0042dcf8e6ba979585c43

See more details on using hashes here.

File details

Details for the file meta_llama_3_yarn-0.1.3-py3-none-any.whl.

File metadata

File hashes

Hashes for meta_llama_3_yarn-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 ed4df3993c2cbce05f805e3be7a8af75087a1dd38c81e1ac0cfe0a41a1f57e5e
MD5 600388b467f298f45fe149520d1b68b4
BLAKE2b-256 b7da512872590065ce30a2a519ce7785742d77546cbc7c837ebdb66fca45448f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page