Skip to main content

A nimble and innovative implementation of the Direct Preference Optimization (DPO) algorithm with Causal Transformer and LSTM model for time series data, inspired by the paper of DPO in fine-tuning unsupervised Language Models

Project description

nanoDPO: Direct Preference Optimization for Time Series Data

PyPI Changelog Tests Documentation Status License

Introduction

Welcome to nanoDPO – a novel, cutting-edge library for Direct Preference Optimization (DPO) tailored for time series data. Inspired by the concept of utilizing DPO in fine-tuning unsupervised Language Models (LMs) to align with user preferences, nanoDPO pivots this approach to the realm of time series analysis. This library offers a unique perspective and toolset for time series forecasting, leveraging the principles of DPO to model and predict preferences in sequential data.

Installation

To get started with nanoDPO, simply install the package using pip:

pip install nanoDPO

Key Features

  • Causal Transformer & Simple Sequence Model: Incorporates both a Causal Transformer and a LSTM-based Simple Sequence Model for diverse modeling needs.
  • Preference Data Simulation: Utilizes a custom function, simulate_dpo_dataset_noise, to generate synthetic preference-based time series data.
  • Sequence Data Preparation: Prepares data for training with prepare_sequence_datasets, aligning time series data with the DPO framework.
  • DPO Training with PyTorch: Leverages the power of PyTorch for efficient and effective model training, complete with customizable parameters.
  • MulticlassTrainer provides an additional approach to handle time series data, focusing on traditional multiclass classification tasks.
  • Cross-Entropy Loss for Multiclass Classification: Optimized for handling multiple classes in time series data.
  • Customizable Training and Evaluation: Flexible parameters for epochs, batch size, and learning rate.
  • Model Evaluation and Visualization: Offers tools for model evaluation and metrics visualization, ensuring an insightful analysis of performance.

Usage

Import the necessary modules from nanoDPO, including the CausalTransformer, SimpleSequenceModel, and dataset preparation functions. Utilize the DPOOneModelTrainer for Direct Preference Optimization or MulticlassTrainer for conventional multiclass training.

import torch
from nanodpo.causal_transformer import CausalTransformer
from nanodpo.simple_sequence_model import SimpleSequenceModel
from nanodpo.preference_data import simulate_dpo_dataset_noise
from nanodpo.sequence_data import prepare_sequence_datasets
from nanodpo.dpo_onemodel_trainer import DPOOneModelTrainer
from nanodpo.multiclass_trainer import MulticlassTrainer

# Initialize and train your model
...
model = CausalTransformer(d_feature= feature_dim, d_model=d_model, n_head=n_head, n_layer=n_layer, 
                          num_actions=num_actions, max_len=sequence_len,
                          device=device).to(device)
...
trainer = DPOOneModelTrainer(model=model, model_dir=f"dpo_{model_type}_model/", device=device,
                             learning_rate=learning_rate, batch_size=batch_size)
trainer.train(train_dataset, test_dataset, epochs=epochs, eval_interval=eval_interval)
# Evaluate and visualize the results
trainer.plot_metrics()
trainer.evaluate(test_dataset)

wandb dpo causal_transformer

License

nanoDPO is licensed under the Apache License 2.0. See the LICENSE file for more details.

Acknowledgments

Inspired by the paper "Direct Preference Optimization: Your Language Model is Secretly a Reward Model," nanoDPO extends the concept of DPO to the domain of time series data, opening new avenues for research and application.

Citation

@misc{rafailov2023direct,
  title   = {Direct Preference Optimization: Your Language Model is Secretly a Reward Model}, 
  author  = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
  year    = {2023},
  eprint  = {2305.18290},
  archivePrefix = {arXiv},
  primaryClass = {cs.LG}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nanodpo-0.1.post1.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

nanodpo-0.1.post1-py2.py3-none-any.whl (17.5 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file nanodpo-0.1.post1.tar.gz.

File metadata

  • Download URL: nanodpo-0.1.post1.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.10.10 Linux/5.4.0-166-generic

File hashes

Hashes for nanodpo-0.1.post1.tar.gz
Algorithm Hash digest
SHA256 8577ba21441395d5fd1259b1109f4636e3c61d862620b31d653733a737b53081
MD5 969ff8a253df548697bac35c1f1088ea
BLAKE2b-256 e8c966edd600487776f1abf806338acae98a065d5aba9893871dd6fc3b87bc72

See more details on using hashes here.

File details

Details for the file nanodpo-0.1.post1-py2.py3-none-any.whl.

File metadata

  • Download URL: nanodpo-0.1.post1-py2.py3-none-any.whl
  • Upload date:
  • Size: 17.5 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.6.1 CPython/3.10.10 Linux/5.4.0-166-generic

File hashes

Hashes for nanodpo-0.1.post1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 ac1593cf854a834f43274b41744b426d20b00f9ae5c1a29f15c4f5dfdb4e9a0a
MD5 8a4336931541ae20f170d4a7107cc087
BLAKE2b-256 7b6996ca903dfabc9e549b3545f14b5f4a4ecd302cef7354c60e1c57ff582d33

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page