A continual learning PyTorch package
Project description
ContinualFlame
Small lightweight package for Continual Learning in PyTorch.
Installation
For now the package is hosted on TestPyPi. To install it you just need to run:
pip install continual-flame
Usage
To use the package you just need to import it inside your project.
import contflame as cf
At the moment the package contains just the dataset module.
Dataset
This module contains datasets normally used in the continual learning scenario. The main ones are:
- SplitMNIST - MNIST dataset split in classes. It allows to create different subtasks by including custom subsets of classes.
- PermutedMNIST - permuted MNIST dataset. It allows to choose the shape of the applied permutation.
- SplitCIFAR100
- PermutedCIFAR100
Examples
SplitMNIST
In the following example the training tasks are five binary classification tasks on subsequent pairs of digit (i.e task 1 (0, 1), task 2 (2, 3), ...)
from cont_flame.dataset import SplitMNIST
valid = []
for i in range(1, 10, 2)
train_dataset = SplitMNIST(classes=[i, i+1], dset='train', valid=0.2)
valid.append(SplitMNIST(classes=[i, i+1], dset='valid', valid=0.2))
for e in epochs:
# train the model on train_dataset
# ...
for v in valid:
# test the model on the current and the previous tasks
# ...
PermutedMNIST
To get a random permutation set tile to (1, 1). The same random permutation, selected by the task id, will be applied to all the data points.
PermutedMNIST(tile=(1, 1), task=1)
You can also apply the permutation row (or column) wise by setting the corresponding dimension of the tile equal to the one of the image
PermutedMNIST(tile=(1, 28), task=1)
Or try to maintain high level spatial feature by setting a bigger tile.
PermutedMNIST(tile=(8, 8), task=1)
To get the images without any permutation set the tile to (28, 28) (default value).
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for continual_flame-1.0-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 54445c365d41014d781a936166cfcaf4df031f7e37cd2cbb98ef74470ffd40e7 |
|
MD5 | 4edca8513d85a12d334f8b03af3c5cfa |
|
BLAKE2b-256 | a7c44943662f2d1e513ce7314db9bbc42e9b5d6a43d5626030ebb881356d1bcd |