CRAFT: Contrastive Representation Aware Fine-Tuning toolkit
Project description
CRAFT · Contrastive Representation Aware Fine-Tuning
CRAFT is a library that layers a contrastive InfoNCE objective on top of standard SFT and preference-optimization trainers. It provides:
- Composable losses – configurable InfoNCE loss with projection/pooling and weighted
blending against supervised losses via
craft_alpha. - Mixed data loading – automated cycling of SFT and contrastive batches according to a
configurable
craft_betaratio, with optional auto-tuning viacraft_beta_mode. - Trainer wrappers – drop-in replacements for TRL's SFT/ORPO/GRPO/PPO/DPO trainers plus
utilities for plain
transformers.Trainerusage. - Metrics – contrastive accuracy, representation consistency, and reference tracking.
- Dataset utilities – helpers for paired datasets or self-aligned positives, plus a default collator ready for mixed InfoNCE/SFT batches.
- Flexible length matching – options to oversample, cap, auto-adjust ratios, or raise if SFT and contrastive lengths diverge, alongside per-loader batch size overrides.
Installation
# Editable install with testing extras
uv pip install -e '.[test]'
# Optional dependency groups
uv pip install -e '.[trl]' # TRL trainers
uv pip install -e '.[hf]' # transformers integration only
uv pip install -e '.[peft]' # LoRA/PEFT examples
uv pip install -e '.[all]' # everything
Package layout
craft/
├── config.py # CRAFT config mixin + TRL-specific configs
├── data.py # Dataset bundle, collator, mixed dataloader
├── losses.py # InfoNCELoss + loss combination helpers
├── metrics.py # Metric utilities and EMA helpers
├── trainers.py # CRAFT trainer mixin + TRL wrappers
└── __init__.py # Public exports
What's New
Custom Data Loaders
CRAFT now supports custom PyTorch DataLoader instances for both SFT and contrastive training, giving you more control over batching, sampling, and collation logic.
trainer = CRAFTSFTTrainer(
model=model,
args=args,
train_dataset=sft_dataset, # Still required for length calculations
craft_bundle=bundle,
craft_sft_loader=custom_sft_loader, # Custom SFT loader
craft_contrastive_loader=custom_contrast_loader # Custom contrastive loader
)
Enhanced Self-align Validation
When using strategy="self_align", CRAFT now performs additional validation to ensure your data is properly formatted:
- Validates presence of either
labelsorassistant_maskin SFT batches - Ensures at least one token is marked as an assistant token
- Provides clear error messages for common configuration issues
# Example of valid self-align batch
{
"input_ids": torch.tensor([...]),
"attention_mask": torch.tensor([...]),
"labels": torch.tensor([-100, -100, 1234, 5678, -100]), # Assistant tokens where labels != -100
# OR
"assistant_mask": torch.tensor([0, 0, 1, 1, 0]) # 1 marks assistant tokens
}
Quick start
from transformers import AutoModelForCausalLM
from craft.config import CRAFTSFTConfig
from craft.data import CRAFTCollator, make_craft_datasets
from craft.trainers import CRAFTSFTTrainer
# Assume `sft_dataset` and `contrastive_dataset` are tokenized datasets with the
# appropriate columns (`input_ids`, `attention_mask`, optional *_tgt columns).
bundle = make_craft_datasets(
sft_dataset,
contrastive_dataset=contrastive_dataset,
strategy="paired_dataset",
)
model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
args = CRAFTSFTConfig(
output_dir="./outputs",
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
craft_alpha=0.6,
craft_beta=0.5,
)
trainer = CRAFTSFTTrainer(
model=model,
args=args,
train_dataset=sft_dataset,
craft_bundle=bundle,
data_collator=CRAFTCollator(),
)
trainer.train()
Length matching & batching strategies
CRAFT lets you control how supervised (SFT) and contrastive datasets are balanced:
craft_length_strategy="oversample"– loop the shorter loader (default)."cap"– stop when either loader exhausts, keeping epochs perfectly aligned."auto_beta"– cap like above and recomputecraft_betafrom observed batch counts."error"– raise if lengths diverge, useful for deterministic experiments.
Combine this with craft_contrastive_batch_size to decouple batch sizes:
config = CRAFTSFTConfig(
output_dir="./outputs",
per_device_train_batch_size=2,
craft_contrastive_batch_size=4,
craft_beta=0.5,
craft_beta_mode="auto",
craft_length_strategy="auto_beta",
)
These knobs are honoured by all CRAFT*Trainer classes and the CRAFTMixedDataLoader.
Notebooks
Six notebooks under packages/craft/notebooks cover end-to-end workflows:
- 01-craft-basic-sft – minimal CRAFTSFTTrainer run with paired datasets.
- 02-craft-best-practices – conversation packing, assistant masking, LoRA.
- 03a-craft-loss-transformers-trainer – integrate
InfoNCELosswith vanillatransformers.Trainer. - 03b-craft-trl-sft – TRL SFTTrainer wrapper with CRAFT metrics.
- 03c-craft-trl-orpo – ORPO preference optimisation with contrastive batches.
- 04-craft-qlora-translation-eval – QLoRA fine-tune of
unsloth/gemma-3-270M-iton Flores translations, with before/after BLEU, loss curves, and metric plots.
Testing
CRAFT ships with a pytest suite covering losses, metrics, data utilities, and trainer mixins.
uv pip install -e '.[test]'
uv run python -m pytest -q
Contributing
- Add or update tests for new functionality.
- Run the lint/test suite before submitting patches.
- Update notebooks and documentation to reflect API changes.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file contrastive_ft-0.2.3.tar.gz.
File metadata
- Download URL: contrastive_ft-0.2.3.tar.gz
- Upload date:
- Size: 16.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
4e88c5ea16494a4ca996a57cd2c6b2aad9c74b5d76dcd939d90d53aabf5ac75b
|
|
| MD5 |
c9e2060d115f0220908302cea3ece11c
|
|
| BLAKE2b-256 |
baa6a1f3143bcdd6a35ab3d1855d990e19b116a1a69eb033c04e996b4e08662a
|
Provenance
The following attestation bundles were made for contrastive_ft-0.2.3.tar.gz:
Publisher:
publish.yml on omarkamali/craft
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
contrastive_ft-0.2.3.tar.gz -
Subject digest:
4e88c5ea16494a4ca996a57cd2c6b2aad9c74b5d76dcd939d90d53aabf5ac75b - Sigstore transparency entry: 779037537
- Sigstore integration time:
-
Permalink:
omarkamali/craft@3b55868e0032f8e8869cd81e81d5b323dbf17af2 -
Branch / Tag:
refs/tags/v0.2.3 - Owner: https://github.com/omarkamali
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@3b55868e0032f8e8869cd81e81d5b323dbf17af2 -
Trigger Event:
release
-
Statement type:
File details
Details for the file contrastive_ft-0.2.3-py3-none-any.whl.
File metadata
- Download URL: contrastive_ft-0.2.3-py3-none-any.whl
- Upload date:
- Size: 15.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? Yes
- Uploaded via: twine/6.1.0 CPython/3.13.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
901a2bfd0cf9c616aa2d23ab4f678bb283bb721b5d7829a8e6183662bb6a722d
|
|
| MD5 |
ec76e578d570a842e4a80f8bc6c901c1
|
|
| BLAKE2b-256 |
7d22d6e5df02122e087144e09f1402d5b507c85c145fb75351ab1aa892b7181c
|
Provenance
The following attestation bundles were made for contrastive_ft-0.2.3-py3-none-any.whl:
Publisher:
publish.yml on omarkamali/craft
-
Statement:
-
Statement type:
https://in-toto.io/Statement/v1 -
Predicate type:
https://docs.pypi.org/attestations/publish/v1 -
Subject name:
contrastive_ft-0.2.3-py3-none-any.whl -
Subject digest:
901a2bfd0cf9c616aa2d23ab4f678bb283bb721b5d7829a8e6183662bb6a722d - Sigstore transparency entry: 779037539
- Sigstore integration time:
-
Permalink:
omarkamali/craft@3b55868e0032f8e8869cd81e81d5b323dbf17af2 -
Branch / Tag:
refs/tags/v0.2.3 - Owner: https://github.com/omarkamali
-
Access:
public
-
Token Issuer:
https://token.actions.githubusercontent.com -
Runner Environment:
github-hosted -
Publication workflow:
publish.yml@3b55868e0032f8e8869cd81e81d5b323dbf17af2 -
Trigger Event:
release
-
Statement type: