Skip to main content

Accelerating Model Predictive Control via Neural Networks

Project description

TransformerMPC: Accelerating Model Predictive Control via Transformers [ICRA '25]

Vrushabh Zinage1 · Ahmed Khalil1 · Efstathios Bakolas1

1University of Texas at Austin

arXiv ProjectPage

Overview

TransformerMPC improves the computational efficiency of Model Predictive Control (MPC) problems using neural network models. It employs the following two prediction models:

  1. Constraint Predictor: Identifies inactive constraints in MPC formulations
  2. Warm Start Predictor: Generates better initial points for MPC solvers

By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality.

Package Structure

The package is organized with a standard Python package structure:

transformermpc/
├── transformermpc/       # Core package module
│   ├── data/             # Data generation utilities
│   ├── models/           # Model implementations
│   ├── utils/            # Utility functions and metrics
│   ├── training/         # Training infrastructure
│   └── demo/             # Demo scripts
├── scripts/              # Demo and utility scripts
├── tests/                # Testing infrastructure
├── setup.py              # Package installation script
└── requirements.txt      # Dependencies

Installation

Install directly from PyPI:

pip install transformermpc

Or install from source:

git clone https://github.com/Vrushabh27/transformermpc.git
cd transformermpc
pip install -e .

Dependencies

  • Python >= 3.7
  • PyTorch >= 1.9.0
  • OSQP >= 0.6.2
  • NumPy, SciPy, and other scientific computing libraries
  • Additional dependencies specified in requirements.txt

Running the Demos

The package includes several demo scripts to showcase its capabilities:

Boxplot Demo (Recommended)

python scripts/boxplot_demo.py

This script provides a visual comparison of different QP solving strategies using randomly generated problems, without requiring model training. It demonstrates the core concepts behind TransformerMPC by showing the performance impact of:

  • Removing inactive constraints
  • Using warm starts with different qualities
  • Combining these strategies

The visualizations include boxplots, violin plots, and bar charts comparing solve times across different strategies.

Simple Demo

python scripts/simple_demo.py

This script demonstrates the complete pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the demo_results/results directory.

Verify Package Structure

To check that the package is installed correctly:

python scripts/verify_structure.py

Customizing Demo Parameters

You can customize the boxplot demo by modifying parameters:

# Generate more problems with different dimensions
python scripts/boxplot_demo.py --num_samples 50 --state_dim 6 --input_dim 3 --horizon 10

# Save results to a custom directory 
python scripts/boxplot_demo.py --output_dir my_results

Similarly, for the simple demo:

# Generate QP problems with specific parameters
python scripts/simple_demo.py --num_samples 200 --state_dim 6 --input_dim 3 --horizon 10

# Customize training parameters
python scripts/simple_demo.py --epochs 20 --batch_size 32

# Use GPU for training if available
python scripts/simple_demo.py --use_gpu

Usage in Projects

Basic Example

from transformermpc import TransformerMPC
import numpy as np

# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])

# Initialize the TransformerMPC solver
solver = TransformerMPC()

# Solve with model acceleration
solution, solve_time = solver.solve(Q, c, A, b)

print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")

General Usage

from transformermpc import TransformerMPC, QPProblem
import numpy as np

# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])
initial_state = np.array([0.5, 0.5])  # Optional: initial state for MPC problems

# Create a QP problem instance
qp_problem = QPProblem(
    Q=Q,
    c=c,
    A=A,
    b=b,
    initial_state=initial_state  # Optional
)

# Initialize with custom settings
solver = TransformerMPC(
    use_constraint_predictor=True,
    use_warm_start_predictor=True,
    fallback_on_violation=True
)

# Solve the problem
solution, solve_time = solver.solve(qp_problem=qp_problem)
print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")

# Compare with baseline
baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem)
print(f"Baseline time: {baseline_time} seconds")

If you find our work useful, please cite us

@article{zinage2024transformermpc,
  title={TransformerMPC: Accelerating Model Predictive Control via Transformers},
  author={Zinage, Vrushabh and Khalil, Ahmed and Bakolas, Efstathios},
  journal={arXiv preprint arXiv:2409.09266},
  year={2024}
}

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformermpc-0.1.7.tar.gz (34.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transformermpc-0.1.7-py3-none-any.whl (34.2 kB view details)

Uploaded Python 3

File details

Details for the file transformermpc-0.1.7.tar.gz.

File metadata

  • Download URL: transformermpc-0.1.7.tar.gz
  • Upload date:
  • Size: 34.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for transformermpc-0.1.7.tar.gz
Algorithm Hash digest
SHA256 194eb4db3d3b524eaec293a4af61c1e875b53a1b334686be63d208d3231cc40f
MD5 073776f8e58fb901b7e1fde0fd7e74c2
BLAKE2b-256 61f175c773043f564542a22c7b60bf2f9fa5c332b08c696a58d8ffb78e3129b4

See more details on using hashes here.

File details

Details for the file transformermpc-0.1.7-py3-none-any.whl.

File metadata

  • Download URL: transformermpc-0.1.7-py3-none-any.whl
  • Upload date:
  • Size: 34.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for transformermpc-0.1.7-py3-none-any.whl
Algorithm Hash digest
SHA256 6ca938480bd05098a660ea7bd3fdd974185038a3a7f0ad416b871d60c0137202
MD5 b28f4c15d6961826bc01e7f4954b30d4
BLAKE2b-256 beb5b8da4cfad43033531a129ff72950dfc0128152e2426edeb639f3d6112fc9

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page