Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

Easily integrate a LightningCLI with RayTune hyperparameter optimization

Getting Started

Extend your custom LightningCLI to a Trainable compatible with RayTune.

Imagine you have a LightningCLI parser for a

from lightning.pytorch as pl

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class DataModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

Launching a hyperparameter tuning job with RayTune using this LightningCLI is as simple as

from ray import tune
from raylightning import run

# define search space
search_space = {
    "model.init_args.hidden_dim": tune.choice([32, 64]),
    "model.init_args.learning_rate": tune.loguniform(1e-4, 1e-2)
}

# define scheduler
scheduler = ASHAScheduler(max_t=50, grace_period=1, reduction_factor=2)

# pass command line style arguments as if you were
# running your `LightningCLI` from the command line
args = ["--config", "/path/to/config.yaml"]
results = run(
    args=args
    cli_cls=simple_cli,
    name="tune-test",
    metric="val_loss",
    objective="min",
    search_space=search_space,
    scheduler=scheduler,
    storage_dir=storage_dir,
    address=None,
    num_samples=num_samples,
    workers_per_trial=1,
    gpus_per_worker=1.0,
    cpus_per_gpu=4.0,
    temp_dir=None,
)

s3 storage works out of the box. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables set. Then, simply pass an s3 path (e.g. s3://{bucket}/{folder} to the storage_dir argument.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.1.5.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

lightray-0.1.5-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.1.5.tar.gz.

File metadata

  • Download URL: lightray-0.1.5.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.5.tar.gz
Algorithm Hash digest
SHA256 6677facd2cbb8292bde00efcf1aec13ff6342598503900c4fa4f85a54f1473e1
MD5 a891e58c8e3cdfd1bbe18d3b4603823c
BLAKE2b-256 da5d6b225749f794e97fef2bb1821002df52bdda873b44ee433b46c11aece005

See more details on using hashes here.

File details

Details for the file lightray-0.1.5-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.1.5-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.5-py3-none-any.whl
Algorithm Hash digest
SHA256 5d5b8cdb54e3990d95bfd8147b86f7651cc632685514db9150a069994252a508
MD5 079760021f86e8abcab1b43466d81573
BLAKE2b-256 5e493de65893a75e43e622d9fc4bfc9f362c5d54852b076135751343af5d441d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page