Skip to main content

A nimble and innovative implementation of the Direct Preference Optimization (DPO) algorithm with Causal Transformer and LSTM model for time series data, inspired by the paper of DPO in fine-tuning unsupervised Language Models

Project description

nanoDPO: Direct Preference Optimization for Time Series Data

PyPI Changelog Tests Documentation Status License

Introduction

Welcome to nanoDPO – a novel, cutting-edge library for Direct Preference Optimization (DPO) tailored for time series data. Inspired by the concept of utilizing DPO in fine-tuning unsupervised Language Models (LMs) to align with user preferences, nanoDPO pivots this approach to the realm of time series analysis. This library offers a unique perspective and toolset for time series forecasting, leveraging the principles of DPO to model and predict preferences in sequential data.

Installation

To get started with nanoDPO, simply install the package using pip:

pip install nanoDPO

Key Features

  • Causal Transformer & Simple Sequence Model: Incorporates both a Causal Transformer and a LSTM-based Simple Sequence Model for diverse modeling needs.
  • Preference Data Simulation: Utilizes a custom function, simulate_dpo_dataset_noise, to generate synthetic preference-based time series data.
  • Sequence Data Preparation: Prepares data for training with prepare_sequence_datasets, aligning time series data with the DPO framework.
  • DPO Training with PyTorch: Leverages the power of PyTorch for efficient and effective model training, complete with customizable parameters.
  • MulticlassTrainer provides an additional approach to handle time series data, focusing on traditional multiclass classification tasks.
  • Cross-Entropy Loss for Multiclass Classification: Optimized for handling multiple classes in time series data.
  • Customizable Training and Evaluation: Flexible parameters for epochs, batch size, and learning rate.
  • Model Evaluation and Visualization: Offers tools for model evaluation and metrics visualization, ensuring an insightful analysis of performance.

Usage

Import the necessary modules from nanoDPO, including the CausalTransformer, SimpleSequenceModel, and dataset preparation functions. Utilize the DPOOneModelTrainer for Direct Preference Optimization or MulticlassTrainer for conventional multiclass training.

import torch
from nanodpo.causal_transformer import CausalTransformer
from nanodpo.simple_sequence_model import SimpleSequenceModel
from nanodpo.preference_data import simulate_dpo_dataset_noise
from nanodpo.sequence_data import prepare_sequence_datasets
from nanodpo.dpo_onemodel_trainer import DPOOneModelTrainer
from nanodpo.multiclass_trainer import MulticlassTrainer

# Initialize and train your model
...
model = CausalTransformer(d_feature= feature_dim, d_model=d_model, n_head=n_head, n_layer=n_layer, 
                          num_actions=num_actions, max_len=sequence_len,
                          device=device).to(device)
...
trainer = DPOOneModelTrainer(model=model, model_dir=f"dpo_{model_type}_model/", device=device,
                             learning_rate=learning_rate, batch_size=batch_size)
trainer.train(train_dataset, test_dataset, epochs=epochs, eval_interval=eval_interval)
# Evaluate and visualize the results
trainer.plot_metrics()
trainer.evaluate(test_dataset)

wandb dpo causal_transformer

License

nanoDPO is licensed under the Apache License 2.0. See the LICENSE file for more details.

Acknowledgments

Inspired by the paper "Direct Preference Optimization: Your Language Model is Secretly a Reward Model," nanoDPO extends the concept of DPO to the domain of time series data, opening new avenues for research and application.

Citation

@misc{rafailov2023direct,
  title   = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model}, 
  author  = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
  year    = {2023},
  eprint  = {2305.18290},
  archivePrefix = {arXiv},
  primaryClass = {cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nanodpo-0.1.post1.tar.gz (14.3 kB view hashes)

Uploaded Source

Built Distribution

nanodpo-0.1.post1-py2.py3-none-any.whl (17.5 kB view hashes)

Uploaded Python 2 Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page