Skip to main content

A continual learning PyTorch package

Project description

Build Status

ContinualFlame

Small lightweight package for Continual Learning in PyTorch.

Installation

For now the package is hosted on TestPyPi. To install it you just need to run:

pip install continual-flame

Usage

To use the package you just need to import it inside your project.

import contflame as cf

At the moment the package contains just the dataset module.

Dataset

This module contains datasets normally used in the continual learning scenario. The main ones are:

  • SplitMNIST - MNIST dataset split in classes. It allows to create different subtasks by including custom subsets of classes.
  • PermutedMNIST - permuted MNIST dataset. It allows to choose the shape of the applied permutation.
  • SplitCIFAR100
  • PermutedCIFAR100

Examples

SplitMNIST

In the following example the training tasks are five binary classification tasks on subsequent pairs of digit (i.e task 1 (0, 1), task 2 (2, 3), ...)

from cont_flame.dataset import SplitMNIST

valid = []
for i in range(1, 10, 2)
  train_dataset = SplitMNIST(classes=[i, i+1], dset='train', valid=0.2)
  valid.append(SplitMNIST(classes=[i, i+1], dset='valid', valid=0.2))

  for e in epochs:
    # train the model on train_dataset
    # ...

  for v in valid:
    # test the model on the current and the previous tasks
    # ...

PermutedMNIST

To get a random permutation set tile to (1, 1). The same random permutation, selected by the task id, will be applied to all the data points.

PermutedMNIST(tile=(1, 1), task=1)
You can also apply the permutation row (or column) wise by setting the corresponding dimension of the tile equal to the one of the image
PermutedMNIST(tile=(1, 28), task=1)
Or try to maintain high level spatial feature by setting a bigger tile.
PermutedMNIST(tile=(8, 8), task=1)

To get the images without any permutation set the tile to (28, 28) (default value).

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

continual-flame-1.0.tar.gz (9.7 kB view hashes)

Uploaded Source

Built Distribution

continual_flame-1.0-py3-none-any.whl (21.3 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page