Skip to main content

Accelerating Model Predictive Control via Transformers

Project description

TransformerMPC

Accelerating Model Predictive Control via Transformers

Overview

TransformerMPC is a Python package that enhances the efficiency of solving Quadratic Programming (QP) problems in Model Predictive Control (MPC) using transformer-based neural networks. It employs two specialized transformer models:

  1. Constraint Predictor: Identifies inactive constraints in QP formulations
  2. Warm Start Predictor: Generates better initial points for QP solvers

By combining these models, TransformerMPC significantly reduces computation time while maintaining solution quality.

Installation

Install directly from PyPI:

pip install transformermpc

Or install from source:

git clone https://github.com/Vrushabh27/transformermpc.git
cd transformermpc
pip install -e .

Dependencies

  • Python >= 3.7
  • PyTorch >= 1.9.0
  • Transformers >= 4.15.0
  • OSQP >= 0.6.2
  • Additional dependencies specified in requirements.txt

Running the Demo

The package includes a simplified demo script that demonstrates the complete workflow:

python simple_demo.py

This script performs the entire pipeline: generating QP problems, training models, and evaluating performance. After completion, it saves performance comparison plots in the demo_results/results directory.

Additional Scripts

The package also includes these utility scripts:

  1. run_demo.py: A wrapper script that executes the main demo module. Use it when working with the installed package:

    python run_demo.py
    
  2. test_benchmark.py: For comprehensive performance evaluation with more metrics and advanced visualizations:

    python test_benchmark.py
    

    This script focuses on benchmarking performance across multiple problems and generates detailed visualizations, including boxplots comparing solve times between the standard OSQP solver and transformer-enhanced methods.

All scripts save their results (including performance metrics and visualization plots) in the demo_results/results directory. The key visualizations include:

  • solve_time_comparison.png: Line plot comparing baseline vs. transformer solve times
  • solve_time_boxplot.png: Statistical distribution of solve times across different solver configurations

Customizing Demo Parameters

You can customize the demo by modifying the following parameters:

Data Generation Parameters:

# Generate 5000 QP problems with state dimension 6, input dimension 3, and horizon 15
python simple_demo.py --num_samples 5000 --state_dim 6 --input_dim 3 --horizon 15

Training Parameters:

# Train the constraint predictor for 200 epochs and warm start predictor for 300 epochs
python simple_demo.py --cp_epochs 200 --ws_epochs 300 --batch_size 128

The number of epochs and samples significantly impact training time and model performance:

  • --cp_epochs: Number of training epochs for the Constraint Predictor model
  • --ws_epochs: Number of training epochs for the Warm Start Predictor model
  • --num_samples: Number of QP problems to generate for training

Increasing these values will generally improve model accuracy but require more computation time. For quick experimentation, use lower values (e.g., 50-100 epochs, 1000-2000 samples). For production-quality models, consider higher values (300+ epochs, 5000+ samples).

Hardware Options:

# Use GPU for training if available
python simple_demo.py --use_gpu

Other Options:

# Skip training and use pre-trained models if available
python simple_demo.py --skip_training

# Specify custom output directory
python simple_demo.py --output_dir custom_results

All Available Options

Parameter Description Default
--num_samples Number of QP problems to generate 2000
--state_dim State dimension for MPC problems 4
--input_dim Input dimension for MPC problems 2
--horizon Time horizon for MPC problems 10
--cp_epochs Epochs for constraint predictor training 100
--ws_epochs Epochs for warm start predictor training 100
--batch_size Batch size for training 64
--test_size Fraction of data to use for testing 0.2
--output_dir Directory to save results "demo_results"
--skip_training Skip training and use pretrained models False
--use_gpu Use GPU if available False

Usage in Projects

Basic Example

from transformermpc import TransformerMPC
import numpy as np

# Define your QP problem parameters
Q = np.array([[4.0, 1.0], [1.0, 2.0]])
c = np.array([1.0, 1.0])
A = np.array([[-1.0, 0.0], [0.0, -1.0], [-1.0, -1.0], [1.0, 1.0]])
b = np.array([0.0, 0.0, -1.0, 2.0])

# Initialize the TransformerMPC solver
solver = TransformerMPC()
# Note: By default, this uses pre-trained models that come with the package.
# To train custom models with different epochs and samples, use the demo script
# or the training utilities as described in the "Customizing Demo Parameters" section.

# Solve with transformer acceleration
solution, solve_time = solver.solve(Q, c, A, b)

print(f"Solution: {solution}")
print(f"Solve time: {solve_time} seconds")

Advanced Usage

from transformermpc import TransformerMPC, QPProblem

# Create a QP problem instance
qp_problem = QPProblem(
    Q=Q,
    c=c,
    A=A,
    b=b,
    initial_state=x0
)

# Initialize with custom settings
solver = TransformerMPC(
    use_constraint_predictor=True,
    use_warm_start_predictor=True,
    fallback_on_violation=True
)

# Solve the problem
solution = solver.solve(qp_problem)

# Compare with baseline
baseline_solution = solver.solve_baseline(qp_problem)

License

This project is licensed under the MIT License - see the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

transformermpc-0.1.1.tar.gz (4.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

transformermpc-0.1.1-py3-none-any.whl (60.6 kB view details)

Uploaded Python 3

File details

Details for the file transformermpc-0.1.1.tar.gz.

File metadata

  • Download URL: transformermpc-0.1.1.tar.gz
  • Upload date:
  • Size: 4.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for transformermpc-0.1.1.tar.gz
Algorithm Hash digest
SHA256 0f0b98119eb45fae1fc5c15f4ba1acf29d33403e04047b1317562f76851cab3c
MD5 b723bd5a62ec51e7cad23db74a2bb0e5
BLAKE2b-256 d45184ba37bd86f827c6cedeb9e543cc1449a804de57e634c6caee941a7c139e

See more details on using hashes here.

File details

Details for the file transformermpc-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: transformermpc-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 60.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.5

File hashes

Hashes for transformermpc-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 b50c6c99ab364fa7b10988a21dcdfba320dadf628a745b8d90ce8c2b0cd432fe
MD5 a4002349d168ab8cdf8c25446db81353
BLAKE2b-256 9d5286a620e043a6fcfab6cf7c065fcbe91c0e01d82907e2a78f74b68024d187

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page