Skip to main content

RLUF: Reinforcement Learning with Uncertainty Feedback via Conformal Prediction for DPO

Project description

RLUF: Reinforcement Learning with Uncertainty Feedback

Conformal Feedback Alignment — using Conformal Prediction to quantify LLM generation uncertainty and integrate it as per-example weights into DPO training. This enables the model to learn more from high-confidence preference pairs and less from uncertain ones.

Paper: Conformal Feedback Alignment for LLM Fine-Tuning

Architecture

┌─────────────────────────────────────────────────────────────┐
│                    RLUF Pipeline                            │
│                                                             │
│  ┌──────┐    ┌───────────-┐    ┌──────────────────┐         │
│  │ SFT  │───>│ Generation │───>│ Conformal        │         │
│  │      │    │ + Scoring  │    │ Prediction (CP)  │         │
│  └──────┘    └───────────—┘    └────────┬─────────┘         │
│                                        │                    │
│                              Prediction Sets                │
│                              (50% & 80% coverage)           │
│                                        │                    │
│  ┌───────────────┐    ┌───────────────▼────────────—──┐     │
│  │ AI Feedback   │───>│ Uncertainty Weight Assignment │     │
│  │ (Preference)  │    └───────────────┬─────────────—─┘     │
│  └───────────────┘                    │                     │
│                              Weighted DPO Pairs             │
│                                        │                    │
│  ┌──────────────┐    ┌────────────────▼─────┐               │
│  │ Weighted DPO │───>│ Inference + Evaluate │               │
│  │ Training     │    └──────────────────────┘               │
│  └──────────────┘                                           │
└─────────────────────────────────────────────────────────────┘

Quick Start

1. Install

git clone https://github.com/tiejin98/Conformal-Feedback-Alignment.git
cd Conformal-Feedback-Alignment
pip install -r requirements.txt

2. Configure

cp .env.example .env
# Edit .env with your API keys:
#   OPENAI_API_KEY=sk-...
#   HF_TOKEN=hf_...

# Review and customize the config:
cp configs/default.yaml configs/my_config.yaml
# Edit configs/my_config.yaml with your model paths

3. Run

# Run individual stages:
python -m cfa sft --config configs/my_config.yaml
python -m cfa generate --config configs/my_config.yaml
python -m cfa calibrate --config configs/my_config.yaml --quantile 0.2
python -m cfa calibrate --config configs/my_config.yaml --quantile 0.5
python -m cfa feedback --config configs/my_config.yaml
python -m cfa assign-weights --config configs/my_config.yaml
python -m cfa train --config configs/my_config.yaml
python -m cfa infer --config configs/my_config.yaml
python -m cfa evaluate --config configs/my_config.yaml

# Or run the full pipeline:
python -m cfa run-all --config configs/my_config.yaml

4. Run Tests

pytest tests/ -v

Pipeline Stages

Stage 1: Generation with Conformal Prediction

Step Command Description Output
1a cfa sft Fine-tune Llama-2-7B on summarization data (loss on summary tokens only) SFT model checkpoint
1b cfa generate Sample 60 responses per prompt, score unique ones with GPT-4o Frequency dicts + accuracy scores
1c cfa calibrate Grid-search CP hyperparameters, calibrate quantile threshold Prediction sets (JSON)

Key algorithm — Nonconformity Score:

score = 10 - (freq/total)*10 + (entropy/2)*weight - similarity_to_top*weight_2

Lower score = higher confidence = more likely in prediction set.

Stage 2: AI Feedback with Uncertainty Weights

Step Command Description Output
2a cfa feedback Pairwise preference annotation via AlpacaFarm DPO pairs (JSONL)
2b cfa assign-weights Weight pairs by CP prediction set membership Weighted DPO pairs

Weight assignment:

  • In 50% coverage set → weight 0.5
  • In 80% coverage set only → weight 0.8
  • In both → weight 0.65
  • In neither → weight 0.0

Stage 3: Training and Evaluation

Step Command Description Output
3a cfa train Weighted DPO training (loss *= weight) RLUF model
3b cfa infer Generate summaries on test set Predictions (pkl)
3c cfa evaluate Score with GPT-4o (Accuracy, Relevance, Completeness, Expression) Scores (pkl)

Output Directory Structure

outputs/
├── generation/          # Stage 1b outputs
│   ├── generation_llama2.txt
│   ├── generation_llama2_accuracy.txt
│   ├── generation_test_llama2.txt
│   └── response_dict_llama2.pkl
├── calibration/         # Stage 1c outputs
│   ├── prediction_set_quantile0.2_threshold0.7_llama2.json
│   └── prediction_set_quantile0.5_threshold0.7_llama2.json
├── feedback/            # Stage 2 outputs
│   ├── dpo_data_llama2.json
│   └── dpo_data_llama2_withuncertainty.json
├── inference/           # Stage 3b outputs
│   ├── test_dict_question.pkl
│   └── test_dict_RLUF.pkl
└── evaluation/          # Stage 3c outputs
    └── evaluation_scores_llama2.pkl

Configuration

All hyperparameters are in configs/default.yaml. Key settings:

Parameter Default Description
generation.calibration_size 50 Number of calibration samples
generation.sampling_num 60 Samples per prompt
generation.temperature 0.35 Sampling temperature
conformal.quantile_bars [0.2, 0.5] Coverage levels for CP
conformal.accuracy_threshold 0.7 GPT score threshold for "correct"
dpo.learning_rate 1.5e-6 DPO training learning rate

Hardware Requirements

Stage GPU Memory Estimated Time
SFT ~24GB (bfloat16) ~2-4 hours
Generation ~16GB (bfloat16) ~6-12 hours (60 samples x 50+ prompts)
Calibration CPU only ~10-30 minutes
Feedback CPU + OpenAI API ~1-2 hours (API dependent)
DPO Training ~24GB (float16) ~2-4 hours
Inference ~16GB (bfloat16) ~1-2 hours
Evaluation CPU + OpenAI API ~1-2 hours (API dependent)

Project Structure

Conformal-Feedback-Alignment/
├── cfa/                    # Main package
│   ├── cli.py              # CLI entry point
│   ├── config.py           # Configuration loading
│   ├── stages/             # Pipeline stage implementations
│   │   ├── sft.py          # SFT training
│   │   ├── generation.py   # Multi-sample generation + scoring
│   │   ├── calibration.py  # Conformal prediction
│   │   ├── feedback.py     # AI preference annotation
│   │   ├── weights.py      # Uncertainty weight assignment
│   │   ├── train.py        # Weighted DPO training
│   │   ├── inference.py    # Test set inference
│   │   └── evaluation.py   # GPT-4o evaluation
│   ├── models/
│   │   └── weighted_dpo.py # WeightedDPOTrainer
│   └── utils/
│       ├── text_processing.py
│       ├── scoring.py
│       └── io.py
├── configs/
│   └── default.yaml        # Default configuration
├── tests/                  # Unit tests
├── Generation With CP/     # Original scripts (legacy)
├── AI Feedback/            # Original scripts (legacy)
├── Training and Testing/   # Original scripts (legacy)
├── requirements.txt
├── Makefile
├── Dockerfile
└── .env.example

Dependencies

  • torch>=2.3.0
  • transformers>=4.51.0
  • trl>=0.7.0
  • datasets>=3.5.0
  • openai>=0.28.1
  • gensim>=4.3.0
  • alpaca-farm
  • accelerate>=0.30.0

License

See LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

conformal_feedback_alignment-0.1.0.tar.gz (27.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

conformal_feedback_alignment-0.1.0-py3-none-any.whl (28.6 kB view details)

Uploaded Python 3

File details

Details for the file conformal_feedback_alignment-0.1.0.tar.gz.

File metadata

File hashes

Hashes for conformal_feedback_alignment-0.1.0.tar.gz
Algorithm Hash digest
SHA256 4322ece3ff263fcd0b3ab2efca4f49ad3bf1c3f2d54672a345fc6d7c3ab6c3ee
MD5 2e36efbfe7f5f52e946d3e315d886498
BLAKE2b-256 d76a508ac8e7e157428ed107572c1eedd0152910edb4811ec0354a1b0616eeb7

See more details on using hashes here.

File details

Details for the file conformal_feedback_alignment-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for conformal_feedback_alignment-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2b56ca901081464c6ebd156880fa454ce8cb214308b2d244b03ee3c6f58438de
MD5 65f8dbfc634b9900e3ddabd8a8acc7cf
BLAKE2b-256 7dfd397eafca7ea20b283b754769c98eef12d1a20600773d29d03b969aceed1a

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page