Skip to main content

TorchFlare is a simple, beginner-friendly, and easy-to-use PyTorch Framework train your models effortlessly.

Project description

image

PyPI API GitHub release (latest by date) CodeFactor Test Documentation Status Publish-PyPI DeepSource DeepSource codecov made-with-python GitHub license Code style: black

TorchFlare

TorchFlare is a simple, beginner-friendly and an easy-to-use PyTorch Framework to train your models with ease. It provides an almost Keras-like experience for training your models with all the callbacks, metrics, etc

Features

  • A high-level module for Keras-like training.
  • Flexibility to write custom training and validation loops for advanced use cases.
  • Off-the-shelf Pytorch style Datasets/Dataloaders for standard tasks such as Image classification, Image segmentation, Text Classification, etc
  • Callbacks for model checkpoints, early stopping, and much more!
  • TorchFlare uses powerful torchmetrics in the backend for metric computations!
  • Reduction of the boiler plate code required for training your models.
  • Create interactive UI for model prototyping and POC

Currently, TorchFlare supports CPU and GPU training. DDP and TPU support will be coming soon!


Installation

pip install torchflare

Documentation

The Documentation is available here


Getting Started

The core idea around TorchFlare is the Experiment class. It handles all the internal stuff like boiler plate code for training, calling callbacks,metrics,etc. The only thing you need to focus on is creating you PyTorch Model.

Also, there are off-the-shelf pytorch style datasets/dataloaders available for standard tasks, so that you don't have to worry about creating Pytorch Datasets/Dataloaders.

Here is an easy-to-understand example to show how Experiment class works.

import torch
import torchmetrics
import torch.nn as nn
from torchflare.experiments import Experiment, ModelConfig
import torchflare.callbacks as cbs

# Some dummy dataloaders
train_dl = SomeTrainingDataloader()
valid_dl = SomeValidationDataloader()
test_dl = SomeTestingDataloader()

Create a pytorch Model

class Net(nn.Module):
    def __init__(self, n_classes, p_dropout):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d(p=p_dropout)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, n_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

Define callbacks and metrics

metric_list = [
    torchmetrics.Accuracy(num_classes=num_classes)
]

callbacks = [
    cbs.EarlyStopping(monitor="val_accuracy", mode="max"),
    cbs.ModelCheckpoint(monitor="val_accuracy"),
    cbs.ReduceLROnPlateau(mode="max", patience=2),
]

Define Model Configuration

#Defining Model Config for experiment.
config = ModelConfig(
    nn_module=Net,
    module_params={"n_classes": 10, "p_dropout": 0.3},
    optimizer="Adam",
    optimizer_params={"lr": 3e-4},
    criterion="cross_entropy",
)

Define your experiment

# Set some constants for training
exp = Experiment(
    num_epochs=5,
    fp16=False,
    device="cuda",
    seed=42,
)

exp.compile_experiment(
    model_config=config,
    callbacks=callbacks,
    metrics=metric_list,
    main_metrics="accuracy",
)
# Run your experiment with training dataloader and validation dataloader.
exp.fit_loader(train_dl=train_dl, valid_dl=valid_dl)

For inference, you can use infer method, which yields output per batch. You can use it as follows

outputs = []
for op in exp.predict_on_loader(
    test_loader=test_dl, path_to_model="./models/model.bin", device="cuda"
):
    op = some_post_process_function(op)
    outputs.extend(op)

If you want to access your experiments history or get as a dataframe. You can do it as follows.

history = exp.history  # This will return a dict
exp.get_logs()  # This will return a dataframe constructed from model-history.

Examples


Contributions

To contribute please refer to Contributing Guide


Current Contributors


Author

Citation

Please use this bibtex if you want to cite this repository in your publications:

@misc{TorchFlare,
    author = {Atharva Phatak},
    title = {TorchFlare: Easy model training and experimentation.},
    year = {2020},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/Atharva-Phatak/torchflare}},
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchflare-0.2.4.tar.gz (48.5 kB view details)

Uploaded Source

Built Distribution

torchflare-0.2.4-py3-none-any.whl (65.0 kB view details)

Uploaded Python 3

File details

Details for the file torchflare-0.2.4.tar.gz.

File metadata

  • Download URL: torchflare-0.2.4.tar.gz
  • Upload date:
  • Size: 48.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for torchflare-0.2.4.tar.gz
Algorithm Hash digest
SHA256 5e1eca84f60436c3a32eed86e0013c57e4c16de52dbbce83a1cde8ba23fa95fc
MD5 e088b868d4bc57a8fd19451f3140d457
BLAKE2b-256 26e735ce6b1071ef099bde839dedfe91e93ce9241541734031c0b6e8d933008c

See more details on using hashes here.

File details

Details for the file torchflare-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: torchflare-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 65.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.2 importlib_metadata/4.8.1 pkginfo/1.7.1 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for torchflare-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 9b63620a6d19956817d39e62cd2e2ee7cf2940526a13057015fa3664da933ede
MD5 9c4630e1c34b7d0a273ca5358228ba32
BLAKE2b-256 a419b5fde107f59f3731263544f6e11a873d821392f5ba977aa37ec6943d59a6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page