Skip to main content

A Policy Gradient RL agent for time series prediction using PyTorch Lightning.

Project description

TimeSeries Agent

PyPI version License Code style: black

A Python package implementing a Policy Gradient Reinforcement Learning agent for predicting directional movements (Up, Down, Same) in time series data. This package uses PyTorch Lightning for structuring the training process.

Features

  • Implements a basic Policy Gradient algorithm.
  • Uses a flexible Multi-Layer Perceptron (MLP) as the policy network.
  • Configurable lookback window, hidden layers, and normalization.
  • Handles multi-feature time series data.
  • Uses epsilon-greedy strategy for exploration/exploitation balance with decay.
  • Built with PyTorch Lightning for easier training management and logging (via CSVLogger).

Installation

pip install timeseries_agent

You may also need to install the following dependencies:

pip install lightning torch

Usage

Using the Example Notebook

The package includes an example Jupyter notebook (examples/example_timeseries_agent.ipynb) that demonstrates how to:

  1. Set up and train a reinforcement learning agent
  2. Load a trained model and test it with simulated live data

To run the notebook:

  1. Install Jupyter if you haven't already: pip install jupyter
  2. Navigate to the examples directory: cd examples
  3. Launch Jupyter: jupyter notebook
  4. Open example_timeseries_agent.ipynb

Using the Example Python Scripts

There are two example scripts provided: example_train.py and example_test.py.

Training

example_train.py is used to train the agent.

python examples/example_train.py

This script:

  1. Generates sample multi-feature time series data using Pandas.
  2. Creates an RLTimeSeriesDataset and DataLoader.
  3. Instantiates the PolicyGradientAgent.
  4. Trains the agent using PyTorch Lightning Trainer.
  5. Logs training progress to logs/rl_agent/.
  6. Finally saves the checkpoint in the current directory.
  7. Optionally plots the training progress.

Note about logging: The training process logs metrics to CSV files by default in the logs/rl_agent/ directory.

Testing

example_test.py is used to test/evaluate the trained agent.

python examples/example_test.py

This script:

  1. Generates sample multi-feature time series data using Pandas.
  2. Loads a trained agent from checkpoint.
  3. Evaluates the trained agent and prints results.
  4. Optionally plots the predicted actions vs true actions.

Using Your Own Data

  1. Prepare your data: Load your time series data into a Pandas DataFrame.
    import pandas as pd
    data_df = pd.read_csv("path/to/your/data.csv", index_col="timestamp", parse_dates=True)
    data_df.sort_index(inplace=True)
    
  2. Adapt the example scripts:
    • Replace sample data generation with your data loading.
    • Update TARGET_COLUMN and NUM_FEATURES.
    • Adjust hyperparameters in PolicyGradientAgent and Trainer.
  3. Run your modified script:
    python your_modified_script.py
    

Configuration

Key parameters in the examples:

PolicyGradientAgent parameters:

  • full_data: Pandas DataFrame.
  • target_column: Target column name for reward.
  • input_features: Number of features.
  • lookback: Lookback window size.
  • hidden_layers: MLP architecture (e.g., [128, 64]).
  • output_size: Number of actions (default: 3).
  • learning_rate: Optimizer learning rate.
  • normalize_state: Normalize state within lookback window.
  • epsilon_start, epsilon_end, epsilon_decay_epochs: Epsilon-greedy parameters.

Trainer parameters:

  • max_epochs: Number of training epochs.
  • accelerator, devices: Hardware configuration ('cpu', 'gpu', 'tpu').
  • logger: Logging configuration (CSVLogger default).

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

timeseries_agent-0.0.21-py3-none-any.whl (11.1 kB view details)

Uploaded Python 3

File details

Details for the file timeseries_agent-0.0.21-py3-none-any.whl.

File metadata

File hashes

Hashes for timeseries_agent-0.0.21-py3-none-any.whl
Algorithm Hash digest
SHA256 c251b47eee239909a041afed643f61a5f0e1de23cd01f8d64ccb4ecf5eaa9d18
MD5 d24c45cb8357d5ed4bc2c0d8196a9558
BLAKE2b-256 e20bdb16e6f20a21860437ddcbdb63699f564da825d019e8eafb29efc4df08f8

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page