A Policy Gradient RL agent for time series prediction using PyTorch Lightning.
Project description
TimeSeries Agent
A Python package implementing a Policy Gradient Reinforcement Learning agent for predicting directional movements (Up, Down, Same) in time series data. This package uses PyTorch Lightning for structuring the training process.
Features
- Implements a basic Policy Gradient algorithm.
- Uses a flexible Multi-Layer Perceptron (MLP) as the policy network.
- Configurable lookback window, hidden layers, and normalization.
- Handles multi-feature time series data.
- Uses epsilon-greedy strategy for exploration/exploitation balance with decay.
- Built with PyTorch Lightning for easier training management and logging (via CSVLogger).
Installation
pip install timeseries_agent
You may also need to install the following dependencies:
pip install lightning torch
Usage
Using the Example Notebook
The package includes an example Jupyter notebook (examples/example_timeseries_agent.ipynb) that demonstrates how to:
- Set up and train a reinforcement learning agent
- Load a trained model and test it with simulated live data
To run the notebook:
- Install Jupyter if you haven't already:
pip install jupyter - Navigate to the examples directory:
cd examples - Launch Jupyter:
jupyter notebook - Open
example_timeseries_agent.ipynb
Using the Example Python Scripts
There are two example scripts provided: example_train.py and example_test.py.
Training
example_train.py is used to train the agent.
python examples/example_train.py
This script:
- Generates sample multi-feature time series data using Pandas.
- Creates an
RLTimeSeriesDatasetandDataLoader. - Instantiates the
PolicyGradientAgent. - Trains the agent using PyTorch Lightning
Trainer. - Logs training progress to
logs/rl_agent/. - Finally saves the checkpoint in the current directory.
- Optionally plots the training progress.
Note about logging: The training process logs metrics to CSV files by default in the logs/rl_agent/ directory.
Testing
example_test.py is used to test/evaluate the trained agent.
python examples/example_test.py
This script:
- Generates sample multi-feature time series data using Pandas.
- Loads a trained agent from checkpoint.
- Evaluates the trained agent and prints results.
- Optionally plots the predicted actions vs true actions.
Using Your Own Data
- Prepare your data: Load your time series data into a Pandas DataFrame.
import pandas as pd data_df = pd.read_csv("path/to/your/data.csv", index_col="timestamp", parse_dates=True) data_df.sort_index(inplace=True)
- Adapt the example scripts:
- Replace sample data generation with your data loading.
- Update
TARGET_COLUMNandNUM_FEATURES. - Adjust hyperparameters in
PolicyGradientAgentandTrainer.
- Run your modified script:
python your_modified_script.py
Configuration
Key parameters in the examples:
PolicyGradientAgent parameters:
full_data: Pandas DataFrame.target_column: Target column name for reward.input_features: Number of features.lookback: Lookback window size.hidden_layers: MLP architecture (e.g.,[128, 64]).output_size: Number of actions (default: 3).learning_rate: Optimizer learning rate.normalize_state: Normalize state within lookback window.epsilon_start,epsilon_end,epsilon_decay_epochs: Epsilon-greedy parameters.
Trainer parameters:
max_epochs: Number of training epochs.accelerator,devices: Hardware configuration ('cpu','gpu','tpu').logger: Logging configuration (CSVLoggerdefault).
License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file timeseries_agent-0.0.21-py3-none-any.whl.
File metadata
- Download URL: timeseries_agent-0.0.21-py3-none-any.whl
- Upload date:
- Size: 11.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.0.1 CPython/3.12.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
c251b47eee239909a041afed643f61a5f0e1de23cd01f8d64ccb4ecf5eaa9d18
|
|
| MD5 |
d24c45cb8357d5ed4bc2c0d8196a9558
|
|
| BLAKE2b-256 |
e20bdb16e6f20a21860437ddcbdb63699f564da825d019e8eafb29efc4df08f8
|