Skip to main content

HAT (Hard Attention to the Task) Modules for Continual Learning

Project description

HAT-CL

Redesigned Hard-Attention-to-the-Task for Continual Learning

HAT-CL is a comprehensive reimagining of the Hard-Attention-to-the-Task (HAT) mechanism, designed specifically to combat catastrophic forgetting during Continual Learning (CL). Originally proposed in the paper Overcoming catastrophic forgetting with hard attention to the task, HAT has been instrumental in enabling neural networks to learn successive tasks without erasure of prior knowledge. However, the original implementation had its drawbacks, notably incompatibility with PyTorch's optimizers and the requirement for manual gradient manipulation. HAT-CL aims to rectify these issues with a user-friendly design and a host of new features:

  • Seamless compatibility with all PyTorch operations and optimizers.
  • Automated gradient manipulation through PyTorch hooks.
  • Simple transformation of PyTorch modules to HAT modules with a single line of code.
  • Out-of-the-box HAT networks integrated with timm.

Link to the paper: HAT-CL: A Hard-Attention-to-the-Task PyTorch Library for Continual Learning


Table of Contents


Quick Start

Installation

To install via pip:

pip install hat-cl

Or, if you are using poetry:

poetry add hat-cl

Basic Usage

To use HAT modules, swap generic PyTorch modules for their HAT counterparts (for instance, replace torch.nn.Linear with hat.modules.HATLinear. More examples in Modules). HAT modules process hat.HATPayload instances as input and output, containing tensor, task ID, and other HAT-mechanism essential variables.

Here's a simple 2-layer MLP example:

import torch
import torch.nn as nn
from hat import HATPayload, HATConfig
from hat.modules import HATLinear


hat_config = HATConfig(num_tasks=5)

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = HATLinear(input_dim, hidden_dim, hat_config)
        self.relu = nn.ReLU()
        self.linear2 = HATLinear(hidden_dim, output_dim, hat_config)
        
    def forward(self, x: HATPayload):
        x = self.linear1(x)
        # You can still pass the payload to the non-HAT modules like this
        x = x.forward_by(self.relu)
        x = self.linear2(x)
        return x
    
    
mlp = MLP(input_dim=128, hidden_dim=32, output_dim=2)

input_payload = HATPayload(torch.rand(10, 128), task_id=0, mask_scale=10.0)
output_payload = mlp(input_payload)
output_data = output_payload.data

With these steps, you've created a 2-layer MLP with the HAT mechanism and successfully conducted a forward pass through the model. Just like any other PyTorch modules, it's ready to be trained, evaluated, and more—all under-the-hood operations are handled by the HAT modules.

Additionally, HAT-CL provides ready-to-use HAT networks with timm integration. Creating a HAT model is as simple as creating any other timm model:

import timm
import hat.timm_models  # This line is necessary to register the HAT models to timm
from hat import HATConfig

hat_config = HATConfig(num_tasks=5)
hat_resnet18 = timm.create_model('hat_resnet18', hat_config=hat_config)

Modules

Here's a handy table of PyTorch modules and their HAT counterparts:

PyTorch module HAT module
torch.nn.Linear hat.modules.HATLinear
torch.nn.Conv1d hat.modules.HATConv1d
torch.nn.Conv2d hat.modules.HATConv2d
torch.nn.Conv3d hat.modules.HATConv3d
torch.nn.BatchNorm1d hat.modules.TaskIndexedBatchNorm1d
torch.nn.BatchNorm2d hat.modules.TaskIndexedBatchNorm2d
torch.nn.BatchNorm3d hat.modules.TaskIndexedBatchNorm3d
torch.nn.LayerNorm hat.modules.TaskIndexedLayerNorm

Networks

Here are the currently available timm-compatible HAT networks:

HAT Network Name Has pretrained weights Description
hat_resnet18 Yes HAT ResNet-18
hat_resnet18s No HAT ResNet-18 for smaller images
hat_resnet34 Yes HAT ResNet-34
hat_resnet34s No HAT ResNet-34 for smaller images
hat_vit_tiny_patch16_224 Yes HAT ViT-Tiny (16, 224)

Examples


Limitations

HAT-CL, while designed for broad compatibility with PyTorch, faces some constraints due to the inherent characteristics of the HAT mechanism:

  • Optimizer Re-initialization: We recommend refreshing the optimizer state after each task to avoid momentum carryover from prior tasks. This can be easily done by re-initializing the optimizer.
  • Weight Decay (L2 Regularization): Weight decay isn't compatible with HAT due to its gradient altering process, which can interfere with parameters meant to be blocked by the HAT mechanism and cause potential forgetting. This includes the weight_decay optimizer parameter, and any optimizer using weight decay, such as AdamW.

TODO

  • Add example notebook for pruning
  • Package paper for implementation details
  • Add CLOM notebook example
  • Link PyPI package to GitHub repo

Citation

If you use HAT-CL in your research, please cite:

@misc{duan2023hatcl,
    title={HAT-CL: A Hard-Attention-to-the-Task PyTorch Library for Continual Learning}, 
    author={Xiaotian Duan},
    year={2023},
    eprint={2307.09653},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Authors

Xiaotian Duan (xduan7 at gmail dot com)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

hat_cl-0.1.1.tar.gz (44.4 kB view details)

Uploaded Source

Built Distribution

hat_cl-0.1.1-py3-none-any.whl (59.1 kB view details)

Uploaded Python 3

File details

Details for the file hat_cl-0.1.1.tar.gz.

File metadata

  • Download URL: hat_cl-0.1.1.tar.gz
  • Upload date:
  • Size: 44.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.0 CPython/3.9.16 Darwin/23.4.0

File hashes

Hashes for hat_cl-0.1.1.tar.gz
Algorithm Hash digest
SHA256 ace91554ef88f94c2c1dd099db922709dc09ac2982c4f2b511b93796a14b8d1e
MD5 49b4b95cf3cfaef75e25c28902809060
BLAKE2b-256 50d776c144a78906699672c49c011c2a46d4006d3754c09047b6e08c901bc09b

See more details on using hashes here.

File details

Details for the file hat_cl-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: hat_cl-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 59.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.4.0 CPython/3.9.16 Darwin/23.4.0

File hashes

Hashes for hat_cl-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 55e82213f5a00cb7ac4c99d676a0a4d25cddeef3b41fbfb497da95ce41b131e3
MD5 f0ba7fa2a460c136020c6d94fb923af5
BLAKE2b-256 57cc5b5b6385c1cb990d76a375dfc39ea89282980a35a189aa9a68d0bfa5e0c3

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page