Skip to main content

YaRN Implementation compatible with HuggingFace Transformers

Project description

meta-llama-3-yarn

Getting Started

Install with pip

pip install meta-llama-3-yarn

Build from source

You can also build and install meta-llama-3-yarn from source:

git clone https://github.com/MeetKai/meta-llama-3-yarn.git
cd meta-llama-3-yarn
pip install -e .

Usage

To use a Llama-3 model regardless whether it is a YaRN-scaled model:

import torch
from transformers import AutoTokenizer
from meta_llama_3_yarn import LlamaForCausalLM

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(
    model_name, device_map="auto", torch_dtype=torch.bfloat16
)

messages = [
    {"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)

result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)

To specifically scale a model using YaRN:

import torch
from transformers import AutoTokenizer
from meta_llama_3_yarn import LlamaForCausalLM, LlamaConfig

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
context_length = 32768
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = LlamaConfig.from_pretrained(model_name)
config.rope_scaling = {"type": "yarn", "factor": context_length / config.max_position_embeddings}
model = LlamaForCausalLM.from_pretrained(
    model_name, config=config, device_map="auto", torch_dtype=torch.bfloat16
)

messages = [
    {"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)

result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

meta_llama_3_yarn-0.1.2.tar.gz (23.5 kB view details)

Uploaded Source

File details

Details for the file meta_llama_3_yarn-0.1.2.tar.gz.

File metadata

  • Download URL: meta_llama_3_yarn-0.1.2.tar.gz
  • Upload date:
  • Size: 23.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.1.0 CPython/3.12.4

File hashes

Hashes for meta_llama_3_yarn-0.1.2.tar.gz
Algorithm Hash digest
SHA256 aea22067c28cb80890544752bc40b74e25a9ee5caebbee1c0b6a7ea82ede6494
MD5 07fd15366abeed860fff688614d1a2a1
BLAKE2b-256 7fea0dd773d3a980f4047af0687edb9515930c7b965762d13a8d06df9ab92d5f

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page