Skip to main content

CRAFT: Contrastive Representation Aware Fine-Tuning toolkit

Project description

CRAFT · Contrastive Representation Aware Fine-Tuning

CRAFT is a library that layers a contrastive InfoNCE objective on top of standard SFT and preference-optimization trainers. It provides:

  • Composable losses – configurable InfoNCE loss with projection/pooling and weighted blending against supervised losses via craft_alpha.
  • Mixed data loading – automated cycling of SFT and contrastive batches according to a configurable craft_beta ratio, with optional auto-tuning via craft_beta_mode.
  • Trainer wrappers – drop-in replacements for TRL's SFT/ORPO/GRPO/PPO/DPO trainers plus utilities for plain transformers.Trainer usage.
  • Metrics – contrastive accuracy, representation consistency, and reference tracking.
  • Dataset utilities – helpers for paired datasets or self-aligned positives, plus a default collator ready for mixed InfoNCE/SFT batches.
  • Flexible length matching – options to oversample, cap, auto-adjust ratios, or raise if SFT and contrastive lengths diverge, alongside per-loader batch size overrides.

Installation

# Editable install with testing extras
uv pip install -e '.[test]'

# Optional dependency groups
uv pip install -e '.[trl]'    # TRL trainers
uv pip install -e '.[hf]'     # transformers integration only
uv pip install -e '.[peft]'   # LoRA/PEFT examples
uv pip install -e '.[all]'    # everything

Package layout

craft/
  ├── config.py     # CRAFT config mixin + TRL-specific configs
  ├── data.py       # Dataset bundle, collator, mixed dataloader
  ├── losses.py     # InfoNCELoss + loss combination helpers
  ├── metrics.py    # Metric utilities and EMA helpers
  ├── trainers.py   # CRAFT trainer mixin + TRL wrappers
  └── __init__.py   # Public exports

What's New

Custom Data Loaders

CRAFT now supports custom PyTorch DataLoader instances for both SFT and contrastive training, giving you more control over batching, sampling, and collation logic.

trainer = CRAFTSFTTrainer(
    model=model,
    args=args,
    train_dataset=sft_dataset,  # Still required for length calculations
    craft_bundle=bundle,
    craft_sft_loader=custom_sft_loader,          # Custom SFT loader
    craft_contrastive_loader=custom_contrast_loader  # Custom contrastive loader
)

Enhanced Self-align Validation

When using strategy="self_align", CRAFT now performs additional validation to ensure your data is properly formatted:

  • Validates presence of either labels or assistant_mask in SFT batches
  • Ensures at least one token is marked as an assistant token
  • Provides clear error messages for common configuration issues
# Example of valid self-align batch
{
    "input_ids": torch.tensor([...]),
    "attention_mask": torch.tensor([...]),
    "labels": torch.tensor([-100, -100, 1234, 5678, -100]),  # Assistant tokens where labels != -100
    # OR
    "assistant_mask": torch.tensor([0, 0, 1, 1, 0])  # 1 marks assistant tokens
}

Quick start

from transformers import AutoModelForCausalLM
from craft.config import CRAFTSFTConfig
from craft.data import CRAFTCollator, make_craft_datasets
from craft.trainers import CRAFTSFTTrainer

# Assume `sft_dataset` and `contrastive_dataset` are tokenized datasets with the
# appropriate columns (`input_ids`, `attention_mask`, optional *_tgt columns).

bundle = make_craft_datasets(
    sft_dataset,
    contrastive_dataset=contrastive_dataset,
    strategy="paired_dataset",
)

model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")

args = CRAFTSFTConfig(
    output_dir="./outputs",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    craft_alpha=0.6,
    craft_beta=0.5,
)

trainer = CRAFTSFTTrainer(
    model=model,
    args=args,
    train_dataset=sft_dataset,
    craft_bundle=bundle,
    data_collator=CRAFTCollator(),
)

trainer.train()

Length matching & batching strategies

CRAFT lets you control how supervised (SFT) and contrastive datasets are balanced:

  • craft_length_strategy="oversample" – loop the shorter loader (default).
  • "cap" – stop when either loader exhausts, keeping epochs perfectly aligned.
  • "auto_beta" – cap like above and recompute craft_beta from observed batch counts.
  • "error" – raise if lengths diverge, useful for deterministic experiments.

Combine this with craft_contrastive_batch_size to decouple batch sizes:

config = CRAFTSFTConfig(
    output_dir="./outputs",
    per_device_train_batch_size=2,
    craft_contrastive_batch_size=4,
    craft_beta=0.5,
    craft_beta_mode="auto",
    craft_length_strategy="auto_beta",
)

These knobs are honoured by all CRAFT*Trainer classes and the CRAFTMixedDataLoader.

Notebooks

Six notebooks under packages/craft/notebooks cover end-to-end workflows:

  1. 01-craft-basic-sft – minimal CRAFTSFTTrainer run with paired datasets.
  2. 02-craft-best-practices – conversation packing, assistant masking, LoRA.
  3. 03a-craft-loss-transformers-trainer – integrate InfoNCELoss with vanilla transformers.Trainer.
  4. 03b-craft-trl-sft – TRL SFTTrainer wrapper with CRAFT metrics.
  5. 03c-craft-trl-orpo – ORPO preference optimisation with contrastive batches.
  6. 04-craft-qlora-translation-eval – QLoRA fine-tune of unsloth/gemma-3-270M-it on Flores translations, with before/after BLEU, loss curves, and metric plots.

Testing

CRAFT ships with a pytest suite covering losses, metrics, data utilities, and trainer mixins.

uv pip install -e '.[test]'
uv run python -m pytest -q

Contributing

  1. Add or update tests for new functionality.
  2. Run the lint/test suite before submitting patches.
  3. Update notebooks and documentation to reflect API changes.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

contrastive_ft-0.2.2.tar.gz (15.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

contrastive_ft-0.2.2-py3-none-any.whl (14.4 kB view details)

Uploaded Python 3

File details

Details for the file contrastive_ft-0.2.2.tar.gz.

File metadata

  • Download URL: contrastive_ft-0.2.2.tar.gz
  • Upload date:
  • Size: 15.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for contrastive_ft-0.2.2.tar.gz
Algorithm Hash digest
SHA256 f961ef815f8d18294ad7518fba0748ead453d5c0df8750f37adb452f8773aef7
MD5 87812b5f3464e6fa5f8f3298e46e69b8
BLAKE2b-256 09db3ac382cd874f8357ef23b3f9fa3eaa0a6d3b8193134e3d7eeea7bab1f4a0

See more details on using hashes here.

Provenance

The following attestation bundles were made for contrastive_ft-0.2.2.tar.gz:

Publisher: publish.yml on omarkamali/craft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file contrastive_ft-0.2.2-py3-none-any.whl.

File metadata

  • Download URL: contrastive_ft-0.2.2-py3-none-any.whl
  • Upload date:
  • Size: 14.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for contrastive_ft-0.2.2-py3-none-any.whl
Algorithm Hash digest
SHA256 45d5488d33229358c57aea61ae4958ca4b325c0d987490f6c0fe7009d602a29b
MD5 306bf313e493bfa813b4e278d492c44d
BLAKE2b-256 ef1502196dd753d4eb8887396a0c2f3e5adb13e13771ac262ee4f64e88148cbc

See more details on using hashes here.

Provenance

The following attestation bundles were made for contrastive_ft-0.2.2-py3-none-any.whl:

Publisher: publish.yml on omarkamali/craft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page