Skip to main content

A pytorch model training protocol for environment invariant deployment

Project description

This is a PyTorch implementation of the Inter-environmental Gradient Alignment algorithm proposed by Koyama and Yamaguchi in their paper Out-of-Distribution Generalization with Maximal Invariant Predictor

Quick start

Install pytorch-iga in the terminal:

pip install pytorch-iga

Import IGA in python:

from iga import IGA

IGA is defined with the following parameters:

IGA(model, optimizer, criterion, data, num_epochs, batch_size, lamda, verbose=10, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)

Parameters:

model (torch.nn.Module): neural network model to be trained/tuned optimizer (torch.optim): pytorch optimizer object such as torch.optim.SGD criterion (function): loss function for model evaluation data (list(torch.utils.Dataset)): a list of Datasets for each environment num_epochs (int): number of training epochs batch_size (int): number of data points per batch lamda (float): importance weight of inter-environmental variance verbose (int): number of iterations in each progress log device (torch.device): optional, torch.device object, defaults to 'cuda' or 'cpu'

Returns:

model (torch.nn.Module): updated torch model IGA_loss (float): ending loss value

Example

to be continued...

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-iga-0.0.3.tar.gz (4.3 kB view hashes)

Uploaded Source

Built Distribution

pytorch_iga-0.0.3-py3-none-any.whl (3.8 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page