Skip to main content

Supervised fine-tuning of LLMs with LoRA and DeepSpeed

Project description

neotune

Supervised fine-tuning of LLMs with LoRA and DeepSpeed, packaged as a simple Python API.

Installation

pip install neotune

With optional extras:

pip install "neotune[ray]"       # Ray Train + Kubernetes support
pip install "neotune[logging]"   # MLflow + Weights & Biases
pip install "neotune[all]"       # everything

Quick Start

from datasets import load_dataset
from neotune import finetune

# Load and prepare your datasets (each split needs a "text" column)
ds = load_dataset("tatsu-lab/alpaca")
train_ds = ds["train"]  # must have a "text" column
val_ds = ds["validation"]

results = finetune(
    model="meta-llama/Llama-3.1-8B-Instruct",
    datasets={"train": train_ds, "validation": val_ds},
    hyperparameters={"learning_rate": 2e-4, "num_train_epochs": 3},
)

Or use the class-based API for more control:

from neotune import NeoTune

nt = NeoTune(
    model="meta-llama/Llama-3.1-8B-Instruct",
    datasets={"train": train_ds, "validation": val_ds, "test": test_ds},
    hyperparameters={"learning_rate": 2e-4, "output_dir": "./my-adapter"},
)

results = nt.train()
nt.evaluate()

API Reference

finetune(model, datasets, hyperparameters)dict

One-call convenience function. Creates a NeoTune instance and trains immediately.

NeoTune(model, datasets, hyperparameters)

modelstr

HuggingFace model ID or local path.

model="meta-llama/Llama-3.1-8B-Instruct"

datasetsdict[str, Dataset]

A dict of HuggingFace Dataset objects. Only "train" is required; "validation" and "test" are optional.

Each dataset must contain either:

  • A "text" column with fully-formatted prompt/response text (will be tokenized automatically), or
  • Pre-tokenized columns: input_ids, attention_mask, and labels
from datasets import load_dataset

ds = load_dataset("my_dataset")
datasets = {
    "train": ds["train"],
    "validation": ds["validation"],
    "test": ds["test"],       # optional — used for final evaluation
}

hyperparametersdict, optional

Override any default. All keys are optional.

Training:

Key Default Description
learning_rate 1e-4 Learning rate
num_train_epochs 3 Number of training epochs
batch_size 1 Per-device train/eval batch size
gradient_accumulation_steps 4 Gradient accumulation steps
warmup_ratio 0.03 Warmup ratio
weight_decay 0.01 Weight decay
bf16 True Use bfloat16 mixed precision
gradient_checkpointing False Enable gradient checkpointing
logging_steps 10 Log every N steps
eval_steps 50 Evaluate every N steps
save_steps 100 Save checkpoint every N steps
save_total_limit 3 Maximum checkpoints to keep

LoRA:

Key Default Description
lora_r 16 LoRA rank
lora_alpha 32 LoRA alpha scaling
lora_dropout 0.05 LoRA dropout
lora_target_modules ["q_proj", "k_proj", ...] Modules to apply LoRA to

Output:

Key Default Description
output_dir "./adapter-output" Where to save the trained adapter
hf_repo None HuggingFace Hub repo to push to

DeepSpeed:

Key Default Description
ds_config None DeepSpeed config: None (built-in ZeRO-2), a file path, or a dict

Data:

Key Default Description
max_len 2048 Max token length (used when tokenizing a "text" column)
hyperparameters={
    "learning_rate": 2e-4,
    "num_train_epochs": 5,
    "lora_r": 32,
    "output_dir": "./my-adapter",
    "hf_repo": "username/my-adapter",
}

Methods

.train()dict

Fine-tunes the model and returns test-set evaluation metrics (accuracy, f1_macro, precision_macro, recall_macro). Returns an empty dict if no test split was provided.

.evaluate()tuple

Runs inference on the test split using the saved adapter. Returns (accuracy, f1, precision, recall). Requires a "test" split.

Advanced Usage

Distributed training with DeepSpeed (CLI)

deepspeed --num_gpus 4 -m neotune.train --config config.yaml --mode train

Distributed training with Ray

python -m neotune.ray_train --config config.yaml --num_workers 4

Kubernetes (KubeRay)

See k8s/rayjob-lora-sft.yaml for a KubeRay RayJob template.

Environment Variables

Variable Description
HF_TOKEN HuggingFace access token (required for gated models)
WANDB_API_KEY Weights & Biases API key (optional)

Create a .env file in your working directory:

HF_TOKEN=your_token_here
WANDB_API_KEY=your_wandb_key_here

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

neotune-0.3.0.tar.gz (15.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

neotune-0.3.0-py3-none-any.whl (17.7 kB view details)

Uploaded Python 3

File details

Details for the file neotune-0.3.0.tar.gz.

File metadata

  • Download URL: neotune-0.3.0.tar.gz
  • Upload date:
  • Size: 15.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for neotune-0.3.0.tar.gz
Algorithm Hash digest
SHA256 e2ff6497576039c1a0920aa2a808bf0d52e70f149022f39e4aef04f94a0c8fc6
MD5 597654ab78d9cc07e69f60ff86cfa681
BLAKE2b-256 cdb54ff214325bd28cf4fd4ef249cca1862029c23c98ce590e99c68a8ff5b011

See more details on using hashes here.

File details

Details for the file neotune-0.3.0-py3-none-any.whl.

File metadata

  • Download URL: neotune-0.3.0-py3-none-any.whl
  • Upload date:
  • Size: 17.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.10.12

File hashes

Hashes for neotune-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 3d766cc09652a3a1d5ce209c5d979a7ada48e915b148a8506210b665808ad6cc
MD5 39cc6069736b5ab973e74e1566a13ded
BLAKE2b-256 cac8d062702ccb6c55a5f77cff9de2a13860da425616b4d4d1c51a67706ff75a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page