Skip to main content

TrAct: Training Activations

Project description

TrAct - Training Activations

Official implementation for our NeurIPS 2024 paper TrAct: Making First-layer Pre-Activations Trainable. In this work, we provide TrAct, a method for more effective and efficient training of the first layer that accelerates training by between 25% and 300%. Herein, rather than using the training dynamics of training weight, we provide a closed form solution for training activations by indirectly updating weights, leading to faster and better training.

Video @ YouTube.

💻 Installation

TrAct can be installed via pip from PyPI with

pip install torch_tract

Alternatively, it is sufficient to copy the src/tract.py into your existing project.

👩‍💻 Usage

The tract.py file contains the TrAct wrapper, which replaces torch.nn.Linear and torch.nn.Conv2d modules by TrActLinear and TrActConv2d modules, respectively.

After initialization, simply wrap your first layer in a TrAct(layer, l=l), wherein l / λ is the only hyperparameter. The default for this is l=0.1 and performs very well. Please refer to the paper for additional information on the hyperparameter.

In the script, TrAct is implemented via:

# regular initialization of model
model = resnet18(num_classes=num_classes)
# apply TrAct to first layer
model.conv1 = TrAct(model.conv1, l=args.l)

Using the same change, it can also be applied to, e.g., ResNet training on ImageNet. Analogously, it can be applied for other codebases and experiments, e.g.:

# For CIFAR ViT: https://github.com/omihub777/ViT-CIFAR/
model.emb = TrAct(model.emb, l=args.l)
# For DeiT: https://github.com/facebookresearch/deit/
model.patch_embed.proj = TrAct(model.patch_embed.proj, l=args.l)
To reproduce the first seed of Figure 1, run:
python train_cifar.py --method normal        --n_epochs 100 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method normal        --n_epochs 200 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method normal        --n_epochs 400 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method normal        --n_epochs 800 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 100 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 200 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 400 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 800 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0

python train_cifar.py --method normal        --n_epochs 100 --lr 0.010 --optim adam_cosine --seed 0
python train_cifar.py --method normal        --n_epochs 200 --lr 0.010 --optim adam_cosine --seed 0
python train_cifar.py --method normal        --n_epochs 400 --lr 0.001 --optim adam_cosine --seed 0
python train_cifar.py --method normal        --n_epochs 800 --lr 0.001 --optim adam_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 100 --lr 0.010 --optim adam_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 200 --lr 0.010 --optim adam_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 400 --lr 0.001 --optim adam_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 800 --lr 0.001 --optim adam_cosine --seed 0

📖 Citing

@inproceedings{petersen2024tract,
  title={TrAct: Making First-layer Pre-Activations Trainable},
  author={Petersen, Felix and Borgelt, Christian and Ermon, Stefano},
  booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
  year={2024}
}

License

TrAct is released under the MIT license. See LICENSE for additional details about it.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch_tract-1.0.0.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

torch_tract-1.0.0-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file torch_tract-1.0.0.tar.gz.

File metadata

  • Download URL: torch_tract-1.0.0.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for torch_tract-1.0.0.tar.gz
Algorithm Hash digest
SHA256 cf8d245d3738afbce9c07157fb9617ffe6bf418685094a52f27fccb34320378b
MD5 fd1bfa8e4c2b32eff6a5512a5f3db5e2
BLAKE2b-256 d96193d3909492bbb173b96fce3e381fa29bcb6578f1f9980caf974a19d5a7a7

See more details on using hashes here.

File details

Details for the file torch_tract-1.0.0-py3-none-any.whl.

File metadata

  • Download URL: torch_tract-1.0.0-py3-none-any.whl
  • Upload date:
  • Size: 6.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.9.6

File hashes

Hashes for torch_tract-1.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 87de498da4c54eae0702d94aaa4d5cee542ddaffd39e113ed51bd9096a92aba3
MD5 a91556888a87b35b564825c565e14f4a
BLAKE2b-256 1e0064645611f74fceb20fe6b1724ee670124316eda4423ae8981f94892476fa

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page