Skip to main content

Gradient Reversal Layer implemented with torch.library

Project description

PyTorch Gradient Reversal Layer

PyPI version License: MIT

Apparrently there are no graphable implementations of the famous DANN paper's Gradient Reversal Layer in Torch. This package implements the Gradient Reversal Layer (GRL) in PyTorch using the torch.library API. It is fully compatible with torch.compile, CUDA graphs, and distributed training as of Torch v2.7. I am releasing this so no one else will need to experience the pain and suffering to get this right; expect limited updates for future versions.

Installation

From PyPI (Recommended)

pip install torch-gradient-reversal

From Source

git clone https://github.com/andrewbistras/torch-gradient-reversal.git
cd torch-gradient-reversal
pip install -e .

Basic Usage

import torch
from gradient_reversal import GradientReversalLayer

# Create a gradient reversal layer
grl = GradientReversalLayer(alpha=1.0)

# Use in a model
model = torch.nn.Sequential(
    torch.nn.Linear(10, 20),
    torch.nn.ReLU(),
    grl,  # Reverses gradients during backprop
    torch.nn.Linear(20, 2)
)

# Forward pass works normally
x = torch.randn(32, 10)
output = model(x)

# During backward pass, gradients are reversed and scaled by alpha
loss = output.sum()
loss.backward()

Domain Adaptation Example

import torch.nn as nn
from gradient_reversal import GradientReversalLayer

class DomainAdaptationModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Shared feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
        # Task classifier (for main task)
        self.task_classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)  # 10 classes
        )
        
        # Domain classifier (with gradient reversal)
        self.domain_classifier = nn.Sequential(
            GradientReversalLayer(alpha=1.0),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2)  # Binary: source/target
        )
    
    def forward(self, x):
        features = self.feature_extractor(x)
        task_output = self.task_classifier(features)
        domain_output = self.domain_classifier(features)
        return task_output, domain_output

Dynamic Alpha Scheduling

# Gradually increase gradient reversal strength during training
for epoch in range(num_epochs):
    # Schedule alpha from 0 to 1
    p = epoch / num_epochs
    alpha = 2 / (1 + np.exp(-10 * p)) - 1
    
    # Update the GRL alpha
    model.domain_classifier[0].set_alpha(alpha)
    
    # Training loop...

How It Works

The Gradient Reversal Layer reverses and scales the backprop gradient as described in "Unsupervised Domain Adaptation by Backpropagation" by Ganin et al.

  • Forward Pass: Acts as an identity function (output = input)
  • Backward Pass: Reverses the gradient and scales by -alpha

This enables theoretically concurrent adversarial training with rich shared signals.

Using with torch.compile

import torch

# Model with GRL
model = create_model_with_grl()

# Compile the model for faster execution
compiled_model = torch.compile(model)

# Use as normal
output = compiled_model(input)

Distributed Training

# Works seamlessly with DDP
model = DomainAdaptationModel()
model = torch.nn.parallel.DistributedDataParallel(model)

API Reference

GradientReversalLayer

class GradientReversalLayer(nn.Module):
    """Gradient Reversal Layer for domain adaptation.
    
    Args:
        alpha (float): Gradient scaling factor. Default: 1.0
    """
    
    def forward(self, x: Tensor) -> Tensor:
        """Forward pass (identity function)."""
        
    def set_alpha(self, alpha: float) -> None:
        """Update gradient scaling factor."""

Functional Interface

def gradient_reversal(x: Tensor, alpha: float = 1.0) -> Tensor:
    """Functional gradient reversal operation."""

License

See the MIT LICENSE for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_gradient_reversal-1.0.0.tar.gz (5.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_gradient_reversal-1.0.0-py3-none-any.whl (6.3 kB view details)

Uploaded Python 3

File details

Details for the file torch_gradient_reversal-1.0.0.tar.gz.

File metadata

  • Download URL: torch_gradient_reversal-1.0.0.tar.gz
  • Upload date:
  • Size: 5.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.0

File hashes

Hashes for torch_gradient_reversal-1.0.0.tar.gz
Algorithm Hash digest
SHA256 8a35b7845190ee92070692bfcee3d6a3d3f36dfe28f43feff2a638264ac489ae
MD5 721726e41850b20818af04899251014a
BLAKE2b-256 2850238d7a3c452f41933a1f7914a04246fa5da569790a45f1e73ac131d21ec4

See more details on using hashes here.

File details

Details for the file torch_gradient_reversal-1.0.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_gradient_reversal-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 13221c2c594756cc7af128a57a984068e8fd7cdec5efa8ac400149f47adb01bf
MD5 943685a7f75da5689e336cc0bce71ced
BLAKE2b-256 f5042348b1841004665d30e6ca8c8fe3b4a0b0014c273b302fa55760fe407953

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page