Skip to main content

Accelerating Model Predictive Control via Transformers

Project description

TransformerMPC: Accelerating Model Predictive Control via Transformers [ICRA '25]

Vrushabh Zinage1 · Ahmed Khalil1 · Efstathios Bakolas1

1University of Texas at Austin

arXiv ProjectPage

Overview

TransformerMPC improves the computational efficiency of Model Predictive Control (MPC) problems using transformer-based neural networks. It employs the following two transformer models:

  1. Constraint Predictor: Identifies inactive constraints in MPC formulations
  2. Warm Start Predictor: Generates better initial points for MPC solvers

By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality.

Installation

Install directly from PyPI:

pip install transformermpc

Or install from source:

git clone https://github.com/vrushabh/transformermpc.git
cd transformermpc
pip install -e .

Dependencies

  • Python >= 3.7
  • PyTorch >= 1.9.0
  • Transformers >= 4.15.0
  • OSQP >= 0.6.2
  • Additional dependencies specified in requirements.txt

Running the Demo

The package includes a simplified demo script that demonstrates the complete workflow:

python simple_demo.py

This script performs the entire pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the demo_results/results directory.

Additional Scripts

The package also includes these utility scripts:

  1. run_demo.py: A wrapper script that executes the main demo module. Use it when working with the installed package:

    python run_demo.py
    
  2. test_benchmark.py: For comprehensive performance evaluation with more metrics and visualizations:

    python test_benchmark.py
    

    This script focuses on benchmarking performance across multiple problems and generates detailed visualizations, including boxplots comparing solve times between the standard OSQP solver and transformer-enhanced methods.

All scripts save their results (including performance metrics and visualization plots) in the demo_results/results directory. The key visualizations include:

  • solve_time_comparison.png: Line plot comparing baseline vs. transformer solve times
  • solve_time_boxplot.png: Statistical distribution of solve times across different solver configurations

Customizing Demo Parameters

You can customize the demo by modifying the following parameters:

Data Generation Parameters:

# Generate 5000 QP problems with state dimension 6, input dimension 3, and horizon 15
python simple_demo.py --num_samples 5000 --state_dim 6 --input_dim 3 --horizon 15

Training Parameters:

# Train the constraint predictor for 200 epochs and warm start predictor for 300 epochs
python simple_demo.py --cp_epochs 200 --ws_epochs 300 --batch_size 128

The number of epochs and samples significantly impact training time and model performance:

  • --cp_epochs: Number of training epochs for the Constraint Predictor model
  • --ws_epochs: Number of training epochs for the Warm Start Predictor model
  • --num_samples: Number of QP problems to generate for training

Increasing these values will generally improve model accuracy but require more computation time. For quick experimentation, use lower values (e.g., 50-100 epochs, 1000-2000 samples). For production-quality models, consider higher values (300+ epochs, 5000+ samples).

Hardware Options:

# Use GPU for training if available
python simple_demo.py --use_gpu

Other Options:

# Skip training and use pre-trained models if available
python simple_demo.py --skip_training

# Specify custom output directory
python simple_demo.py --output_dir custom_results

Available Options

Parameter Description Default
--num_samples Number of QP problems to generate 2000
--state_dim State dimension for MPC problems 4
--input_dim Input dimension for MPC problems 2
--horizon Time horizon for MPC problems 10
--cp_epochs Epochs for constraint predictor training 100
--ws_epochs Epochs for warm start predictor training 100
--batch_size Batch size for training 64
--test_size Fraction of data to use for testing 0.2
--output_dir Directory to save results "demo_results"
--skip_training Skip training and use pretrained models False
--use_gpu Use GPU if available False

Usage in Projects

Basic Example

from transformermpc import TransformerMPC
import numpy as np

# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])

# Initialize the TransformerMPC solver
solver = TransformerMPC()
# Note: By default, this uses pre-trained models that come with the package.
# To train custom models with different epochs and samples, use the demo script
# or the training utilities as described in the "Customizing Demo Parameters" section.

# Solve with transformer acceleration
solution, solve_time = solver.solve(Q, c, A, b)

print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")

General Usage

from transformermpc import TransformerMPC, QPProblem
import numpy as np

# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])
initial_state = np.array([0.5, 0.5])  # Optional: initial state for MPC problems

# Create a QP problem instance
qp_problem = QPProblem(
    Q=Q,
    c=c,
    A=A,
    b=b,
    initial_state=initial_state  # Optional
)

# Initialize with custom settings
solver = TransformerMPC(
    use_constraint_predictor=True,
    use_warm_start_predictor=True,
    fallback_on_violation=True
)

# Solve the problem
solution, solve_time = solver.solve(qp_problem=qp_problem)
print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")

# Compare with baseline
baseline_solution, baseline_time = solver.solve_baseline(qp_problem=qp_problem)
print(f"Baseline time: {baseline_time} seconds")

If you find our work useful, please cite us

@article{zinage2024transformermpc,
  title={TransformerMPC: Accelerating Model Predictive Control via Transformers},
  author={Zinage, Vrushabh and Khalil, Ahmed and Bakolas, Efstathios},
  journal={arXiv preprint arXiv:2409.09266},
  year={2024}
}

License

This project is licensed under the MIT License.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformermpc-0.1.4.tar.gz (51.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transformermpc-0.1.4-py3-none-any.whl (62.3 kB view details)

Uploaded Python 3

File details

Details for the file transformermpc-0.1.4.tar.gz.

File metadata

  • Download URL: transformermpc-0.1.4.tar.gz
  • Upload date:
  • Size: 51.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for transformermpc-0.1.4.tar.gz
Algorithm Hash digest
SHA256 43b98fe5a141e525664adc40740b79be46c2e8bba33608108f93df5d91a9aa12
MD5 3a48e330000be40251d3189e35b5b3df
BLAKE2b-256 09c87b15faf38d33337af649a968c6f077a1b39483fd957b7be85516c6ec36de

See more details on using hashes here.

File details

Details for the file transformermpc-0.1.4-py3-none-any.whl.

File metadata

  • Download URL: transformermpc-0.1.4-py3-none-any.whl
  • Upload date:
  • Size: 62.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for transformermpc-0.1.4-py3-none-any.whl
Algorithm Hash digest
SHA256 8e8bb96630dcd716d7ef883b6427ee3fb25d7a674cf190bb64ee606de16575ea
MD5 2f1b2aeb35f5021700b8e70f4d6847a2
BLAKE2b-256 68bd7a96ed06e90bbdaee936c5ca26c7e1a4649b21775a6ccef46b10136683d4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page