An AutoML Library made with Optuna and PyTorch Lightning
Project description
An AutoML Library made with Optuna and PyTorch Lightning
Installation
Recommended
pip install -U gradsflow
From source
pip install git+https://github.com/gradsflow/gradsflow@main
Examples
Image Classification
from gradsflow.autoclassifier import AutoImageClassifier
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
)
suggested_conf = dict(
optimizers=["adam"],
lr=(5e-4, 1e-3),
)
model = AutoImageClassifier(datamodule,
suggested_backbones=['ssl_resnet18'],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
timeout=30)
print("AutoImageClassifier initialised!")
model.hp_tune()
Text Classification
from gradsflow.autoclassifier import AutoTextClassifier
from flash.core.data.utils import download_data
from flash.text import TextClassificationData
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
datamodule = TextClassificationData.from_csv(
"review",
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
backbone="prajjwal1/bert-medium",
)
suggested_conf = dict(
optimizers=["adam"],
lr=(5e-4, 1e-3),
)
model = AutoTextClassifier(datamodule,
suggested_backbones=['sgugger/tiny-distilbert-classification'],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
timeout=30)
print("AutoTextClassifier initialised!")
model.hp_tune()
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
gradsflow-0.0.1b1.tar.gz
(19.7 kB
view hashes)
Built Distribution
Close
Hashes for gradsflow-0.0.1b1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7a94c5719514ff5be8ba3a123408b4cdd0e651ad5b206ae113ae307c5fbfa366 |
|
MD5 | 1f98cabee02d17881ecf6a978dcb9f9a |
|
BLAKE2b-256 | 58905037dd04fe531bf6423f29a8df61f94d64cd425e75dfe53ab182deb74c18 |