Skip to main content

A modern, async-first Python client for the Rundpo API

Project description

Rundpo Python Client

A modern, async-first Python client for the Rundpo API. This client provides a convenient way to interact with the Rundpo API for running DPO (Direct Preference Optimization) training.

Installation

pip install rundpo transformers torch peft

Complete Example

Here's a complete example that shows how to:

  1. Train a DPO adapter
  2. Download the trained adapter
  3. Run inference with the adapter
import os
import time
import torch
from rundpo import RundpoClient, DPOConfig, RunConfig, RunStatus, download_and_extract
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

# Initialize the client
client = RundpoClient()

# Check credits
credits = client.get_credits()
print(f"Remaining credits: {credits}")

# Upload your data file (assuming you have a JSONL file with chosen/rejected pairs)
file_upload = client.upload_file("training_data.jsonl")
print(f"File uploaded successfully! ID: {file_upload.file_id}")

# Configure DPO run
base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
config = DPOConfig(
    file_id=file_upload.file_id,
    run_config=RunConfig(
        base_model=base_model_name,
        gpus=2,
        dpo_num_train_epochs=5
    )
)

# Start DPO training
run_id = client.run_dpo(config)
print(f"Started DPO run with ID: {run_id}")

# Poll for completion
while True:
    result = client.get_status(run_id)
    status = result["status"]
    print(f"Run status: {status}")
    
    if status == RunStatus.COMPLETED:
        print("✓ Run completed successfully!")
        # Download and extract the model
        if result.get("download_url"):
            print("Downloading and extracting model...")
            adapter_path = download_and_extract(result["download_url"], run_id)
            print(f"Model downloaded and extracted to: {adapter_path}")
        break
    elif status == RunStatus.FAILED:
        print("✗ Run failed!")
        break
        
    # Wait 30 seconds before checking again
    time.sleep(30)

# Now let's run inference with the trained adapter!
print("\nRunning inference with the trained adapter...")

# Load the base model and adapter
model = AutoModelForCausalLM.from_pretrained(base_model_name)
model = PeftModel.from_pretrained(model, adapter_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

# Prepare the chat prompt
chat = [
    {"role": "user", "content": "Hey there, frontier language model! I'm curious about the potential intersections between machine learning techniques and the visual analysis of Renaissance art composition. Could you help me explore how neural networks might be able to detect subtle stylistic variations in perspective and color usage across different Italian city-state painting schools during the 15th century? I'm particularly interested in computational methods that could provide new insights into attribution and artistic lineage. Also, as a side note, I'd love to hear your thoughts on how machine learning might predict optimal ski route selection in backcountry terrain with variable snow conditions."}
]

chat_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

print("Prompt:")
print(chat_prompt)

# Tokenize the input
inputs = tokenizer(chat_prompt, return_tensors="pt")

# Generate the response
# Set parameters for generation
generation_config = {
    "max_new_tokens": 1000,  # Maximum length of generated response
    "temperature": 0.7,      # Controls randomness (0.0 = deterministic, 1.0 = more random)
    "top_p": 0.9,           # Nucleus sampling parameter
    "do_sample": True,      # Enable sampling
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": tokenizer.eos_token_id
}

# Generate output
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        **generation_config
    )

# Decode the generated response
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
print("\nGenerated Response:")
print(generated_text)

# Extract just the assistant's response (everything after the last <|im_start|>assistant)
assistant_response = generated_text.split("<|im_start|>assistant")[-1].strip()
print("\nAssistant's Response Only:")
print(assistant_response)

Training Data Format

Your training data should be a JSONL file where each line is a JSON object containing chosen/rejected response pairs. Here's an example format:

{"chosen": "This is a high-quality response", "rejected": "This is a lower-quality response", "prompt": "What is the meaning of life?"}

Using HuggingFace Datasets

Instead of uploading a file, you can use HuggingFace datasets directly:

config = DPOConfig(
    hf_sft_dataset_name="your-sft-dataset",
    hf_dpo_dataset_name="your-dpo-dataset",
    run_config=RunConfig(
        base_model="meta-llama/Llama-3.1-8B-Instruct",
        gpus=2,
        dpo_num_train_epochs=5
    )
)

Available Configuration Options

The RunConfig class supports all parameters from the API:

config = RunConfig(
    # Required
    base_model="meta-llama/Llama-3.1-8B-Instruct",  # Base model to use
    
    # SFT (Supervised Fine-Tuning) parameters
    sft_learning_rate=0.0002,                    # Learning rate for SFT (default: 0.0002)
    sft_ratio=0.05,                             # Ratio of data to use for SFT (default: 0.05)
    sft_packing=True,                           # Whether to use packing for SFT (default: True)
    sft_per_device_train_batch_size=2,          # Batch size per device for SFT (default: 2)
    sft_gradient_accumulation_steps=8,          # Gradient accumulation steps for SFT (default: 8)
    sft_gradient_checkpointing=True,            # Whether to use gradient checkpointing (default: True)
    sft_lora_r=32,                             # LoRA r parameter for SFT (default: 32)
    sft_lora_alpha=16,                         # LoRA alpha parameter for SFT (default: 16)
    
    # DPO (Direct Preference Optimization) parameters
    dpo_learning_rate=0.000005,                # Learning rate for DPO (default: 0.000005)
    dpo_num_train_epochs=1,                    # Number of epochs for DPO (default: 1)
    dpo_per_device_train_batch_size=8,         # Batch size per device for DPO (default: 8)
    dpo_gradient_accumulation_steps=2,         # Gradient accumulation steps for DPO (default: 2)
    dpo_gradient_checkpointing=True,           # Whether to use gradient checkpointing (default: True)
    dpo_lora_r=16,                            # LoRA r parameter for DPO (default: 16)
    dpo_lora_alpha=8,                         # LoRA alpha parameter for DPO (default: 8)
    dpo_bf16=True,                            # Whether to use bfloat16 for DPO (default: True)
    dpo_max_length=None,                      # Maximum sequence length for DPO (default: None)
    
    # Infrastructure
    gpus=2                                    # Number of GPUs to use (default: 2)
)

Model Downloads

By default, downloaded models are stored in ~/.cache/rundpo/adapters. You can customize this location by setting the RD_HOME environment variable:

export RD_HOME="/path/to/your/preferred/cache"

API Reference

Clients

  • AsyncRundpoClient: Async-first client for modern Python applications
  • RundpoClient: Synchronous client for simpler use cases

Data Classes

  • RunConfig: Configuration for training runs (see Available Configuration Options above)
  • DPOConfig: Configuration specific to DPO training
  • FileUpload: Represents an uploaded file
  • RunStatus: Enum of possible run statuses:
    • PENDING: Initial state
    • PROVISIONING: Setting up GPUs
    • LAUNCHING_SFT: Starting SFT training
    • TRAINING_SFT: Running SFT training
    • PREPARING_DPO: Preparing for DPO training
    • LAUNCHING_DPO: Starting DPO training
    • TRAINING_DPO: Running DPO training
    • SAVING_MODEL: Saving the trained model
    • FREEING_GPUS: Cleaning up resources
    • COMPLETED: Run completed successfully
    • FAILED: Run failed

Utility Functions

  • download_and_extract_async: Download and extract a model asynchronously
  • download_and_extract: Download and extract a model synchronously
  • get_cache_dir: Get the current cache directory path

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rundpo-0.2.1.tar.gz (12.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

rundpo-0.2.1-py3-none-any.whl (10.3 kB view details)

Uploaded Python 3

File details

Details for the file rundpo-0.2.1.tar.gz.

File metadata

  • Download URL: rundpo-0.2.1.tar.gz
  • Upload date:
  • Size: 12.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.5

File hashes

Hashes for rundpo-0.2.1.tar.gz
Algorithm Hash digest
SHA256 5549b468eb202384754794e8bd7387256cc6d9258bb3170ad1d89202e91056e6
MD5 9db8cc2194ac1c2a8f03dca52a0cec36
BLAKE2b-256 42ea4ad5b31e184d525a3c40949a1ffd5d90121c98255c6b1617a7fa9bb1b996

See more details on using hashes here.

File details

Details for the file rundpo-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: rundpo-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 10.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.0.1 CPython/3.11.5

File hashes

Hashes for rundpo-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 736e1d8db72671268aaa025933f837f963a5a173192436239a0264c8ee4261a3
MD5 8b185acc55aae94f02d0165fd3e265e3
BLAKE2b-256 2705fea4ec477d72718282b6e260ca241ba65ddec9042c2024962f6124e68e93

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page