Gradient Reversal Layer implemented with torch.library
Project description
PyTorch Gradient Reversal Layer
Apparrently there are no graphable implementations of the famous DANN paper's Gradient Reversal Layer in Torch. This package implements the Gradient Reversal Layer (GRL) in PyTorch using the torch.library API. It is fully compatible with torch.compile, CUDA graphs, and distributed training as of Torch v2.7. I am releasing this so no one else will need to experience the pain and suffering to get this right; expect limited updates for future versions.
Installation
From PyPI (Recommended)
pip install torch-gradient-reversal
From Source
git clone https://github.com/yourusername/torch-gradient-reversal.git
cd torch-gradient-reversal
pip install -e .
Basic Usage
import torch
from gradient_reversal import GradientReversalLayer
# Create a gradient reversal layer
grl = GradientReversalLayer(alpha=1.0)
# Use in a model
model = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
grl, # Reverses gradients during backprop
torch.nn.Linear(20, 2)
)
# Forward pass works normally
x = torch.randn(32, 10)
output = model(x)
# During backward pass, gradients are reversed and scaled by alpha
loss = output.sum()
loss.backward()
Domain Adaptation Example
import torch.nn as nn
from gradient_reversal import GradientReversalLayer
class DomainAdaptationModel(nn.Module):
def __init__(self):
super().__init__()
# Shared feature extractor
self.feature_extractor = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128)
)
# Task classifier (for main task)
self.task_classifier = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10) # 10 classes
)
# Domain classifier (with gradient reversal)
self.domain_classifier = nn.Sequential(
GradientReversalLayer(alpha=1.0),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 2) # Binary: source/target
)
def forward(self, x):
features = self.feature_extractor(x)
task_output = self.task_classifier(features)
domain_output = self.domain_classifier(features)
return task_output, domain_output
Dynamic Alpha Scheduling
# Gradually increase gradient reversal strength during training
for epoch in range(num_epochs):
# Schedule alpha from 0 to 1
p = epoch / num_epochs
alpha = 2 / (1 + np.exp(-10 * p)) - 1
# Update the GRL alpha
model.domain_classifier[0].set_alpha(alpha)
# Training loop...
How It Works
The Gradient Reversal Layer reverses and scales the backprop gradient as described in "Unsupervised Domain Adaptation by Backpropagation" by Ganin et al.
- Forward Pass: Acts as an identity function (output = input)
- Backward Pass: Reverses the gradient and scales by -alpha
This enables theoretically concurrent adversarial training with rich shared signals.
Using with torch.compile
import torch
# Model with GRL
model = create_model_with_grl()
# Compile the model for faster execution
compiled_model = torch.compile(model)
# Use as normal
output = compiled_model(input)
Distributed Training
# Works seamlessly with DDP
model = DomainAdaptationModel()
model = torch.nn.parallel.DistributedDataParallel(model)
API Reference
GradientReversalLayer
class GradientReversalLayer(nn.Module):
"""Gradient Reversal Layer for domain adaptation.
Args:
alpha (float): Gradient scaling factor. Default: 1.0
"""
def forward(self, x: Tensor) -> Tensor:
"""Forward pass (identity function)."""
def set_alpha(self, alpha: float) -> None:
"""Update gradient scaling factor."""
Functional Interface
def gradient_reversal(x: Tensor, alpha: float = 1.0) -> Tensor:
"""Functional gradient reversal operation."""
License
See the LICENSE file for details.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_gradient_reversal-0.1.0.tar.gz.
File metadata
- Download URL: torch_gradient_reversal-0.1.0.tar.gz
- Upload date:
- Size: 5.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ec3b723416f207994eabcbafc2abedf00f46d26b4f1e446fca6ac943103777af
|
|
| MD5 |
d7263beb1310999a616d933a86127e03
|
|
| BLAKE2b-256 |
d66a009a7ca1961dfef6b28fe25b56cffd49f252d3a9f4410f1a3af4d29f9d49
|
File details
Details for the file torch_gradient_reversal-0.1.0-py3-none-any.whl.
File metadata
- Download URL: torch_gradient_reversal-0.1.0-py3-none-any.whl
- Upload date:
- Size: 6.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.12.0
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
845b5fbc4e1fca0980ed4b9696c298fd9d0f4fdc793b2a6d3ca9b2d34004eb35
|
|
| MD5 |
0821e9da3a365573d789de39a2552f8d
|
|
| BLAKE2b-256 |
842fd0bac5fae25dd90a199122364d89a891642b38b62e71592a15466894877b
|