Skip to main content

Advanced Machine Learning Training Platform - IN DEVELOPMENT

Project description

AITraining Interactive Wizard

PyPI version Python versions License Documentation Ask DeepWiki

Train state-of-the-art ML models with minimal code

English | Portugues

📚 Full Documentation →


📖 Comprehensive Documentation Available

Visit docs.monostate.com for detailed guides, tutorials, API reference, and examples covering all features including LLM fine-tuning, PEFT/LoRA, DPO/ORPO training, hyperparameter sweeps, and more.


AITraining is an advanced machine learning training platform built on top of AutoTrain Advanced. It provides a streamlined interface for fine-tuning LLMs, vision models, and more.

Highlights

Automatic Dataset Conversion

Feed any dataset format and AITraining automatically detects and converts it. Supports 6 input formats with automatic detection:

Format Detection Example Columns
Alpaca instruction/input/output {"instruction": "...", "output": "..."}
ShareGPT from/value pairs {"conversations": [{"from": "human", ...}]}
Messages role/content {"messages": [{"role": "user", ...}]}
Q&A question/answer variants {"question": "...", "answer": "..."}
DPO prompt/chosen/rejected For preference training
Plain Text Single text column Raw text for pretraining
aitraining llm --train --auto-convert-dataset --chat-template gemma3 \
  --data-path tatsu-lab/alpaca --model google/gemma-3-270m-it

32 Chat Templates

Comprehensive template library with token-level weight control:

  • Llama family: llama, llama-3, llama-3.1
  • Gemma family: gemma, gemma-2, gemma-3, gemma-3n
  • Others: mistral, qwen-2.5, phi-3, phi-4, chatml, alpaca, vicuna, zephyr
from autotrain.rendering import get_renderer, ChatFormat, RenderConfig

config = RenderConfig(format=ChatFormat.CHATML, only_assistant=True)
renderer = get_renderer('chatml', tokenizer, config)
encoded = renderer.build_supervised_example(conversation)
# Returns: {'input_ids', 'labels', 'token_weights', 'attention_mask'}

GRPO Training with Custom Environments

Train with Group Relative Policy Optimization using your own reward environment. The env provides prompts and scores multi-turn episodes — GRPO generates completions, scores them, and optimizes:

aitraining llm --train --trainer grpo \
  --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
  --rl-env-module my_envs.hotel_env \
  --rl-env-class HotelEnv \
  --rl-num-generations 4 \
  --rl-max-new-tokens 256

Your environment implements three methods:

class HotelEnv:
    def build_dataset(self, tokenizer) -> Dataset:
        """Return HF Dataset with 'prompt' column."""

    def score_episode(self, model, tokenizer, completion, case_idx) -> float:
        """Run multi-turn episode from completion, return 0.0-1.0 score."""

    def get_tools(self) -> list[dict]:
        """Return tool schemas for generation (optional)."""

Custom RL Environments (PPO)

Build custom reward functions for PPO training with three environment types:

# Text generation with custom reward
aitraining llm --train --trainer ppo \
  --rl-env-type text_generation \
  --rl-env-config '{"stop_sequences": ["</s>"]}' \
  --rl-reward-model-path ./reward_model

# Multi-objective rewards (correctness + formatting)
aitraining llm --train --trainer ppo \
  --rl-env-type multi_objective \
  --rl-env-config '{"reward_components": {"correctness": {"type": "keyword"}, "formatting": {"type": "length"}}}' \
  --rl-reward-weights '{"correctness": 1.0, "formatting": 0.1}'

Hyperparameter Sweeps

Automated optimization with Optuna, random search, or grid search:

from autotrain.utils import HyperparameterSweep, SweepConfig, ParameterRange

config = SweepConfig(
    backend="optuna",
    optimization_metric="eval_loss",
    optimization_mode="minimize",
    num_trials=20,
)

sweep = HyperparameterSweep(
    objective_function=train_model,
    config=config,
    parameters=[
        ParameterRange("learning_rate", "log_uniform", low=1e-5, high=1e-3),
        ParameterRange("batch_size", "categorical", choices=[4, 8, 16]),
    ]
)
result = sweep.run()
# Returns best_params, best_value, trial history

Enhanced Evaluation Metrics

8 metrics beyond loss, with callbacks for periodic evaluation:

Metric Type Use Case
Perplexity Auto-computed Language model quality
BLEU Generation Translation, summarization
ROUGE (1/2/L) Generation Summarization
BERTScore Generation Semantic similarity
METEOR Generation Translation
F1/Accuracy Classification Standard metrics
Exact Match QA Question answering
from autotrain.evaluation import Evaluator, EvaluationConfig, MetricType

config = EvaluationConfig(
    metrics=[MetricType.PERPLEXITY, MetricType.BLEU, MetricType.ROUGE, MetricType.BERTSCORE],
    save_predictions=True,
)
evaluator = Evaluator(model, tokenizer, config)
result = evaluator.evaluate(dataset)

Auto LoRA Merge

After PEFT training, automatically merge adapters and save deployment-ready models:

# Default: merges adapters into full model
aitraining llm --train --peft --model meta-llama/Llama-3.2-1B

# Keep adapters separate (smaller files)
aitraining llm --train --peft --no-merge-adapter --model meta-llama/Llama-3.2-1B

Screenshots

Chat interface for testing trained models
Built-in chat interface for testing trained models with conversation history

Terminal UI with W&B LEET integration
Terminal UI with real-time W&B LEET metrics visualization


Installation

pip install aitraining

Requirements: Python >= 3.10, PyTorch

Quick Start

Interactive Wizard

aitraining

The wizard guides you through:

  1. Trainer type selection (LLM, vision, NLP, tabular)
  2. Model selection with curated catalogs from HuggingFace
  3. Dataset configuration with auto-format detection
  4. Advanced parameters (PEFT, quantization, sweeps)

Config File

aitraining --config config.yaml

Python API

from autotrain.trainers.clm import train
from autotrain.trainers.clm.params import LLMTrainingParams

config = LLMTrainingParams(
    model="meta-llama/Llama-3.2-1B",
    data_path="your-dataset",
    trainer="sft",
    epochs=3,
    batch_size=4,
    lr=2e-5,
    peft=True,
    auto_convert_dataset=True,
    chat_template="llama3",
)

train(config)

Comparison

AITraining vs AutoTrain vs Tinker

Feature AutoTrain AITraining Tinker
Trainers
SFT/DPO/ORPO Yes Yes Yes
PPO (RLHF) Basic Enhanced (TRL) Advanced
GRPO No Yes (TRL 0.28) Custom
Reward Modeling Yes Yes No
Knowledge Distillation No Yes (KL + CE loss) Yes (text-only)
Data
Auto Format Detection No Yes (6 formats) No
Chat Template Library Basic 32 templates 5 templates
Runtime Column Mapping No Yes No
Conversation Extension No Yes No
Training
Hyperparameter Sweeps No Yes (Optuna) Manual
Custom RL Environments No Yes (3 types) Yes
Multi-objective Rewards No Yes Yes
Forward-Backward Pipeline No Yes Yes
Async Off-Policy RL No No Yes
Stream Minibatch No No Yes
Evaluation
Metrics Beyond Loss No 8 metrics Manual
Periodic Eval Callbacks No Yes Yes
Custom Metric Registration No Yes No
Interface
Interactive CLI Wizard No Yes No
TUI (Experimental) No Yes No
W&B LEET Visualizer No Yes Yes
Hardware
Apple Silicon (MPS) Limited Full No
Quantization (int4/int8) Yes Yes Unknown
Multi-GPU Yes Yes Yes
Task Coverage
Vision Tasks Yes Yes No
NLP Tasks Yes Yes No
Tabular Tasks Yes Yes No
Tool Use Environments No Yes (GRPO) Yes
Multiplayer RL No No Yes

Supported Tasks

Task Trainers Status
LLM Fine-tuning SFT, DPO, ORPO, PPO, GRPO, Reward, Distillation Stable
Text Classification Single/Multi-label Stable
Token Classification NER, POS tagging Stable
Sequence-to-Sequence Translation, Summarization Stable
Image Classification Single/Multi-label Stable
Object Detection YOLO, DETR Stable
VLM Training Vision-Language Models Beta
Tabular XGBoost, sklearn Stable
Sentence Transformers Semantic similarity Stable
Extractive QA SQuAD format Stable

Configuration Example

task: llm-sft
base_model: meta-llama/Llama-3.2-1B
project_name: my-finetune

data:
  path: your-dataset
  train_split: train
  auto_convert_dataset: true
  chat_template: llama3

params:
  epochs: 3
  batch_size: 4
  lr: 2e-5
  peft: true
  lora_r: 16
  lora_alpha: 32
  quantization: int4
  mixed_precision: bf16

# Optional: hyperparameter sweep
sweep:
  enabled: true
  backend: optuna
  n_trials: 10
  metric: eval_loss

Documentation

📚 docs.monostate.com — Complete documentation with tutorials, API reference, and examples.

Quick Links

Local Docs


License

Apache 2.0 - See LICENSE for details.

Based on AutoTrain Advanced by Hugging Face.


Monostate AI

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aitraining-0.0.36.tar.gz (541.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

aitraining-0.0.36-py3-none-any.whl (569.3 kB view details)

Uploaded Python 3

File details

Details for the file aitraining-0.0.36.tar.gz.

File metadata

  • Download URL: aitraining-0.0.36.tar.gz
  • Upload date:
  • Size: 541.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for aitraining-0.0.36.tar.gz
Algorithm Hash digest
SHA256 8c5e3f0c50aeda5c9031841a2a900c8f8d57fa128d8d300fb09549a566a752a3
MD5 a871439f2e6047e8b39398ad3193e280
BLAKE2b-256 5b34fb9acd3d3d311c052530f16a0b352ee26ada0597cbb732320582ee1ffa99

See more details on using hashes here.

File details

Details for the file aitraining-0.0.36-py3-none-any.whl.

File metadata

  • Download URL: aitraining-0.0.36-py3-none-any.whl
  • Upload date:
  • Size: 569.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.13.5

File hashes

Hashes for aitraining-0.0.36-py3-none-any.whl
Algorithm Hash digest
SHA256 9d8d8ff550e31d6d4b583a28ccefdbb56cf41bbf3cb6bd19e2675436ca54d29a
MD5 71daf90637695e35f657d4fbaeab79cc
BLAKE2b-256 633a03f6c5c167928b12794e46e87eed9cb0b6a0e17a0b38fb28f1306c806ee1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page