Accelerating Model Predictive Control via Transformers
Project description
TransformerMPC
Accelerating Model Predictive Control via Transformers
Overview
TransformerMPC is a Python package that enhances the efficiency of solving Quadratic Programming (QP) problems in Model Predictive Control (MPC) using transformer-based neural networks. It employs two specialized transformer models:
- Constraint Predictor: Identifies inactive constraints in QP formulations
- Warm Start Predictor: Generates better initial points for QP solvers
By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality.
Installation
Install directly from PyPI:
pip install transformermpc
Or install from source:
git clone https://github.com/vrushabh/transformermpc.git
cd transformermpc
pip install -e .
Dependencies
- Python >= 3.7
- PyTorch >= 1.9.0
- Transformers >= 4.15.0
- OSQP >= 0.6.2
- Additional dependencies specified in requirements.txt
Running the Demo
The package includes a simplified demo script that demonstrates the complete workflow:
python simple_demo.py
This script performs the entire pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the demo_results/results directory.
Additional Scripts
The package also includes these utility scripts:
-
run_demo.py: A wrapper script that executes the main demo module. Use it when working with the installed package:
python run_demo.py -
test_benchmark.py: For comprehensive performance evaluation with more metrics and visualizations:
python test_benchmark.pyThis script focuses on benchmarking performance across multiple problems and generates detailed visualizations, including boxplots comparing solve times between the standard OSQP solver and transformer-enhanced methods.
All scripts save their results (including performance metrics and visualization plots) in the demo_results/results directory. The key visualizations include:
- solve_time_comparison.png: Line plot comparing baseline vs. transformer solve times
- solve_time_boxplot.png: Statistical distribution of solve times across different solver configurations
Customizing Demo Parameters
You can customize the demo by modifying the following parameters:
Data Generation Parameters:
# Generate 5000 QP problems with state dimension 6, input dimension 3, and horizon 15
python simple_demo.py --num_samples 5000 --state_dim 6 --input_dim 3 --horizon 15
Training Parameters:
# Train the constraint predictor for 200 epochs and warm start predictor for 300 epochs
python simple_demo.py --cp_epochs 200 --ws_epochs 300 --batch_size 128
The number of epochs and samples significantly impact training time and model performance:
--cp_epochs: Number of training epochs for the Constraint Predictor model--ws_epochs: Number of training epochs for the Warm Start Predictor model--num_samples: Number of QP problems to generate for training
Increasing these values will generally improve model accuracy but require more computation time. For quick experimentation, use lower values (e.g., 50-100 epochs, 1000-2000 samples). For production-quality models, consider higher values (300+ epochs, 5000+ samples).
Hardware Options:
# Use GPU for training if available
python simple_demo.py --use_gpu
Other Options:
# Skip training and use pre-trained models if available
python simple_demo.py --skip_training
# Specify custom output directory
python simple_demo.py --output_dir custom_results
All Available Options
| Parameter | Description | Default |
|---|---|---|
--num_samples |
Number of QP problems to generate | 2000 |
--state_dim |
State dimension for MPC problems | 4 |
--input_dim |
Input dimension for MPC problems | 2 |
--horizon |
Time horizon for MPC problems | 10 |
--cp_epochs |
Epochs for constraint predictor training | 100 |
--ws_epochs |
Epochs for warm start predictor training | 100 |
--batch_size |
Batch size for training | 64 |
--test_size |
Fraction of data to use for testing | 0.2 |
--output_dir |
Directory to save results | "demo_results" |
--skip_training |
Skip training and use pretrained models | False |
--use_gpu |
Use GPU if available | False |
Usage in Projects
Basic Example
from transformermpc import TransformerMPC
import numpy as np
# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])
# Initialize the TransformerMPC solver
solver = TransformerMPC()
# Note: By default, this uses pre-trained models that come with the package.
# To train custom models with different epochs and samples, use the demo script
# or the training utilities as described in the "Customizing Demo Parameters" section.
# Solve with transformer acceleration
solution, solve_time = solver.solve(Q, c, A, b)
print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")
Advanced Usage
from transformermpc import TransformerMPC, QPProblem
import numpy as np
# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])
initial_state = np.array([0.5, 0.5]) # Optional: initial state for MPC problems
# Create a QP problem instance
qp_problem = QPProblem(
Q=Q,
c=c,
A=A,
b=b,
initial_state=initial_state # Optional
)
# Initialize with custom settings
solver = TransformerMPC(
use_constraint_predictor=True,
use_warm_start_predictor=True,
fallback_on_violation=True
)
# Solve the problem
solution, solve_time = solver.solve(qp_problem=qp_problem)
print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")
# Compare with baseline
baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem)
print(f"Baseline time: {baseline_time} seconds")
License
This project is licensed under the MIT License - see the LICENSE file for details.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file transformermpc-0.1.3.tar.gz.
File metadata
- Download URL: transformermpc-0.1.3.tar.gz
- Upload date:
- Size: 49.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
1409786e5b44044ba833b4d6bcaf60e8b508f35316601aa83b58b28e51a28799
|
|
| MD5 |
38e5acd29638d15113bfde1b6c1354f9
|
|
| BLAKE2b-256 |
3538fe42dcae211659e108ec802a0f8f1d286ee4e997cf0199adbf5ce395d395
|
File details
Details for the file transformermpc-0.1.3-py3-none-any.whl.
File metadata
- Download URL: transformermpc-0.1.3-py3-none-any.whl
- Upload date:
- Size: 60.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8d9cbb62a961dd51a43b8f99a5756d5b2e9769273d1213e48f7dea2179dec0b6
|
|
| MD5 |
663e379ca21ab304a8ce61dc1020fe2c
|
|
| BLAKE2b-256 |
6329a3b11019a6425f11e09924e0a0edbed2a8aa024528454058fa1da867d8a9
|