TrAct: Training Activations
Project description
TrAct - Training Activations
Official implementation for our NeurIPS 2024 paper TrAct: Making First-layer Pre-Activations Trainable. In this work, we provide TrAct, a method for more effective and efficient training of the first layer that accelerates training by between 25% and 300%. Herein, rather than using the training dynamics of training weight, we provide a closed form solution for training activations by indirectly updating weights, leading to faster and better training.
Video @ YouTube.
💻 Installation
TrAct can be installed via pip from PyPI with
pip install torch_tract
Alternatively, it is sufficient to copy the src/tract.py into your existing project.
👩💻 Usage
The tract.py file contains the TrAct wrapper, which replaces torch.nn.Linear and torch.nn.Conv2d modules by TrActLinear and TrActConv2d modules, respectively.
After initialization, simply wrap your first layer in a TrAct(layer, l=l), wherein l / λ is the only hyperparameter.
The default for this is l=0.1 and performs very well.
Please refer to the paper for additional information on the hyperparameter.
In the script, TrAct is implemented via:
# regular initialization of model
model = resnet18(num_classes=num_classes)
# apply TrAct to first layer
model.conv1 = TrAct(model.conv1, l=args.l)
Using the same change, it can also be applied to, e.g., ResNet training on ImageNet. Analogously, it can be applied for other codebases and experiments, e.g.:
# For CIFAR ViT: https://github.com/omihub777/ViT-CIFAR/
model.emb = TrAct(model.emb, l=args.l)
# For DeiT: https://github.com/facebookresearch/deit/
model.patch_embed.proj = TrAct(model.patch_embed.proj, l=args.l)
To reproduce the first seed of Figure 1, run:
python train_cifar.py --method normal --n_epochs 100 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method normal --n_epochs 200 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method normal --n_epochs 400 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method normal --n_epochs 800 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 100 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 200 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 400 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 800 --lr 0.08 --optim sgd_w_momentum_cosine --seed 0
python train_cifar.py --method normal --n_epochs 100 --lr 0.010 --optim adam_cosine --seed 0
python train_cifar.py --method normal --n_epochs 200 --lr 0.010 --optim adam_cosine --seed 0
python train_cifar.py --method normal --n_epochs 400 --lr 0.001 --optim adam_cosine --seed 0
python train_cifar.py --method normal --n_epochs 800 --lr 0.001 --optim adam_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 100 --lr 0.010 --optim adam_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 200 --lr 0.010 --optim adam_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 400 --lr 0.001 --optim adam_cosine --seed 0
python train_cifar.py --method tract --l 0.1 --n_epochs 800 --lr 0.001 --optim adam_cosine --seed 0
📖 Citing
@inproceedings{petersen2024tract,
title={TrAct: Making First-layer Pre-Activations Trainable},
author={Petersen, Felix and Borgelt, Christian and Ermon, Stefano},
booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
year={2024}
}
License
TrAct is released under the MIT license. See LICENSE for additional details about it.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file torch_tract-1.0.0.tar.gz.
File metadata
- Download URL: torch_tract-1.0.0.tar.gz
- Upload date:
- Size: 6.4 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
cf8d245d3738afbce9c07157fb9617ffe6bf418685094a52f27fccb34320378b
|
|
| MD5 |
fd1bfa8e4c2b32eff6a5512a5f3db5e2
|
|
| BLAKE2b-256 |
d96193d3909492bbb173b96fce3e381fa29bcb6578f1f9980caf974a19d5a7a7
|
File details
Details for the file torch_tract-1.0.0-py3-none-any.whl.
File metadata
- Download URL: torch_tract-1.0.0-py3-none-any.whl
- Upload date:
- Size: 6.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.9.6
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
87de498da4c54eae0702d94aaa4d5cee542ddaffd39e113ed51bd9096a92aba3
|
|
| MD5 |
a91556888a87b35b564825c565e14f4a
|
|
| BLAKE2b-256 |
1e0064645611f74fceb20fe6b1724ee670124316eda4423ae8981f94892476fa
|