Skip to main content

HAT (Hard Attention to the Task) Modules for Continual Learning

Project description


Redesigned Hard-Attention-to-the-Task for Continual Learning

HAT-CL is a comprehensive reimagining of the Hard-Attention-to-the-Task (HAT) mechanism, designed specifically to combat catastrophic forgetting during Continual Learning (CL). Originally proposed in the paper Overcoming catastrophic forgetting with hard attention to the task, HAT has been instrumental in enabling neural networks to learn successive tasks without erasure of prior knowledge. However, the original implementation had its drawbacks, notably incompatibility with PyTorch's optimizers and the requirement for manual gradient manipulation. HAT-CL aims to rectify these issues with a user-friendly design and a host of new features:

  • Seamless compatibility with all PyTorch operations and optimizers.
  • Automated gradient manipulation through PyTorch hooks.
  • Simple transformation of PyTorch modules to HAT modules with a single line of code.
  • Out-of-the-box HAT networks integrated with timm.

Link to the paper: HAT-CL: A Hard-Attention-to-the-Task PyTorch Library for Continual Learning

Table of Contents

Quick Start


To install via pip:

pip install hat-cl

Or, if you are using poetry:

poetry add hat-cl

Basic Usage

To use HAT modules, swap generic PyTorch modules for their HAT counterparts (for instance, replace torch.nn.Linear with hat.modules.HATLinear. More examples in Modules). HAT modules process hat.HATPayload instances as input and output, containing tensor, task ID, and other HAT-mechanism essential variables.

Here's a simple 2-layer MLP example:

import torch
import torch.nn as nn
from hat import HATPayload, HATConfig
from hat.modules import HATLinear

hat_config = HATConfig(num_tasks=5)

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        self.linear1 = HATLinear(input_dim, hidden_dim, hat_config)
        self.relu = nn.ReLU()
        self.linear2 = HATLinear(hidden_dim, output_dim, hat_config)
    def forward(self, x: HATPayload):
        x = self.linear1(x)
        # You can still pass the payload to the non-HAT modules like this
        x = x.forward_by(self.relu)
        x = self.linear2(x)
        return x
mlp = MLP(input_dim=128, hidden_dim=32, output_dim=2)

input_payload = HATPayload(torch.rand(10, 128), task_id=0, mask_scale=10.0)
output_payload = mlp(input_payload)
output_data =

With these steps, you've created a 2-layer MLP with the HAT mechanism and successfully conducted a forward pass through the model. Just like any other PyTorch modules, it's ready to be trained, evaluated, and more—all under-the-hood operations are handled by the HAT modules.

Additionally, HAT-CL provides ready-to-use HAT networks with timm integration. Creating a HAT model is as simple as creating any other timm model:

import timm
import hat.timm_models  # This line is necessary to register the HAT models to timm
from hat import HATConfig

hat_config = HATConfig(num_tasks=5)
hat_resnet18 = timm.create_model('hat_resnet18', hat_config=hat_config)


Here's a handy table of PyTorch modules and their HAT counterparts:

PyTorch module HAT module
torch.nn.Linear hat.modules.HATLinear
torch.nn.Conv1d hat.modules.HATConv1d
torch.nn.Conv2d hat.modules.HATConv2d
torch.nn.Conv3d hat.modules.HATConv3d
torch.nn.BatchNorm1d hat.modules.TaskIndexedBatchNorm1d
torch.nn.BatchNorm2d hat.modules.TaskIndexedBatchNorm2d
torch.nn.BatchNorm3d hat.modules.TaskIndexedBatchNorm3d
torch.nn.LayerNorm hat.modules.TaskIndexedLayerNorm


Here are the currently available timm-compatible HAT networks:

HAT Network Name Has pretrained weights Description
hat_resnet18 Yes HAT ResNet-18
hat_resnet18s No HAT ResNet-18 for smaller images
hat_resnet34 Yes HAT ResNet-34
hat_resnet34s No HAT ResNet-34 for smaller images
hat_vit_tiny_patch16_224 Yes HAT ViT-Tiny (16, 224)



HAT-CL, while designed for broad compatibility with PyTorch, faces some constraints due to the inherent characteristics of the HAT mechanism:

  • Optimizer Re-initialization: We recommend refreshing the optimizer state after each task to avoid momentum carryover from prior tasks. This can be easily done by re-initializing the optimizer.
  • Weight Decay (L2 Regularization): Weight decay isn't compatible with HAT due to its gradient altering process, which can interfere with parameters meant to be blocked by the HAT mechanism and cause potential forgetting. This includes the weight_decay optimizer parameter, and any optimizer using weight decay, such as AdamW.


  • Add example notebook for pruning
  • Package paper for implementation details
  • Add CLOM notebook example
  • Link PyPI package to GitHub repo


If you use HAT-CL in your research, please cite:

    title={HAT-CL: A Hard-Attention-to-the-Task PyTorch Library for Continual Learning}, 
    author={Xiaotian Duan},


Xiaotian Duan (xduan7 at gmail dot com)

Project details

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hat_cl-0.1.1.tar.gz (44.4 kB view hashes)

Uploaded Source

Built Distribution

hat_cl-0.1.1-py3-none-any.whl (59.1 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page