A pytorch model training protocol for environment invariant deployment
Project description
This is a PyTorch implementation of the Inter-environmental Gradient Alignment algorithm proposed by Koyama and Yamaguchi in their paper Out-of-Distribution Generalization with Maximal Invariant Predictor
Quick start
Install pytorch-iga in the terminal:
pip install pytorch-iga
Import IGA in python:
from iga import IGA
IGA is defined with the following parameters:
IGA(model, optimizer, criterion, data, num_epochs, batch_size, lamda, verbose=10, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
Parameters:
model (torch.nn.Module): neural network model to be trained/tuned optimizer (torch.optim): pytorch optimizer object such as torch.optim.SGD criterion (function): loss function for model evaluation data (list(torch.utils.Dataset)): a list of Datasets for each environment num_epochs (int): number of training epochs batch_size (int): number of data points per batch lamda (float): importance weight of inter-environmental variance verbose (int): number of iterations in each progress log device (torch.device): optional, torch.device object, defaults to 'cuda' or 'cpu'
Returns:
model (torch.nn.Module): updated torch model IGA_loss (float): ending loss value
Example
to be continued...
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pytorch_iga-0.0.3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fa7d83670f3c9d965c6bf58f2f18d7f2fc6197e2d05814fedf097bbe8d704e39 |
|
MD5 | 4212db0cf4ffb0e8cf50a4f7bff6f2e8 |
|
BLAKE2b-256 | 1e6bde1917be8a4381353c7a1c6dcd42d5559c1ca44ddb835c6d5ec4c8f5ce71 |