Skip to main content

TorchFlare is a simple, beginner-friendly, and easy-to-use PyTorch Framework train your models effortlessly.

Project description

image

CodeFactor Test documentation DeepSource DeepSource codecov made-with-python GitHub license

TorchFlare

TorchFlare is a simple, beginner-friendly and an easy-to-use PyTorch Framework train your models without much effort. It provides an almost Keras-like experience for training your models with all the callbacks, metrics, etc

Features

  • A high-level module for Keras-like training.
  • Off-the-shelf Pytorch style Datasets/Dataloaders for standard tasks such as Image classification, Image segmentation, Text Classification, etc
  • Callbacks for model checkpoints, early stopping, and much more!
  • Metrics and much more.

Currently, TorchFlare supports CPU and GPU training. DDP and TPU support will be coming soon!

Note : This library is in its nascent stage. So, there might be breaking changes.

Getting Started

The core idea around TorchFlare is the Experiment class. It handles all the internal stuff like boiler plate code for training, calling callbacks,metrics,etc. The only thing you need to focus on is creating you PyTorch Model.

Also, there are off-the-shelf pytorch style datasets/dataloaders available for standard tasks, so that you don't have to worry about creating Pytorch Datasets/Dataloaders.

Here is an easy-to-understand example to show how Experiment class works.

import torch
import torch.nn as nn
from torchflare.experiments import Experiment
import torchflare.callbacks as cbs
import torchflare.metrics as metrics

#Some dummy dataloaders
train_dl = SomeTrainingDataloader()
valid_dl = SomeValidationDataloader()
test_dl = SomeTestingDataloader()

Create a pytorch Model

model = nn.Sequential(
    nn.Linear(num_features, hidden_state_size),
    nn.ReLU(),
    nn.Linear(hidden_state_size, num_classes)
)

Define callbacks and metrics

metric_list = [metrics.Accuracy(num_classes=num_classes, multilabel=False),
                metrics.F1Score(num_classes=num_classes, multilabel=False)]

callbacks = [cbs.EarlyStopping(monitor="accuracy", mode="max"), cbs.ModelCheckpoint(monitor="accuracy")]

Define your experiment

# Set some constants for training
exp = Experiment(
    num_epochs=5,
    save_dir="./models",
    model_name="model.bin",
    fp16=False,
    using_batch_mixers=False,
    device="cuda",
    compute_train_metrics=True,
    seed=42,
)

# Compile your experiment with model, optimizer, schedulers, etc
exp.compile_experiment(
    model=net,
    optimizer="Adam",
    optimizer_params=dict(lr=3e-4),
    callbacks=callbacks,
    scheduler="ReduceLROnPlateau",
    scheduler_params=dict(mode="max", patience=5),
    criterion="cross_entropy",
    metrics=metric_list,
    main_metric="accuracy",
)

# Run your experiment with training dataloader and validation dataloader.
exp.run_experiment(train_dl=train_dl, valid_dl= valid_dl)

For inference you can use infer method, which yields output per batch. You can use it as follows

outputs = []

for op in exp.infer(test_loader=test_dl , path='./models/model.bin' , device = 'cuda'):
    op = some_post_process_function(op)
    outputs.extend(op)

Experiment class internally saves a history.csv file which includes your training and validation metrics per epoch. This file can be found in same directory as save_dir argument.

If you want to access your experiments history or plot it. You can do it as follows.

history = exp.history.history # This will return a dict

# If you want to plot progress of particular metric as epoch progress use this.

exp.plot_history(key = "accuracy" , save_fig = False , plot_fig = True)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchflare-0.1.0.tar.gz (41.7 kB view details)

Uploaded Source

Built Distribution

torchflare-0.1.0-py3-none-any.whl (66.9 kB view details)

Uploaded Python 3

File details

Details for the file torchflare-0.1.0.tar.gz.

File metadata

  • Download URL: torchflare-0.1.0.tar.gz
  • Upload date:
  • Size: 41.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4

File hashes

Hashes for torchflare-0.1.0.tar.gz
Algorithm Hash digest
SHA256 59132bbea222b4081ad5abe1cf75f37bd890a493e2dce91517fcd3e79e9a8cc0
MD5 93da6cd540b46ddf89b60bce104fd8d4
BLAKE2b-256 a56c8d8ac56ad554029f2b392e3a5b7615a7b26f59aa7f138f89cec0fa2c686e

See more details on using hashes here.

File details

Details for the file torchflare-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: torchflare-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 66.9 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4

File hashes

Hashes for torchflare-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 ed790dfc441b19f5d664ab3e1dd574a6f6a4c48b818b54a05626ec20cdffc813
MD5 25a4cd841164634d66ec77dcb69f20e8
BLAKE2b-256 5d5e0ef607f8ed76ca296d47bdb5868baec30abd9c8a714d8a10011ac4038971

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page