Skip to main content

TorchFlare is a simple, beginner-friendly, and easy-to-use PyTorch Framework train your models effortlessly.

Project description

image

PyPI API GitHub release (latest by date) CodeFactor Test Documentation Status Publish-PyPI DeepSource DeepSource codecov made-with-python GitHub license PyPI - Downloads

TorchFlare

TorchFlare is a simple, beginner-friendly and an easy-to-use PyTorch Framework train your models without much effort. It provides an almost Keras-like experience for training your models with all the callbacks, metrics, etc

Features

  • A high-level module for Keras-like training.
  • Flexibility to write custom training and validation loops for advanced use cases.
  • Off-the-shelf Pytorch style Datasets/Dataloaders for standard tasks such as Image classification, Image segmentation, Text Classification, etc
  • Callbacks for model checkpoints, early stopping, and much more!
  • Metrics and much more.
  • Reduction of the boiler plate code required for training your models.

Currently, TorchFlare supports CPU and GPU training. DDP and TPU support will be coming soon!


Installation

pip install torchflare

Documentation

The Documentation is available here


Getting Started

The core idea around TorchFlare is the Experiment class. It handles all the internal stuff like boiler plate code for training, calling callbacks,metrics,etc. The only thing you need to focus on is creating you PyTorch Model.

Also, there are off-the-shelf pytorch style datasets/dataloaders available for standard tasks, so that you don't have to worry about creating Pytorch Datasets/Dataloaders.

Here is an easy-to-understand example to show how Experiment class works.

import torch
import torch.nn as nn
from torchflare.experiments import Experiment, ModelConfig
import torchflare.callbacks as cbs
import torchflare.metrics as metrics

#Some dummy dataloaders
train_dl = SomeTrainingDataloader()
valid_dl = SomeValidationDataloader()
test_dl = SomeTestingDataloader()

Create a pytorch Model

class Net(nn.Module):
       def __init__(self, n_classes, p_dropout):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
            self.conv2_drop = nn.Dropout2d(p=p_dropout)
            self.fc1 = nn.Linear(320, 50)
            self.fc2 = nn.Linear(50, n_classes)

       def forward(self, x):
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
            x = x.view(-1, 320)
            x = F.relu(self.fc1(x))
            x = F.dropout(x, training=self.training)
            x = self.fc2(x)
            return x

Define callbacks and metrics

metric_list = [metrics.Accuracy(num_classes=num_classes, multilabel=False),
                metrics.F1Score(num_classes=num_classes, multilabel=False)]

callbacks = [cbs.EarlyStopping(monitor="val_accuracy", mode="max"), cbs.ModelCheckpoint(monitor="val_accuracy"),
            cbs.ReduceLROnPlateau(mode="max" , patience = 2)]

Define your experiment

# Set some constants for training
exp = Experiment(
    num_epochs=5,
    fp16=False,
    device="cuda",
    seed=42,
)

# Compile your experiment with model, optimizer, schedulers, etc
config = ModelConfig(nn_module = Net,
                          module_params = {"n_classes" : 10 , "p_dropout" : 0.3},
                          optimizer = "Adam"
                          optimizer_params = {"lr" : 3e-4},
                          criterion = "cross_entropy")

exp.compile_experiment(model_config = config,
                       callbacks = callbacks,
                       metrics = metric_list,
                       main_metrics = "accuracy")
# Run your experiment with training dataloader and validation dataloader.
exp.fit_loader(train_dl=train_dl, valid_dl= valid_dl)

For inference, you can use infer method, which yields output per batch. You can use it as follows

outputs = []

for op in exp.predict_on_loader(test_loader=test_dl , path_to_model='./models/model.bin' , device = 'cuda'):
    op = some_post_process_function(op)
    outputs.extend(op)

If you want to access your experiments history or get as a dataframe. You can do it as follows.

history = exp.history # This will return a dict
exp.get_logs() #This will return a dataframe constructed from model-history.

Examples


Current Contributors


Stability

The library isn't mature or stable for production use yet.

The best of the library currently would be for non production use and rapid prototyping.


Contribution

Contributions are always welcome, it would be great to have people use and contribute to this project to help users understand and benefit from the library.

How to contribute

  • Create an issue: If you have a new feature in mind, feel free to open an issue and add some short description on what that feature could be.
  • Create a PR: If you have a bug fix, enhancement or new feature addition, create a Pull Request and the maintainers of the repo, would review and merge them.

Author

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchflare-0.2.3.tar.gz (57.6 kB view details)

Uploaded Source

Built Distribution

torchflare-0.2.3-py3-none-any.whl (85.4 kB view details)

Uploaded Python 3

File details

Details for the file torchflare-0.2.3.tar.gz.

File metadata

  • Download URL: torchflare-0.2.3.tar.gz
  • Upload date:
  • Size: 57.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5

File hashes

Hashes for torchflare-0.2.3.tar.gz
Algorithm Hash digest
SHA256 795f0f86ce7f822c6591ba13b421a86e49754caeff452384aa1fe7010cfd00d0
MD5 660c4fbe9b1fdbde3d58a70d5cba8435
BLAKE2b-256 157ab4ff6d352bdec98edba30682b8b5bbddf2fd4e6020f313ec2905833f7ada

See more details on using hashes here.

File details

Details for the file torchflare-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: torchflare-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 85.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.6.0 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.61.1 CPython/3.9.5

File hashes

Hashes for torchflare-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 4413542f52d2a04fe74c9a7b58b5798d048f15f0cbe851d3d187fcc961f339c7
MD5 02a3b76cffef8751b3e374d32f259ea9
BLAKE2b-256 89aa09f554cf578e07aafcb0d38e754db0aa0d7c296f2bf297097bc63ac940f8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page