Skip to main content

Accelerating Model Predictive Control via Neural Networks

Project description

TransformerMPC: Accelerating Model Predictive Control via Transformers [ICRA '25]

Vrushabh Zinage1 · Ahmed Khalil1 · Efstathios Bakolas1

1University of Texas at Austin

arXiv ProjectPage

Overview

TransformerMPC improves the computational efficiency of Model Predictive Control (MPC) problems using neural network models. It employs the following two prediction models:

  1. Constraint Predictor: Identifies inactive constraints in MPC formulations
  2. Warm Start Predictor: Generates better initial points for MPC solvers

By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality.

Package Structure

The package is organized with a standard Python package structure:

transformermpc/
├── transformermpc/       # Core package module
│   ├── data/             # Data generation utilities
│   ├── models/           # Model implementations
│   ├── utils/            # Utility functions and metrics
│   ├── training/         # Training infrastructure
│   └── demo/             # Demo scripts
├── scripts/              # Demo and utility scripts
├── tests/                # Testing infrastructure
├── setup.py              # Package installation script
└── requirements.txt      # Dependencies

Installation

Install directly from PyPI:

pip install transformermpc

Or install from source:

git clone https://github.com/vrushabh/transformermpc.git
cd transformermpc
pip install -e .

Dependencies

  • Python >= 3.7
  • PyTorch >= 1.9.0
  • OSQP >= 0.6.2
  • NumPy, SciPy, and other scientific computing libraries
  • Additional dependencies specified in requirements.txt

Running the Demos

The package includes several demo scripts to showcase its capabilities:

Boxplot Demo (Recommended)

python scripts/boxplot_demo.py

This script provides a visual comparison of different QP solving strategies using randomly generated problems, without requiring model training. It demonstrates the core concepts behind TransformerMPC by showing the performance impact of:

  • Removing inactive constraints
  • Using warm starts with different qualities
  • Combining these strategies

The visualizations include boxplots, violin plots, and bar charts comparing solve times across different strategies.

Simple Demo

python scripts/simple_demo.py

This script demonstrates the complete pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the demo_results/results directory.

Verify Package Structure

To check that the package is installed correctly:

python scripts/verify_structure.py

Customizing Demo Parameters

You can customize the boxplot demo by modifying parameters:

# Generate more problems with different dimensions
python scripts/boxplot_demo.py --num_samples 50 --state_dim 6 --input_dim 3 --horizon 10

# Save results to a custom directory 
python scripts/boxplot_demo.py --output_dir my_results

Similarly, for the simple demo:

# Generate QP problems with specific parameters
python scripts/simple_demo.py --num_samples 200 --state_dim 6 --input_dim 3 --horizon 10

# Customize training parameters
python scripts/simple_demo.py --epochs 20 --batch_size 32

# Use GPU for training if available
python scripts/simple_demo.py --use_gpu

Usage in Projects

Basic Example

from transformermpc import TransformerMPC
import numpy as np

# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])

# Initialize the TransformerMPC solver
solver = TransformerMPC()

# Solve with model acceleration
solution, solve_time = solver.solve(Q, c, A, b)

print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")

General Usage

from transformermpc import TransformerMPC, QPProblem
import numpy as np

# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])
initial_state = np.array([0.5, 0.5])  # Optional: initial state for MPC problems

# Create a QP problem instance
qp_problem = QPProblem(
    Q=Q,
    c=c,
    A=A,
    b=b,
    initial_state=initial_state  # Optional
)

# Initialize with custom settings
solver = TransformerMPC(
    use_constraint_predictor=True,
    use_warm_start_predictor=True,
    fallback_on_violation=True
)

# Solve the problem
solution, solve_time = solver.solve(qp_problem=qp_problem)
print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")

# Compare with baseline
baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem)
print(f"Baseline time: {baseline_time} seconds")

If you find our work useful, please cite us

@article{zinage2024transformermpc,
  title={TransformerMPC: Accelerating Model Predictive Control via Transformers},
  author={Zinage, Vrushabh and Khalil, Ahmed and Bakolas, Efstathios},
  journal={arXiv preprint arXiv:2409.09266},
  year={2024}
}

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformermpc-0.1.6.tar.gz (34.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transformermpc-0.1.6-py3-none-any.whl (34.2 kB view details)

Uploaded Python 3

File details

Details for the file transformermpc-0.1.6.tar.gz.

File metadata

  • Download URL: transformermpc-0.1.6.tar.gz
  • Upload date:
  • Size: 34.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for transformermpc-0.1.6.tar.gz
Algorithm Hash digest
SHA256 4427198523c21843471b46df440a806ca0e673282009e8b038bcb74c84223344
MD5 c2f572e8faab3c955164d11af87fefa8
BLAKE2b-256 16bd5e50d515f679eec48d30b219f5d2bf91ed988b6655f7e52a942f812a0236

See more details on using hashes here.

File details

Details for the file transformermpc-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: transformermpc-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 34.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for transformermpc-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 f081ee8303312b8fad3bba0195d32bfe83c2cfd7a3955729ee28a2efdf291fc5
MD5 2e91b4524c862e2e66933461eabf8a04
BLAKE2b-256 9cd6152b7ba834798d7ce85cf8f452e1c46fc1e121da42ace5c338ec67178d85

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page