Supervised fine-tuning of LLMs with LoRA and DeepSpeed
Project description
Supervised Fine-Tuning (SFT) with LoRA and DeepSpeed
This project provides a streamlined pipeline for fine-tuning Large Language Models (LLMs) like Llama 3.1 using Low-Rank Adaptation (LoRA) and DeepSpeed for efficient distributed training.
📂 Directory Structure
SFT/
├── configs/ # Configuration files (planned)
├── <placeholder>/ # Datasets
├── <placeholder>/ # Output directory for checkpoints and adapters
├── run.sh # Main entry point script
├── lora_sft.py # Main training and inference script
├── config.yaml # Hyperparameters and paths configuration
├── ds_config.json # DeepSpeed configuration
├── requirements.txt # Python dependencies
└── README.md # This file
🚀 Setup
-
Create and activate a virtual environment:
python3 -m venv venv source venv/bin/activate
-
Install dependencies:
pip install -r requirements.txt
(Ensure
deepspeedis installed compatible with your CUDA version) -
Configure Environment: Create a
.envfile in the root directory:HF_TOKEN=your_huggingface_token WANDB_API_KEY=your_wandb_key # Optional, for logging
🛠️ Configuration
config.yaml: Controls model ID, dataset paths, training hyperparameters (learning rate, epochs, batch size), and LoRA settings.ds_config.json: Configures DeepSpeed optimization (ZeRO stage, offloading, mixed precision).
🏃 Usage
Use the provided scripts/run.sh wrapper for easy execution. It automatically handles directory paths.
Training
To start fine-tuning the model:
# Default: Train on 2 GPUs
./run.sh train 2
# Train on 4 GPUs
./scripts/run.sh train 4
Inference
To evaluate the fine-tuned model on the test set:
./scripts/run.sh inference
Custom Configuration
You can specify a custom DeepSpeed config file:
./scripts/run.sh --config custom_ds_config.json train 4
📊 Monitoring
Training progress (loss, accuracy, etc.) is logged to MLflow (and/or WandB if configured).
To view MLflow logs locally:
mlflow ui
Then open http://localhost:5000 in your browser.
🐛 Troubleshooting
deepspeed: command not found: Ensure you have activated the virtual environment wheredeepspeedis installed.- CUDA Errors: Check
ds_config.jsonto ensure batch sizes and offloading settings fit your GPU memory.
☸️ Ray Train + Kubernetes (KubeRay)
This repo now includes a Ray Train entrypoint for running multi-GPU training on a Ray cluster (including on Kubernetes via KubeRay).
- Ray entrypoint:
SFT/ray_train_lora_sft.py - KubeRay RayJob template:
SFT/k8s/rayjob-lora-sft.yaml - Container build:
SFT/k8s/Dockerfile
Local Ray (single node)
pip install -r requirements.txt
python ray_train_lora_sft.py --num_workers 2
Kubernetes (high level)
- Build/push the image from
SFT/k8s/Dockerfileand set it inrayjob-lora-sft.yaml. - Create PVCs for:
/workspace/SFT/data_dir(your training data)/mnt/ray-results(Ray Train run storage / checkpoints)
- Apply the RayJob:
kubectl apply -f SFT/k8s/rayjob-lora-sft.yaml
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file neotune-0.1.0.tar.gz.
File metadata
- Download URL: neotune-0.1.0.tar.gz
- Upload date:
- Size: 12.6 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
96391b132c3ebbd75348a5a3c6fee2592be9a3058f8b95637f62fcfb4e58692f
|
|
| MD5 |
8cacb5b98fe6464d55f86da8cd0f7c5f
|
|
| BLAKE2b-256 |
2885aa5a605f0e45e75620cbeb075927c8db0400bed5d6511d02d8c995f249bd
|
File details
Details for the file neotune-0.1.0-py3-none-any.whl.
File metadata
- Download URL: neotune-0.1.0-py3-none-any.whl
- Upload date:
- Size: 13.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.12
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
7ac132ea7bc0959cc49e0b03b4c0d99e6aafe22105ac8fdc24c96d4b6b73281f
|
|
| MD5 |
bf2d1cacec78490b0c7b2fdcba3914d5
|
|
| BLAKE2b-256 |
105bf835a945a6728c89767901f249e76bed4d7aa6b18a06dbc7456dc807226f
|