Skip to main content

Ready-to-use PyTorch code to boost your way into few-shot image classification

Project description

Easy Few-Shot Learning

Python Versions CircleCI Code style: black License: MIT Open In Colab

Ready-to-use code and tutorial notebooks to boost your way into few-shot image classification. This repository is made for you if:

  • you're new to few-shot learning and want to learn;
  • or you're looking for reliable, clear and easily usable code that you can use for your projects.

Don't get lost in large repositories with hundreds of methods and no explanation on how to use them. Here, we want each line of code to be covered by a tutorial.

What's in there?

Notebooks: learn and practice

You want to learn few-shot learning and don't know where to start? Start with our tutorial.

Code that you can use and understand

Models:

Tools for data loading:

  • EasySet: a ready-to-use Dataset object to handle datasets of images with a class-wise directory split
  • TaskSampler: samples batches in the shape of few-shot classification tasks

Datasets to test your model

QuickStart

  1. Install the package with pip:

pip install git+https://github.com/sicara/easy-few-shot-learning.git

Note: alternatively, you can clone the repository so that you can modify the code as you wish.

  1. Download CU-Birds and the few-shot train/val/test split:
mkdir -p data/CUB && cd data/CUB
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1GDr1OkoXdhaXWGA8S3MAq3a522Tak-nx" -O images.tgz
rm -rf /tmp/cookies.txt
tar  --exclude='._*' -zxvf images.tgz
wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/CUB/train.json
wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/CUB/val.json
wget https://raw.githubusercontent.com/sicara/easy-few-shot-learning/master/data/CUB/test.json
cd ...
  1. Check that you have a 680,9MB images folder in ./data/CUB along with three JSON files.

  2. From the training subset of CUB, create a dataloader that yields few-shot classification tasks:

from easyfsl.data_tools import EasySet, TaskSampler
from torch.utils.data import DataLoader

train_set = EasySet(specs_file="./data/CUB/train.json", training=True)
train_sampler = TaskSampler(
    train_set, n_way=5, n_shot=5, n_query=10, n_tasks=40000
)
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=12,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
  1. Create and train a model
from easyfsl.methods import PrototypicalNetworks
from torch import nn
from torch.optim import Adam
from torchvision.models import resnet18

convolutional_network = resnet18(pretrained=False)
convolutional_network.fc = nn.Flatten()
model = PrototypicalNetworks(convolutional_network).cuda()

optimizer = Adam(params=model.parameters())

model.fit(train_loader, optimizer)

Troubleshooting: a ResNet18 with a batch size of (5 * (5+10)) = 75 whould use about 4.2GB on your GPU. If you don't have it, switch to CPU, choose a smaller model or reduce the batch size (in TaskSampler above).

  1. Evaluate your model on the test set
test_set = EasySet(specs_file="./data/CUB/test.json", training=False)
test_sampler = TaskSampler(
    test_set, n_way=5, n_shot=5, n_query=10, n_tasks=100
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=12,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

model.evaluate(test_loader)

Roadmap

  • Implement unit tests
  • Add validation to AbstractMetaLearner.fit()
  • Integrate more methods:
    • Matching Networks
    • Relation Networks
    • MAML
    • Transductive Propagation Network
  • Integrate non-episodic training
  • Integrate more benchmarks:
    • miniImageNet
    • tieredImageNet
    • Meta-Dataset

Contribute

This project is very open to contributions! You can help in various ways:

  • raise issues
  • resolve issues already opened
  • tackle new features from the roadmap
  • fix typos, improve code quality

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easyfsl-0.1.0.tar.gz (10.8 kB view hashes)

Uploaded Source

Built Distribution

easyfsl-0.1.0-py3-none-any.whl (10.7 kB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page