An AutoML Library made with Optuna and PyTorch Lightning
Project description
An AutoML Library made with Optuna and PyTorch Lightning
Installation
Recommended
pip install -U gradsflow
From source
pip install git+https://github.com/gradsflow/gradsflow@main
Examples
Auto Image Classification
from gradsflow.autoclassifier import AutoImageClassifier
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
)
suggested_conf = dict(
optimizers=["adam"],
lr=(5e-4, 1e-3),
)
model = AutoImageClassifier(datamodule,
suggested_backbones=['ssl_resnet18'],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
timeout=30)
print("AutoImageClassifier initialised!")
model.hp_tune()
Auto Text Classification
from gradsflow.autoclassifier import AutoTextClassifier
from flash.core.data.utils import download_data
from flash.text import TextClassificationData
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
datamodule = TextClassificationData.from_csv(
"review",
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
backbone="prajjwal1/bert-medium",
)
suggested_conf = dict(
optimizers=["adam"],
lr=(5e-4, 1e-3),
)
model = AutoTextClassifier(datamodule,
suggested_backbones=['sgugger/tiny-distilbert-classification'],
suggested_conf=suggested_conf,
max_epochs=1,
optimization_metric="val_accuracy",
timeout=30)
print("AutoTextClassifier initialised!")
model.hp_tune()
Auto Text Summarization
from gradsflow.autoclassifier import AutoSummarization
from flash.core.data.utils import download_data
from flash.text import SummarizationData
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/")
# 2. Load the data
datamodule = SummarizationData.from_csv(
"input",
"target",
train_file="data/xsum/train.csv",
val_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
)
suggested_conf = dict(
optimizers=["adam"],
lr=(5e-4, 1e-3),
)
model = AutoSummarization(
datamodule,
max_epochs=1,
timeout=5,
suggested_backbones="sshleifer/distilbart-cnn-12-6",
n_trials=1,
)
print("AutoSummarization initialised!")
model.hp_tune()
📑 For detailed usage examples please visit our documentation page.
💬 Join the Slack group to chat with us.
🤗 Contribute
Contributions of any kind are welcome. Please check the Contributing Guidelines before contributing.
Code Of Conduct
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
Read full Contributor Covenant Code of Conduct
Acknowledgement
Gradsflow is built with help of Optuna and PyTorch Lightning 💜
Citing
@software{aniket_maurya_2021_5245151,
author = {Aniket Maurya},
title = {{gradsflow/gradsflow: An AutoML Library made with
Optuna and PyTorch Lightning}},
month = aug,
year = 2021,
publisher = {Zenodo},
version = {v0.0.1b1},
doi = {10.5281/zenodo.5245151},
url = {https://doi.org/10.5281/zenodo.5245151}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
gradsflow-0.0.1.tar.gz
(25.3 kB
view hashes)
Built Distribution
gradsflow-0.0.1-py3-none-any.whl
(15.0 kB
view hashes)
Close
Hashes for gradsflow-0.0.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c9a80a2c2ce3a268aba979d26a7abd5387f8ff5b147968d5c95497bddd0e6176 |
|
MD5 | e5db63b9eec419a0646b5a2fd832b4d1 |
|
BLAKE2b-256 | f1cd2ab77565d5ae2cd6a629c5e9137e4c72b7d1afc28039a1d44d816225c31d |