Distribute a LightningCLI hyperparameter search with Ray Tune
Project description
lightray
Easily integrate a LightningCLI
with RayTune
hyperparameter optimization
Getting Started
Extend your custom LightningCLI
to a Trainable
compatible with RayTune
.
Imagine you have a LightningCLI
parser for a
from lightning.pytorch as pl
class CustomCLI(pl.cli.LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments(
"data.init_args.parameter",
"model.init_args.parameter",
apply_on="parse"
)
class DataModule(pl.LightningDataModule):
def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
self.parameter = parameter
self.hidden_dim = hidden_dim
self.learning_rate = learning_rate
def train_dataloader(self):
...
class DataModule(pl.LightningModule):
def __init__(self, parameter: int):
self.parameter = parameter
def training_step(self):
...
Launching a hyperparameter tuning job with RayTune
using this LightningCLI
is as simple as
from ray import tune
from raylightning import run
# define search space
search_space = {
"model.init_args.hidden_dim": tune.choice([32, 64]),
"model.init_args.learning_rate": tune.loguniform(1e-4, 1e-2)
}
# define scheduler
scheduler = ASHAScheduler(max_t=50, grace_period=1, reduction_factor=2)
# pass command line style arguments as if you were
# running your `LightningCLI` from the command line
args = ["--config", "/path/to/config.yaml"]
results = run(
args=args
cli_cls=simple_cli,
name="tune-test",
metric="val_loss",
objective="min",
search_space=search_space,
scheduler=scheduler,
storage_dir=storage_dir,
address=None,
num_samples=num_samples,
workers_per_trial=1,
gpus_per_worker=1.0,
cpus_per_gpu=4.0,
temp_dir=None,
)
s3 storage works out of the box. Make sure you have set the AWS_ENDPOINT_URL
, AWS_ACCESS_KEY_ID
, and AWS_SECRET_ACCESS_KEY
environment variables set. Then, simply pass an s3 path (e.g. s3://{bucket}/{folder}
to the storage_dir
argument.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file lightray-0.1.1.tar.gz
.
File metadata
- Download URL: lightray-0.1.1.tar.gz
- Upload date:
- Size: 8.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2e8bcb718d32ce14d0780aaf96fd2f81288886c2fb9ecad9ff71a8e42aa00b88 |
|
MD5 | 79b930af02ee6b7385b8e2a713b6bef0 |
|
BLAKE2b-256 | 49d18276b819048d08dc12d80dcf383f3f2e987f01cbac8ae41cc25f5956e605 |
File details
Details for the file lightray-0.1.1-py3-none-any.whl
.
File metadata
- Download URL: lightray-0.1.1-py3-none-any.whl
- Upload date:
- Size: 9.0 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 40bae6757302918593e32f3cba591ddddca91c5304da05865b3d2d3bcd7bbcd9 |
|
MD5 | 5ee2791dce9c91476f0921d69a1c8248 |
|
BLAKE2b-256 | 4dfc048bbf918e4885c4d5a33dc553bdd71b483794654db37a00fd16d17870fa |