Skip to main content

CRAFT: Contrastive Representation Aware Fine-Tuning toolkit

Project description

CRAFT · Contrastive Representation Aware Fine-Tuning

DOI PyPI version CRAFT Tests

CRAFT is a library and fine-tuning technique that layers a contrastive InfoNCE objective on top of standard SFT and preference-optimization trainers. It provides:

  • Composable losses – configurable InfoNCE loss with projection/pooling and weighted blending against supervised losses via craft_alpha.
  • Accumulation-aware scaling – proper gradient ratio regardless of batch distribution, ensuring alpha means exactly what it says.
  • Memory-efficient training – hook-based hidden state capture and GradCache support for large-batch contrastive learning under memory constraints.
  • Single forward pass – for self-align strategy, both SFT and contrastive losses are computed from one forward pass using dual pooling.
  • Trainer wrappers – drop-in replacements for TRL's SFT/ORPO/GRPO/PPO/DPO trainers plus utilities for plain transformers.Trainer usage.
  • Metrics – contrastive accuracy, representation consistency, and reference tracking.
  • Dataset utilities – helpers for paired datasets or self-aligned positives, plus a default collator ready for mixed InfoNCE/SFT batches.
  • Flexible length matching – options to oversample, cap, auto-adjust ratios, or raise if SFT and contrastive lengths diverge, alongside per-loader batch size overrides.

Installation

# Install from PyPI
uv pip install contrastive-ft

# Optional dependency groups
uv pip install -e 'contrastive-ft[trl]'    # TRL trainers
uv pip install -e 'contrastive-ft[hf]'     # transformers integration only
uv pip install -e 'contrastive-ft[peft]'   # LoRA/PEFT examples
uv pip install -e 'contrastive-ft[all]'    # everything

# Editable install with testing extras for local development
git clone https://github.com/omarkamali/craft.git
cd craft
uv pip install -e '.[test]'

Package layout

craft/
  ├── config.py              # CRAFT config mixin + TRL-specific configs
  ├── data.py                # Dataset bundle, collator, mixed dataloader
  ├── losses.py              # InfoNCELoss, ProjectionHead, pooling strategies
  ├── metrics.py             # Metric utilities and EMA helpers
  ├── trainers.py            # CRAFT trainer mixin + TRL wrappers
  ├── accumulator.py         # Accumulation-aware loss scaling
  ├── hooks.py               # Memory-efficient hidden state capture
  ├── gradcache.py           # GradCache for large-batch contrastive
  ├── gradient_balancing.py  # Gradient dominance mitigation strategies
  └── __init__.py            # Public exports

What's New

v0.4.0: Gradient Balancing & Presets

This release addresses gradient dominance and simplifies configuration with presets.

Gradient Balancing Strategies:

  • loss_scale: Simple loss normalization by running mean (recommended starting point)
  • uncertainty: Homoscedastic uncertainty weighting (Kendall et al., CVPR 2018)
  • gradnorm: Dynamic gradient normalization (Chen et al., ICML 2018)
  • pcgrad: Project conflicting gradients (Yu et al., NeurIPS 2020)

Presets & Auto-Configuration:

# Start from a preset
config = CRAFTSFTConfig.from_preset("balanced", output_dir="./outputs")

# Or auto-detect optimal settings
config = CRAFTSFTConfig.auto(
    output_dir="./outputs",
    model=my_model,
    sft_dataset=train_data,
    available_memory_gb=16,
)

Available presets: minimal, balanced, memory_efficient, large_batch, aggressive

v0.3.0:

This release introduces significant optimizations for memory efficiency and training correctness:

Accumulation-Aware Loss Scaling: The loss scaling now correctly accounts for batch distribution within gradient accumulation windows. Previously, with alpha=0.6 and beta=0.6, the effective gradient ratio was ~72:28 instead of the intended 60:40. Now alpha means exactly what it says regardless of beta.

Single Forward Pass for Self-Align: When using strategy="self_align", CRAFT now computes both SFT and contrastive losses from a single forward pass using dual pooling. This eliminates the redundant second forward pass, reducing compute by ~50% for self-align.

Memory-Efficient Hidden State Capture: New hook-based hidden state extraction captures only the final layer output instead of all layers. This reduces memory overhead from O(num_layers × batch × seq × hidden) to O(batch × seq × hidden).

GradCache Support: For paired dataset training with large batches, enable craft_use_gradcache=True to compute contrastive loss with gradient caching. This allows effective batch sizes of 1000+ even on a single GPU.

Improved Projection Head: The projection head now uses a 2-layer MLP with GELU activation (following SimCLR), replacing the previous single-layer Tanh design. Output dimension is configurable via craft_projection_dim.

config = CRAFTSFTConfig(
    # Memory optimization
    craft_use_gradcache=True,           # Enable GradCache for large batches
    craft_gradcache_chunk_size=8,       # Chunk size for backward pass
    craft_use_hidden_state_hook=True,   # Hook-based hidden state capture

    # Projection head
    craft_projection_dim=256,           # Lower dim = more efficient
    craft_learnable_temperature=True,   # CLIP-style learnable temp

    # Negative sampling
    craft_negative_strategy="queue",    # MoCo-style negative queue
    craft_negative_queue_size=65536,
)

Custom Data Loaders

CRAFT now supports custom PyTorch DataLoader instances for both SFT and contrastive training, giving you more control over batching, sampling, and collation logic.

trainer = CRAFTSFTTrainer(
    model=model,
    args=args,
    train_dataset=sft_dataset,  # Still required for length calculations
    craft_bundle=bundle,
    craft_sft_loader=custom_sft_loader,          # Custom SFT loader
    craft_contrastive_loader=custom_contrast_loader  # Custom contrastive loader
)

Enhanced Self-align Validation

When using strategy="self_align", CRAFT now performs additional validation to ensure your data is properly formatted:

  • Validates presence of either labels or assistant_mask in SFT batches
  • Ensures at least one token is marked as an assistant token
  • Provides clear error messages for common configuration issues
# Example of valid self-align batch
{
    "input_ids": torch.tensor([...]),
    "attention_mask": torch.tensor([...]),
    "labels": torch.tensor([-100, -100, 1234, 5678, -100]),  # Assistant tokens where labels != -100
    # OR
    "assistant_mask": torch.tensor([0, 0, 1, 1, 0])  # 1 marks assistant tokens
}

Quick start

from transformers import AutoModelForCausalLM
from craft.config import CRAFTSFTConfig
from craft.data import CRAFTCollator, make_craft_datasets
from craft.trainers import CRAFTSFTTrainer

# Assume `sft_dataset` and `contrastive_dataset` are tokenized datasets with the
# appropriate columns (`input_ids`, `attention_mask`, optional *_tgt columns).

bundle = make_craft_datasets(
    sft_dataset,
    contrastive_dataset=contrastive_dataset,
    strategy="paired_dataset",
)

model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta")

args = CRAFTSFTConfig(
    output_dir="./outputs",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    craft_alpha=0.6,
    craft_beta=0.5,
)

trainer = CRAFTSFTTrainer(
    model=model,
    args=args,
    train_dataset=sft_dataset,
    craft_bundle=bundle,
    data_collator=CRAFTCollator(),
)

trainer.train()

Length matching & batching strategies

CRAFT lets you control how supervised (SFT) and contrastive datasets are balanced:

  • craft_length_strategy="oversample" – loop the shorter loader (default).
  • "cap" – stop when either loader exhausts, keeping epochs perfectly aligned.
  • "auto_beta" – cap like above and recompute craft_beta from observed batch counts.
  • "error" – raise if lengths diverge, useful for deterministic experiments.

Combine this with craft_contrastive_batch_size to decouple batch sizes:

config = CRAFTSFTConfig(
    output_dir="./outputs",
    per_device_train_batch_size=2,
    craft_contrastive_batch_size=4,
    craft_beta=0.5,
    craft_beta_mode="auto",
    craft_length_strategy="auto_beta",
)

These knobs are honoured by all CRAFT*Trainer classes and the CRAFTMixedDataLoader.

Review the guide for more details.

Techniques & References

CRAFT incorporates techniques from several influential papers:

Technique Reference Usage in CRAFT
InfoNCE Loss Oord et al. "Representation Learning with Contrastive Predictive Coding" (2018) Core contrastive objective
Projection Head Chen et al. "A Simple Framework for Contrastive Learning of Visual Representations" (SimCLR, 2020) 2-layer MLP with GELU for projection
Temperature Scaling Gao et al. "SimCSE: Simple Contrastive Learning of Sentence Embeddings" (2021) Configurable temperature (0.05 default)
Learnable Temperature Radford et al. "Learning Transferable Visual Models From Natural Language Supervision" (CLIP, 2021) Optional craft_learnable_temperature
GradCache Gao et al. "Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup" (2021) Memory-efficient large-batch training
Negative Queue He et al. "Momentum Contrast for Unsupervised Visual Representation Learning" (MoCo, 2020) Optional craft_negative_strategy="queue"
Multi-task Accumulation Raffel et al. "Exploring the Limits of Transfer Learning" (T5, 2020) Accumulation-aware loss scaling
GradNorm Chen et al. "Gradient Normalization for Adaptive Loss Balancing" (ICML 2018) craft_gradient_balancing="gradnorm"
Uncertainty Weighting Kendall et al. "Multi-Task Learning Using Uncertainty to Weigh Losses" (CVPR 2018) craft_gradient_balancing="uncertainty"
PCGrad Yu et al. "Gradient Surgery for Multi-Task Learning" (NeurIPS 2020) craft_gradient_balancing="pcgrad"

Notebooks

Six notebooks under packages/craft/notebooks cover end-to-end workflows:

  1. 01-craft-basic-sft – minimal CRAFTSFTTrainer run with paired datasets.
  2. 02-craft-best-practices – conversation packing, assistant masking, LoRA.
  3. 03a-craft-loss-transformers-trainer – integrate InfoNCELoss with vanilla transformers.Trainer.
  4. 03b-craft-trl-sft – TRL SFTTrainer wrapper with CRAFT metrics.
  5. 03c-craft-trl-orpo – ORPO preference optimisation with contrastive batches.
  6. 04-craft-qlora-translation-eval – QLoRA fine-tune of unsloth/gemma-3-270M-it on Flores translations, with before/after BLEU, loss curves, and metric plots.

Testing

CRAFT ships with a pytest suite covering losses, metrics, data utilities, and trainer mixins.

uv pip install -e '.[test]'
uv run python -m pytest -q

Contributing

  1. Add or update tests for new functionality.
  2. Run the lint/test suite before submitting patches.
  3. Update notebooks and documentation to reflect API changes.

Citation

If you find CRAFT useful for your research, please cite it as follows:

@misc{kamali2025craft,
  title={CRAFT: Contrastive Representation Aware Fine-Tuning},
  author={Kamali, Omar},
  year={2025},
  publisher={Zenodo},
  doi={10.5281/zenodo.18053757},
  url={https://doi.org/10.5281/zenodo.18053757},
  institution={Omneity Labs}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

contrastive_ft-0.4.2.tar.gz (53.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

contrastive_ft-0.4.2-py3-none-any.whl (40.2 kB view details)

Uploaded Python 3

File details

Details for the file contrastive_ft-0.4.2.tar.gz.

File metadata

  • Download URL: contrastive_ft-0.4.2.tar.gz
  • Upload date:
  • Size: 53.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for contrastive_ft-0.4.2.tar.gz
Algorithm Hash digest
SHA256 bc91e5d143ef13aae6892a2df51dc78dd35c05377fc42fbd031a8b216b2f97e5
MD5 2f075173f71ab7c41ef0f6e1331a4475
BLAKE2b-256 085a8f470f5a3ac47bd6e94ced5e9d258133f29ec91c197819100254db68aa9a

See more details on using hashes here.

Provenance

The following attestation bundles were made for contrastive_ft-0.4.2.tar.gz:

Publisher: publish.yml on omarkamali/craft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file contrastive_ft-0.4.2-py3-none-any.whl.

File metadata

  • Download URL: contrastive_ft-0.4.2-py3-none-any.whl
  • Upload date:
  • Size: 40.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.7

File hashes

Hashes for contrastive_ft-0.4.2-py3-none-any.whl
Algorithm Hash digest
SHA256 3548efec3e7bd5ea0256dca1a9c2b6e1ca042a4df3bab47cd17ef5595e758798
MD5 b42f90a1001622cec53c4fdfab6a2c3c
BLAKE2b-256 e213d1b53811897a9bb032685808ce748fec60e912655e784889cf8a3d5a61a1

See more details on using hashes here.

Provenance

The following attestation bundles were made for contrastive_ft-0.4.2-py3-none-any.whl:

Publisher: publish.yml on omarkamali/craft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page