Skip to main content

Gradient Reversal Layer implemented with torch.library

Project description

PyTorch Gradient Reversal Layer

PyPI version

Apparrently there are no graphable implementations of the famous DANN paper's Gradient Reversal Layer in Torch. This package implements the Gradient Reversal Layer (GRL) in PyTorch using the torch.library API. It is fully compatible with torch.compile, CUDA graphs, and distributed training as of Torch v2.7. I am releasing this so no one else will need to experience the pain and suffering to get this right; expect limited updates for future versions.

Installation

From PyPI (Recommended)

pip install torch-gradient-reversal

From Source

git clone https://github.com/yourusername/torch-gradient-reversal.git
cd torch-gradient-reversal
pip install -e .

Basic Usage

import torch
from gradient_reversal import GradientReversalLayer

# Create a gradient reversal layer
grl = GradientReversalLayer(alpha=1.0)

# Use in a model
model = torch.nn.Sequential(
    torch.nn.Linear(10, 20),
    torch.nn.ReLU(),
    grl,  # Reverses gradients during backprop
    torch.nn.Linear(20, 2)
)

# Forward pass works normally
x = torch.randn(32, 10)
output = model(x)

# During backward pass, gradients are reversed and scaled by alpha
loss = output.sum()
loss.backward()

Domain Adaptation Example

import torch.nn as nn
from gradient_reversal import GradientReversalLayer

class DomainAdaptationModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Shared feature extractor
        self.feature_extractor = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
        # Task classifier (for main task)
        self.task_classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)  # 10 classes
        )
        
        # Domain classifier (with gradient reversal)
        self.domain_classifier = nn.Sequential(
            GradientReversalLayer(alpha=1.0),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2)  # Binary: source/target
        )
    
    def forward(self, x):
        features = self.feature_extractor(x)
        task_output = self.task_classifier(features)
        domain_output = self.domain_classifier(features)
        return task_output, domain_output

Dynamic Alpha Scheduling

# Gradually increase gradient reversal strength during training
for epoch in range(num_epochs):
    # Schedule alpha from 0 to 1
    p = epoch / num_epochs
    alpha = 2 / (1 + np.exp(-10 * p)) - 1
    
    # Update the GRL alpha
    model.domain_classifier[0].set_alpha(alpha)
    
    # Training loop...

How It Works

The Gradient Reversal Layer reverses and scales the backprop gradient as described in "Unsupervised Domain Adaptation by Backpropagation" by Ganin et al.

  • Forward Pass: Acts as an identity function (output = input)
  • Backward Pass: Reverses the gradient and scales by -alpha

This enables theoretically concurrent adversarial training with rich shared signals.

Using with torch.compile

import torch

# Model with GRL
model = create_model_with_grl()

# Compile the model for faster execution
compiled_model = torch.compile(model)

# Use as normal
output = compiled_model(input)

Distributed Training

# Works seamlessly with DDP
model = DomainAdaptationModel()
model = torch.nn.parallel.DistributedDataParallel(model)

API Reference

GradientReversalLayer

class GradientReversalLayer(nn.Module):
    """Gradient Reversal Layer for domain adaptation.
    
    Args:
        alpha (float): Gradient scaling factor. Default: 1.0
    """
    
    def forward(self, x: Tensor) -> Tensor:
        """Forward pass (identity function)."""
        
    def set_alpha(self, alpha: float) -> None:
        """Update gradient scaling factor."""

Functional Interface

def gradient_reversal(x: Tensor, alpha: float = 1.0) -> Tensor:
    """Functional gradient reversal operation."""

License

See the LICENSE file for details.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_gradient_reversal-0.1.0.tar.gz (5.5 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_gradient_reversal-0.1.0-py3-none-any.whl (6.2 kB view details)

Uploaded Python 3

File details

Details for the file torch_gradient_reversal-0.1.0.tar.gz.

File metadata

  • Download URL: torch_gradient_reversal-0.1.0.tar.gz
  • Upload date:
  • Size: 5.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.0

File hashes

Hashes for torch_gradient_reversal-0.1.0.tar.gz
Algorithm Hash digest
SHA256 ec3b723416f207994eabcbafc2abedf00f46d26b4f1e446fca6ac943103777af
MD5 d7263beb1310999a616d933a86127e03
BLAKE2b-256 d66a009a7ca1961dfef6b28fe25b56cffd49f252d3a9f4410f1a3af4d29f9d49

See more details on using hashes here.

File details

Details for the file torch_gradient_reversal-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for torch_gradient_reversal-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 845b5fbc4e1fca0980ed4b9696c298fd9d0f4fdc793b2a6d3ca9b2d34004eb35
MD5 0821e9da3a365573d789de39a2552f8d
BLAKE2b-256 842fd0bac5fae25dd90a199122364d89a891642b38b62e71592a15466894877b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page