Accelerating Model Predictive Control via Neural Networks
Project description
TransformerMPC: Accelerating Model Predictive Control via Transformers [ICRA '25]
Vrushabh Zinage1 · Ahmed Khalil1 · Efstathios Bakolas1
1University of Texas at Austin
Overview
TransformerMPC improves the computational efficiency of Model Predictive Control (MPC) problems using neural network models. It employs the following two prediction models:
- Constraint Predictor: Identifies inactive constraints in MPC formulations
- Warm Start Predictor: Generates better initial points for MPC solvers
By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality.
Package Structure
The package is organized with a standard Python package structure:
transformermpc/
├── transformermpc/ # Core package module
│ ├── data/ # Data generation utilities
│ ├── models/ # Model implementations
│ ├── utils/ # Utility functions and metrics
│ ├── training/ # Training infrastructure
│ └── demo/ # Demo scripts
├── scripts/ # Demo and utility scripts
├── tests/ # Testing infrastructure
├── setup.py # Package installation script
└── requirements.txt # Dependencies
Installation
Install directly from PyPI:
pip install transformermpc
Or install from source:
git clone https://github.com/Vrushabh27/transformermpc.git
cd transformermpc
pip install -e .
Dependencies
- Python >= 3.7
- PyTorch >= 1.9.0
- OSQP >= 0.6.2
- NumPy, SciPy, and other scientific computing libraries
- Additional dependencies specified in requirements.txt
Running the Demos
The package includes several demo scripts to showcase its capabilities:
Boxplot Demo (Recommended)
python scripts/boxplot_demo.py
This script provides a visual comparison of different QP solving strategies using randomly generated problems, without requiring model training. It demonstrates the core concepts behind TransformerMPC by showing the performance impact of:
- Removing inactive constraints
- Using warm starts with different qualities
- Combining these strategies
The visualizations include boxplots, violin plots, and bar charts comparing solve times across different strategies.
Simple Demo
python scripts/simple_demo.py
This script demonstrates the complete pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the demo_results/results directory.
Verify Package Structure
To check that the package is installed correctly:
python scripts/verify_structure.py
Customizing Demo Parameters
You can customize the boxplot demo by modifying parameters:
# Generate more problems with different dimensions
python scripts/boxplot_demo.py --num_samples 50 --state_dim 6 --input_dim 3 --horizon 10
# Save results to a custom directory
python scripts/boxplot_demo.py --output_dir my_results
Similarly, for the simple demo:
# Generate QP problems with specific parameters
python scripts/simple_demo.py --num_samples 200 --state_dim 6 --input_dim 3 --horizon 10
# Customize training parameters
python scripts/simple_demo.py --epochs 20 --batch_size 32
# Use GPU for training if available
python scripts/simple_demo.py --use_gpu
Usage in Projects
Basic Example
from transformermpc import TransformerMPC
import numpy as np
# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])
# Initialize the TransformerMPC solver
solver = TransformerMPC()
# Solve with model acceleration
solution, solve_time = solver.solve(Q, c, A, b)
print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")
General Usage
from transformermpc import TransformerMPC, QPProblem
import numpy as np
# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])
initial_state = np.array([0.5, 0.5]) # Optional: initial state for MPC problems
# Create a QP problem instance
qp_problem = QPProblem(
Q=Q,
c=c,
A=A,
b=b,
initial_state=initial_state # Optional
)
# Initialize with custom settings
solver = TransformerMPC(
use_constraint_predictor=True,
use_warm_start_predictor=True,
fallback_on_violation=True
)
# Solve the problem
solution, solve_time = solver.solve(qp_problem=qp_problem)
print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")
# Compare with baseline
baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem)
print(f"Baseline time: {baseline_time} seconds")
If you find our work useful, please cite us
@article{zinage2024transformermpc,
title={TransformerMPC: Accelerating Model Predictive Control via Transformers},
author={Zinage, Vrushabh and Khalil, Ahmed and Bakolas, Efstathios},
journal={arXiv preprint arXiv:2409.09266},
year={2024}
}
License
This project is licensed under the MIT License.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file transformermpc-0.1.7.tar.gz.
File metadata
- Download URL: transformermpc-0.1.7.tar.gz
- Upload date:
- Size: 34.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
194eb4db3d3b524eaec293a4af61c1e875b53a1b334686be63d208d3231cc40f
|
|
| MD5 |
073776f8e58fb901b7e1fde0fd7e74c2
|
|
| BLAKE2b-256 |
61f175c773043f564542a22c7b60bf2f9fa5c332b08c696a58d8ffb78e3129b4
|
File details
Details for the file transformermpc-0.1.7-py3-none-any.whl.
File metadata
- Download URL: transformermpc-0.1.7-py3-none-any.whl
- Upload date:
- Size: 34.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.11.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
6ca938480bd05098a660ea7bd3fdd974185038a3a7f0ad416b871d60c0137202
|
|
| MD5 |
b28f4c15d6961826bc01e7f4954b30d4
|
|
| BLAKE2b-256 |
beb5b8da4cfad43033531a129ff72950dfc0128152e2426edeb639f3d6112fc9
|