A nimble and innovative implementation of the Direct Preference Optimization (DPO) algorithm with Causal Transformer and LSTM model for time series data, inspired by the paper of DPO in fine-tuning unsupervised Language Models
Project description
nanoDPO: Direct Preference Optimization for Time Series Data
Introduction
Welcome to nanoDPO
– a novel, cutting-edge library for Direct Preference Optimization (DPO) tailored for time series data. Inspired by the concept of utilizing DPO in fine-tuning unsupervised Language Models (LMs) to align with user preferences, nanoDPO
pivots this approach to the realm of time series analysis. This library offers a unique perspective and toolset for time series forecasting, leveraging the principles of DPO to model and predict preferences in sequential data.
Installation
To get started with nanoDPO
, simply install the package using pip:
pip install nanoDPO
Key Features
- Causal Transformer & Simple Sequence Model: Incorporates both a Causal Transformer and a LSTM-based Simple Sequence Model for diverse modeling needs.
- Preference Data Simulation: Utilizes a custom function, simulate_dpo_dataset_noise, to generate synthetic preference-based time series data.
- Sequence Data Preparation: Prepares data for training with prepare_sequence_datasets, aligning time series data with the DPO framework.
- DPO Training with PyTorch: Leverages the power of PyTorch for efficient and effective model training, complete with customizable parameters.
- MulticlassTrainer provides an additional approach to handle time series data, focusing on traditional multiclass classification tasks.
- Cross-Entropy Loss for Multiclass Classification: Optimized for handling multiple classes in time series data.
- Customizable Training and Evaluation: Flexible parameters for epochs, batch size, and learning rate.
- Model Evaluation and Visualization: Offers tools for model evaluation and metrics visualization, ensuring an insightful analysis of performance.
Usage
Import the necessary modules from nanoDPO, including the CausalTransformer, SimpleSequenceModel, and dataset preparation functions. Utilize the DPOOneModelTrainer for Direct Preference Optimization or MulticlassTrainer for conventional multiclass training.
import torch
from nanodpo.causal_transformer import CausalTransformer
from nanodpo.simple_sequence_model import SimpleSequenceModel
from nanodpo.preference_data import simulate_dpo_dataset_noise
from nanodpo.sequence_data import prepare_sequence_datasets
from nanodpo.dpo_onemodel_trainer import DPOOneModelTrainer
from nanodpo.multiclass_trainer import MulticlassTrainer
# Initialize and train your model
...
model = CausalTransformer(d_feature= feature_dim, d_model=d_model, n_head=n_head, n_layer=n_layer,
num_actions=num_actions, max_len=sequence_len,
device=device).to(device)
...
trainer = DPOOneModelTrainer(model=model, model_dir=f"dpo_{model_type}_model/", device=device,
learning_rate=learning_rate, batch_size=batch_size)
trainer.train(train_dataset, test_dataset, epochs=epochs, eval_interval=eval_interval)
# Evaluate and visualize the results
trainer.plot_metrics()
trainer.evaluate(test_dataset)
License
nanoDPO is licensed under the Apache License 2.0. See the LICENSE file for more details.
Acknowledgments
Inspired by the paper "Direct Preference Optimization: Your Language Model is Secretly a Reward Model," nanoDPO extends the concept of DPO to the domain of time series data, opening new avenues for research and application.
Citation
@misc{rafailov2023direct,
title = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
year = {2023},
eprint = {2305.18290},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file nanodpo-0.1.post1.tar.gz
.
File metadata
- Download URL: nanodpo-0.1.post1.tar.gz
- Upload date:
- Size: 14.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.6.1 CPython/3.10.10 Linux/5.4.0-166-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8577ba21441395d5fd1259b1109f4636e3c61d862620b31d653733a737b53081 |
|
MD5 | 969ff8a253df548697bac35c1f1088ea |
|
BLAKE2b-256 | e8c966edd600487776f1abf806338acae98a065d5aba9893871dd6fc3b87bc72 |
File details
Details for the file nanodpo-0.1.post1-py2.py3-none-any.whl
.
File metadata
- Download URL: nanodpo-0.1.post1-py2.py3-none-any.whl
- Upload date:
- Size: 17.5 kB
- Tags: Python 2, Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.6.1 CPython/3.10.10 Linux/5.4.0-166-generic
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | ac1593cf854a834f43274b41744b426d20b00f9ae5c1a29f15c4f5dfdb4e9a0a |
|
MD5 | 8a4336931541ae20f170d4a7107cc087 |
|
BLAKE2b-256 | 7b6996ca903dfabc9e549b3545f14b5f4a4ecd302cef7354c60e1c57ff582d33 |