An AutoML Library made with Optuna and PyTorch Lightning
Project description
Gradsflow
An AutoML Library made with Optuna and PyTorch Lightning
Installation
Recommended
pip install -U gradsflow
From source
pip install git+https://github.com/gradsflow/gradsflow@main
Examples
Image Classification
from gradsflow.autoclassifier import AutoImageClassifier
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
)
suggested_conf = dict(
optimizers=["adam"],
lr=(5e-4, 1e-3),
)
model = AutoImageClassifier(datamodule,
suggested_backbones=['ssl_resnet18'],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
timeout=30)
print("AutoImageClassifier initialised!")
model.hp_tune()
Text Classification
from gradsflow.autoclassifier import AutoTextClassifier
from flash.core.data.utils import download_data
from flash.text import TextClassificationData
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
datamodule = TextClassificationData.from_csv(
"review",
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
backbone="prajjwal1/bert-medium",
)
suggested_conf = dict(
optimizers=["adam"],
lr=(5e-4, 1e-3),
)
model = AutoTextClassifier(datamodule,
suggested_backbones=['sgugger/tiny-distilbert-classification'],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
timeout=30)
print("AutoTextClassifier initialised!")
model.hp_tune()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
gradsflow-0.0.1a1.tar.gz
(18.1 kB
view hashes)
Built Distribution
Close
Hashes for gradsflow-0.0.1a1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0bb3d8d7ffd7f4e75810026aa4f7c8841ee67a072551acf32c98635470f81f2a |
|
MD5 | 53f0eafaeded9d4c9e1d85d36cb0a078 |
|
BLAKE2b-256 | 60d1031960351a2810c1956d49c437da18f33d8fde538815d9021cc4cf082ac1 |