Skip to main content

A toolbox for using Constraint Guided Gradient Descent when training neural networks.

Reason this release was yanked:

Release yanked because the library is still in active development and backwards compatibility is not guaranteed. Please use the new v0.x series.

Project description

Incorporate constraints into neural network training for more reliable and robust models.


PyPi Read the Docs Python Version: 3.11+ Downloads License



Congrads is a Python toolbox that brings constraint-guided gradient descent capabilities to your machine learning projects. Built with seamless integration into PyTorch, Congrads empowers you to enhance the training and optimization process by incorporating constraints into your training pipeline.

Whether you're working with simple inequality constraints, combinations of input-output relations, or custom constraint formulations, Congrads provides the tools and flexibility needed to build more robust and generalized models.

Key Features

  • Constraint-Guided Training: Add constraints to guide the optimization process, ensuring that your model generalizes better by trying to satisfy the constraints.
  • Flexible Constraint Definition: Define constraints on inputs, outputs, or combinations thereof, using an intuitive and extendable interface. Make use of pre-programmed constraint classes or write your own.
  • Seamless PyTorch Integration: Use Congrads within your existing PyTorch workflows with minimal setup.
  • Flexible and extendible: Write your own custom networks, constraints and dataset classes to easily extend the functionality of the toolbox.

Getting Started

1. Installation

First, make sure to install PyTorch since Congrads heavily relies on its deep learning framework. Please refer to the PyTorch's getting started guide. Make sure to install with CUDA support for GPU training.

Next, install the Congrads toolbox. The recommended way to install it is to use pip:

pip install congrads

You can also install Congrads together with extra packages required to run the examples:

pip install congrads[examples]

This should automatically install all required dependencies for you. If you would like to install dependencies manually, Congrads depends on the following:

  • Python 3.11 - 3.13
  • PyTorch (install with CUDA support for GPU training, refer to PyTorch's getting started guide)
  • NumPy (install with pip install numpy, or refer to NumPy's install guide.)
  • Pandas (install with pip install pandas, or refer to Panda's install guide.)
  • Tqdm (install with pip install tqdm)
  • Torchvision (install with pip install torchvision)
  • Optional: Tensorboard (install with pip install tensorboard)

2. Core concepts

Before diving into the toolbox, it is recommended to familiarize yourself with Congrads's core concept and topics. Please read the documentation at https://congrads.readthedocs.io/en/latest/ to get up-to-date.

3. Basic Usage

Below, a basic example can be found that illustrates how to work with the Congrads toolbox. For additional examples, refer to the examples and notebooks folders in the repository.

1. First, select the device to run your code on with.

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

2. Next, load your data and split it into training, validation and testing subsets.

data = BiasCorrection(
    "./datasets", preprocess_BiasCorrection, download=True
)
loaders = split_data_loaders(
    data,
    loader_args={"batch_size": 100, "shuffle": True},
    valid_loader_args={"shuffle": False},
    test_loader_args={"shuffle": False},
)

3. Instantiate your neural network, make sure the dimensions match up with your data.

network = MLPNetwork(25, 2, n_hidden_layers=3, hidden_dim=35)
network = network.to(device)

4. Choose your loss function and optimizer.

criterion = MSELoss()
optimizer = Adam(network.parameters(), lr=0.001)

5. Then, setup the descriptor, that will attach names to specific parts of your network.

descriptor = Descriptor()
descriptor.add("output", 0, "Tmax")
descriptor.add("output", 1, "Tmin")

6. Define your constraints on the network.

Constraint.descriptor = descriptor
Constraint.device = device
constraints = [
    ScalarConstraint("Tmin", ge, 0),
    ScalarConstraint("Tmin", le, 1),
    ScalarConstraint("Tmax", ge, 0),
    ScalarConstraint("Tmax", le, 1),
    BinaryConstraint("Tmax", gt, "Tmin"),
]

7. Instantiate metric manager and core, and start the training.

metric_manager = MetricManager()
core = CongradsCore(
    descriptor,
    constraints,
    loaders,
    network,
    criterion,
    optimizer,
    metric_manager,
    device,
    checkpoint_manager,
)

core.fit(max_epochs=50)

Example Use Cases

  • Optimization with Domain Knowledge: Ensure outputs meet real-world restrictions or safety standards.
  • Improve Training Process: Inject domain knowledge in the training stage, increasing learning efficiency.
  • Physics-Informed Neural Networks (PINNs): Coming soon, Enforce physical laws as constraints in your models.

Planned changes / Roadmap

  • Add ODE/PDE constraints to support PINNs
  • Rework callback system
  • Add support for constraint parser that can interpret equations

Research

If you make use of this package or it's concepts in your research, please consider citing the following papers.

  • Van Baelen, Q., & Karsmakers, P. (2023). Constraint guided gradient descent: Training with inequality constraints with applications in regression and semantic segmentation. Neurocomputing, 556, 126636. doi:10.1016/j.neucom.2023.126636
    [ pdf | bibtex ]

Contributing

We welcome contributions to Congrads! Whether you want to report issues, suggest features, or contribute code via issues and pull requests.

License

Congrads is licensed under the The 3-Clause BSD License. We encourage companies that are interested in a collaboration for a specific topic to contact the authors for more information or to set up joint research projects.

Contacts

Feel free to contact any of the below contact persons for more information or details about the project. Companies interested in a collaboration, or to set up joint research projects are also encouraged to get in touch with us.

Contributors

Below you find a list of people who contributed in making the toolbox. Feel free to contact them for any repository- or code-specific questions, suggestions or remarks.


Elevate your neural networks with Congrads! 🚀

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

congrads-1.1.0.tar.gz (43.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

congrads-1.1.0-py3-none-any.whl (46.8 kB view details)

Uploaded Python 3

File details

Details for the file congrads-1.1.0.tar.gz.

File metadata

  • Download URL: congrads-1.1.0.tar.gz
  • Upload date:
  • Size: 43.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for congrads-1.1.0.tar.gz
Algorithm Hash digest
SHA256 99c0cc0873068b80d82270a5b1d83c0e4465fcdcabecc8c8f60983853a99935a
MD5 2b3b32903196c5456f796b29b60944b4
BLAKE2b-256 97c7418d30f8659f6438648a689a071aee839f74576e4bf25bf282c9ca2ac9a8

See more details on using hashes here.

File details

Details for the file congrads-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: congrads-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 46.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.8.17

File hashes

Hashes for congrads-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 47e3e3e5ad07f348598a918abb2c1c8c6ecfbbb6283befde468f03af123d2876
MD5 1986f567889170f46aade54d4febcfb6
BLAKE2b-256 0758c381a5b7da68b1e1afd0036d042e8284899988b133a98fe3cc7438a3dfef

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page