Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

Easily integrate a LightningCLI with RayTune hyperparameter optimization

Getting Started

Extend your custom LightningCLI to a Trainable compatible with RayTune.

Imagine you have a LightningCLI parser for a

from lightning.pytorch as pl

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class DataModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

Launching a hyperparameter tuning job with RayTune using this LightningCLI is as simple as

from ray import tune
from raylightning import run

# define search space
search_space = {
    "model.init_args.hidden_dim": tune.choice([32, 64]),
    "model.init_args.learning_rate": tune.loguniform(1e-4, 1e-2)
}

# define scheduler
scheduler = ASHAScheduler(max_t=50, grace_period=1, reduction_factor=2)

# pass command line style arguments as if you were
# running your `LightningCLI` from the command line
args = ["--config", "/path/to/config.yaml"]
results = run(
    args=args
    cli_cls=simple_cli,
    name="tune-test",
    metric="val_loss",
    objective="min",
    search_space=search_space,
    scheduler=scheduler,
    storage_dir=storage_dir,
    address=None,
    num_samples=num_samples,
    workers_per_trial=1,
    gpus_per_worker=1.0,
    cpus_per_gpu=4.0,
    temp_dir=None,
)

s3 storage works out of the box. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables set. Then, simply pass an s3 path (e.g. s3://{bucket}/{folder} to the storage_dir argument.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.1.1.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

lightray-0.1.1-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.1.1.tar.gz.

File metadata

  • Download URL: lightray-0.1.1.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.1.tar.gz
Algorithm Hash digest
SHA256 2e8bcb718d32ce14d0780aaf96fd2f81288886c2fb9ecad9ff71a8e42aa00b88
MD5 79b930af02ee6b7385b8e2a713b6bef0
BLAKE2b-256 49d18276b819048d08dc12d80dcf383f3f2e987f01cbac8ae41cc25f5956e605

See more details on using hashes here.

File details

Details for the file lightray-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 40bae6757302918593e32f3cba591ddddca91c5304da05865b3d2d3bcd7bbcd9
MD5 5ee2791dce9c91476f0921d69a1c8248
BLAKE2b-256 4dfc048bbf918e4885c4d5a33dc553bdd71b483794654db37a00fd16d17870fa

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page