Skip to main content

Flash is a framework for fast prototyping, finetuning, and solving most standard deep learning challenges

Project description

Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning


InstallationDocsAboutPredictionFinetuningTasksGeneral TaskContributeCommunityWebsiteLicense

Stable API Documentation Status PyPI - Python Version PyPI Status PyPI Status Slack Discourse status license CI testing codecov


Installation

Pip / conda

pip install lightning-flash

Master

pip install git+https://github.com/PytorchLightning/lightning-flash.git@master --upgrade

Source

git clone https://github.com/PyTorchLightning/lightning-flash.git
cd lightning-flash
pip install -e .

What is Flash

Flash is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning. It is focused on:

  • Predictions
  • Finetuning
  • Task-based training

It is built for data scientists, machine learning practitioners, and applied researchers.

Scalability

Flash is built on top of PyTorch Lightning (by the Lightning team), which is a thin organizational layer on top of PyTorch. If you know PyTorch, you know PyTorch Lightning and Flash already!

As a result, Flash can scale up across any hardware (GPUs, TPUS) with zero changes to your code. It also has the best practices in AI research embedded into each task so you don't have to be a deep learning PhD to leverage its power :)

Predictions

# import our libraries
from flash.text import TextClassifier

# 1. Load finetuned task
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado."
    "Very, very afraid"
    "This guy has done a great job with this movie!",
])
print(predictions)

Finetuning

First, finetune:

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    valid_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes, backbone="resnet18")

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")

Then use the finetuned model

# load the finetuned model
classifier = ImageClassifier.load_from_checkpoint('image_classification_model.pt')

# predict!
predictions = classifier.predict('data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)

Tasks

Flash is built as a collection of community-built tasks. A task is highly opinionated and laser-focused on solving a single problem well, using state-of-the-art methods.

Example 1: Image classification

Flash has an ImageClassification task to tackle any image classification problem.

View example To illustrate, Let's say we wanted to develop a model that could classify between ants and bees.

Here we classify ants vs bees.

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    valid_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
model = ImageClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")

# 6. Test the model
trainer.test()

# 7. Predict!
predictions = model.predict([
    "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
    "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
    "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)

To run the example:

python flash_examples/finetuning/image_classifier.py

Example 2: Text Classification

Flash has a TextClassification task to tackle any text classification problem.

View example To illustrate, say you wanted to classify movie reviews as positive or negative.
import flash
from flash import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 2. Load the data
datamodule = TextClassificationData.from_files(
    train_file="data/imdb/train.csv",
    valid_file="data/imdb/valid.csv",
    test_file="data/imdb/test.csv",
    input="review",
    target="sentiment",
    batch_size=512
)

# 3. Build the model
model = TextClassifier(num_classes=datamodule.num_classes)

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze_unfreeze")

# 6. Test model
trainer.test()

# 7. Classify a few sentences! How was the movie?
predictions = model.predict([
    "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
    "The worst movie in the history of cinema.",
    "I come from Bulgaria where it 's almost impossible to have a tornado."
    "Very, very afraid"
    "This guy has done a great job with this movie!",
])
print(predictions)

To run the example:

python flash_examples/finetuning/classify_text.py

Example 3: Tabular Classification

Flash has a TabularClassification task to tackle any tabular classification problem.

View example

To illustrate, say we want to build a model to predict if a passenger survived on the Titanic.

from pytorch_lightning.metrics.classification import Accuracy, Precision, Recall
import flash
from flash import download_data
from flash.tabular import TabularClassifier, TabularData

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')

# 2. Load the data
datamodule = TabularData.from_csv(
    "./data/titanic/titanic.csv",
    test_csv="./data/titanic/test.csv",
    categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    numerical_input=["Fare"],
    target="Survived",
    val_size=0.25,
)

# 3. Build the model
model = TabularClassifier.from_data(datamodule, metrics=[Accuracy(), Precision(), Recall()])

# 4. Create the trainer. Run 10 times on data
trainer = flash.Trainer(max_epochs=10)

# 5. Train the model
trainer.fit(model, datamodule=datamodule)

# 6. Test model
trainer.test()

# 7. Predict!
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)

To run the example:

python flash_examples/finetuning/tabular_data.py

A general task

Flash comes prebuilt with a task to handle a huge portion of deep learning problems.

import flash
from torch import nn, optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import pytorch_lightning as pl

# model
model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# data
dataset = datasets.MNIST('./data_folder', download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

# task
classifier = flash.Task(model, loss_fn=nn.functional.cross_entropy, optimizer=optim.Adam)

# train
flash.Trainer().fit(classifier, DataLoader(train), DataLoader(val))

Infinitely customizable

Tasks can be built in just a few minutes because Flash is built on top of PyTorch Lightning LightningModules, which are infinitely extensible and let you train across GPUs, TPUs etc without doing any code changes.

import torch
import torch.nn.functional as F
from flash.core.classification import ClassificationTask

class LinearClassifier(ClassificationTask):
    def __init__(
        self,
        num_inputs,
        num_classes,
        loss_fn: Callable = F.cross_entropy,
        optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
        metrics: Union[Callable, Mapping, Sequence, None] = [Accuracy()],
        learning_rate: float = 1e-3,
    ):
        super().__init__(
            model=None,
            loss_fn=loss_fn,
            optimizer=optimizer,
            metrics=metrics,
            learning_rate=learning_rate,
        )
        self.save_hyperparameters()

        self.linear = torch.nn.Linear(num_inputs, num_classes)

    def forward(self, x):
        return self.linear(x)

classifier = LinearClassifier()
...

When you reach the limits of the flexibility provided by tasks, then seamlessly transition to PyTorch Lightning which gives you the most flexibility because it is simply organized PyTorch.

Contribute!

The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we're looking for incredible contributors like you to submit new tasks!

Join our Slack to get help becoming a contributor!

Community

For help or questions, join our huge community on Slack!

License

Please observe the Apache 2.0 license that is listed in this repository. In addition the Lightning framework is Patent Pending.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightning-flash-0.1.0.tar.gz (45.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lightning_flash-0.1.0-py3-none-any.whl (83.8 kB view details)

Uploaded Python 3

File details

Details for the file lightning-flash-0.1.0.tar.gz.

File metadata

  • Download URL: lightning-flash-0.1.0.tar.gz
  • Upload date:
  • Size: 45.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.0.0 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.8.7

File hashes

Hashes for lightning-flash-0.1.0.tar.gz
Algorithm Hash digest
SHA256 1dfdada660900188afdaee1a73d0b76509acc8b5c782df66f3fcfd7ea14c5064
MD5 7b10540dc151adeb7a495980f83bee9b
BLAKE2b-256 3bbe99912301746065098a8a90b3573e07bef50106bfd053757e766b76b141da

See more details on using hashes here.

File details

Details for the file lightning_flash-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: lightning_flash-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 83.8 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/53.0.0 requests-toolbelt/0.9.1 tqdm/4.56.0 CPython/3.8.7

File hashes

Hashes for lightning_flash-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a62452e8b21f7727a6bdae9edb46785df3bbcf0a1f28de20e721df1c103131ed
MD5 1475cd1bc31bf119043881a7e847d31e
BLAKE2b-256 e7a83e5d558cf9e37b4b785221006fe1d6877e78a20787520f17e34e4f4b7fba

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page