Skip to main content

A toolbox for using Constraint Guided Gradient Descent when training neural networks.

Project description

Congrads

Congrads is a Python toolbox that brings constraint-guided gradient descent capabilities to your machine learning projects. Built with seamless integration into PyTorch and PyTorch Lightning, Congrads empowers you to enhance the training and optimization process by incorporating constraints into your training pipeline.

Whether you're working with simple inequality constraints, combinations of input-output relations, or custom constraint formulations, Congrads provides the tools and flexibility needed to build more robust and generalized models.

Note: The Congrads toolbox is currently in alpha phase. Expect significant changes, potential bugs, and incomplete features as we continue to develop and improve the functionality. Feedback is highly appreciated during this phase to help us refine the toolbox and ensure its reliability in later stages.

Key Features

  • Constraint-Guided Training: Add constraints to guide the optimization process, ensuring that your model generalizes better by trying to satisfy the constraints.
  • Flexible Constraint Definition: Define constraints on inputs, outputs, or combinations thereof, using an intuitive and extendable interface. Make use of pre-programmed constraint classes or write your own.
  • Seamless PyTorch Integration: Use Congrads within your existing PyTorch workflows with minimal setup.
  • PyTorch Lightning Support: Easily plug into PyTorch Lightning projects for scalable and structured model training.
  • Flexible and extendible: Write your own custom networks, constraints and dataset classes to easily extend the functionality of the toolbox.

Installation

Currently, the Congrads toolbox can only be installed using pip. We will later expand to other package managers such as conda.

pip install congrads

Getting Started

1. Prerequisites

Before you can use Congrads, make sure you have the following installed:

2. Installation

Please install Congrads via pip:

pip install congrads

3. Basic Usage

1. Import the toolbox

from congrads.descriptor import Descriptor
from congrads.constraints import ScalarConstraint, BinaryConstraint
from congrads.learners import Learner

2. Instantiate and configure descriptor

The descriptor describes your specific use-case. It assigns names to specific neurons so you can easily reference them when defining constraints. By settings flags, you can specifiy if a layer is fixed or if it is an output layer.

# Descriptor setup
descriptor = Descriptor()
descriptor.add("input", ["I1", "I2", "I3", "I4"], constant=True)
descriptor.add("output", ["O1", "O2"], output=True)

3. Define constraints on your network

You can define constraints on your network using the names previously configured in the descriptor. A set of predefined constraint classes can be used to define inequalities on input or output data.

# Constraints definition
Constraint.descriptor = descriptor
constraints = [
    ScalarConstraint("O1", gt, 0),      # O1 > 0
    BinaryConstraint("O1", le, "O2"),   # O1 <= O2
]

4. Adjust network

Your regular Pytorch network can be used with this toolbox. We only require that the output of your model's forward pass is a dictionary of layers. The keys must match the descriptor settings.

def forward(self, X):
    input = X
    output = self.out(self.hidden(self.input(X)))

    return {"input": input, "output": output}

You then can use your own network and directly assign it to the learner.

5. Set up network and data

Next, instantiate the adjusted network and the data. At the moment, we require the data to be implemented as a LightningDataModule class.

# Data and network setup
network = YourOwnNetwork(n_inputs=4, n_outputs=2, n_hidden_layers=3, hidden_dim=10)
data = YourOwnData(batch_size=100)

6. Set up learner

You can specify your own loss function and optimizer with their own settings to be used for learning the model.

# Learner setup
loss_function = MSELoss()
optimizer = Adam(network.parameters(), lr=0.001)

learner = Learner(network, descriptor, constraints, loss_function, optimizer)

7. Set up trainer

Finally, set up a trainer to start the actual training of the model.

# Trainer setup
trainer = Trainer(max_epochs=100)

# Train model
trainer.fit(learner, data)

Example Use Cases

  • Optimization with Domain Knowledge: Ensure outputs meet real-world restrictions or safety standards.
  • Physics-Informed Neural Networks (PINNs): Enforce physical laws as constraints in your models.
  • Improve Training Process: Inject domain knowledge in the training stage, increasing learning efficiency.

Roadmap

  • Documentation and Notebook examples
  • Add support for constraint parser that can interpret equations
  • Add better handling of metric logging and visualization
  • Revise if Pytorch Lightning is preferable over plain Pytorch
  • Determine if it is feasible to add unit and or functional tests

Contributing

We welcome contributions to Congrads! Whether you want to report issues, suggest features, or contribute code via issues and pull requests.

License

Congrads is licensed under the MIT License with a Commons Clause. This means you are free to use, modify, and distribute the software, but you may not sell or offer it as part of a paid service without permission. We encourage companies that are interested in a collaboration for a specific topic to contact the authors for more information.


Elevate your neural networks with Congrads! 🚀

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

congrads-0.1.0.tar.gz (23.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

congrads-0.1.0-py3-none-any.whl (23.6 kB view details)

Uploaded Python 3

File details

Details for the file congrads-0.1.0.tar.gz.

File metadata

  • Download URL: congrads-0.1.0.tar.gz
  • Upload date:
  • Size: 23.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.10

File hashes

Hashes for congrads-0.1.0.tar.gz
Algorithm Hash digest
SHA256 3be8cf905d7cbbb220cc38eea325728bbc58d262dd8bac2d0cad8a000b031688
MD5 3959e6accc53c664bfbd9ec3e807789b
BLAKE2b-256 ca3fc46e4c1593d9a1ac754d745bd3e27b3055c48ee3d1f80abfd298ddffeae0

See more details on using hashes here.

File details

Details for the file congrads-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: congrads-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 23.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.1.1 CPython/3.11.10

File hashes

Hashes for congrads-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f08eba2d208410bbfbf82cfe7e5bf07c2e056488f3c2b6517421d8bed309b6f1
MD5 c3288dc8964edc684c52c9a899919a87
BLAKE2b-256 5993b0d44c97f4b74559b98f64882e273ac38342a6b75e8b07e81fb7e05c771e

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page