TorchFlare is a simple, beginner-friendly, and easy-to-use PyTorch Framework train your models effortlessly.
Project description
TorchFlare
TorchFlare is a simple, beginner-friendly and an easy-to-use PyTorch Framework train your models without much effort. It provides an almost Keras-like experience for training your models with all the callbacks, metrics, etc
Features
- A high-level module for Keras-like training.
- Off-the-shelf Pytorch style Datasets/Dataloaders for standard tasks such as Image classification, Image segmentation, Text Classification, etc
- Callbacks for model checkpoints, early stopping, and much more!
- Metrics and much more.
Currently, TorchFlare supports CPU and GPU training. DDP and TPU support will be coming soon!
Note : This library is in its nascent stage. So, there might be breaking changes.
Installation
pip install torchflare
Documentation
The Documentation is available here
Getting Started
The core idea around TorchFlare is the Experiment class. It handles all the internal stuff like boiler plate code for training, calling callbacks,metrics,etc. The only thing you need to focus on is creating you PyTorch Model.
Also, there are off-the-shelf pytorch style datasets/dataloaders available for standard tasks, so that you don't have to worry about creating Pytorch Datasets/Dataloaders.
Here is an easy-to-understand example to show how Experiment class works.
import torch
import torch.nn as nn
from torchflare.experiments import Experiment
import torchflare.callbacks as cbs
import torchflare.metrics as metrics
#Some dummy dataloaders
train_dl = SomeTrainingDataloader()
valid_dl = SomeValidationDataloader()
test_dl = SomeTestingDataloader()
Create a pytorch Model
model = nn.Sequential(
nn.Linear(num_features, hidden_state_size),
nn.ReLU(),
nn.Linear(hidden_state_size, num_classes)
)
Define callbacks and metrics
metric_list = [metrics.Accuracy(num_classes=num_classes, multilabel=False),
metrics.F1Score(num_classes=num_classes, multilabel=False)]
callbacks = [cbs.EarlyStopping(monitor="accuracy", mode="max"), cbs.ModelCheckpoint(monitor="accuracy"),
cbs.ReduceLROnPlateau(mode="max" , patience = 2)]
Define your experiment
# Set some constants for training
exp = Experiment(
num_epochs=5,
save_dir="./models",
model_name="model.bin",
fp16=False,
using_batch_mixers=False,
device="cuda",
compute_train_metrics=True,
seed=42,
)
# Compile your experiment with model, optimizer, schedulers, etc
exp.compile_experiment(
model=net,
optimizer="Adam",
optimizer_params=dict(lr=3e-4),
callbacks=callbacks,
criterion="cross_entropy",
metrics=metric_list,
main_metric="accuracy",
)
# Run your experiment with training dataloader and validation dataloader.
exp.run_experiment(train_dl=train_dl, valid_dl= valid_dl)
For inference, you can use infer method, which yields output per batch. You can use it as follows
outputs = []
for op in exp.infer(test_loader=test_dl , path='./models/model.bin' , device = 'cuda'):
op = some_post_process_function(op)
outputs.extend(op)
If you want to access your experiments history or plot it. You can do it as follows.
history = exp.history # This will return a dict
# If you want to plot progress of particular metric as epoch progress use this.
exp.plot_history(keys = ["loss" , "accuracy"] , save_fig = False , plot_fig = True)
Examples
- Image Classification on CIFAR-10 using TorchFlare.
- Text Classification on IMDB data.
- Binary Classification of Tabular Data on previous kaggle competition
- Tutorial on using Hydra and TorchFlare for efficient workflow and parameter management.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file torchflare-0.2.0.tar.gz
.
File metadata
- Download URL: torchflare-0.2.0.tar.gz
- Upload date:
- Size: 47.3 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5039caeb3e2f294ee325f34f77e6f738b27cf6fdf9744087663d06a063fc244f |
|
MD5 | 324fc7c85b231cc7b1f9c75387e2526c |
|
BLAKE2b-256 | 001f56ef61e260cd0329a4d232142f7a8b383fe4bb5092f6f7581f69c128a141 |
File details
Details for the file torchflare-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: torchflare-0.2.0-py3-none-any.whl
- Upload date:
- Size: 76.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.25.1 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.9.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 338d712d57b6dd43e150e00e69e27dd59bd9fc6c0969627a4276e0bbafbde7ff |
|
MD5 | 8764fafacdf288b0bce21f9af9ef6dc3 |
|
BLAKE2b-256 | 694a5c3eac27ce2a8fcda0a72ec9c843298d6d2efdc5c471ad0d7f20f96face7 |