Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

Easily integrate a LightningCLI with RayTune hyperparameter optimization

Getting Started

Extend your custom LightningCLI to a Trainable compatible with RayTune.

Imagine you have a LightningCLI parser for a

from lightning.pytorch as pl

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class DataModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

Launching a hyperparameter tuning job with RayTune using this LightningCLI is as simple as

from ray import tune
from raylightning import run

# define search space
search_space = {
    "model.init_args.hidden_dim": tune.choice([32, 64]),
    "model.init_args.learning_rate": tune.loguniform(1e-4, 1e-2)
}

# define scheduler
scheduler = ASHAScheduler(max_t=50, grace_period=1, reduction_factor=2)

# pass command line style arguments as if you were
# running your `LightningCLI` from the command line
args = ["--config", "/path/to/config.yaml"]
results = run(
    args=args
    cli_cls=simple_cli,
    name="tune-test",
    metric="val_loss",
    objective="min",
    search_space=search_space,
    scheduler=scheduler,
    storage_dir=storage_dir,
    address=None,
    num_samples=num_samples,
    workers_per_trial=1,
    gpus_per_worker=1.0,
    cpus_per_gpu=4.0,
    temp_dir=None,
)

s3 storage works out of the box. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables set. Then, simply pass an s3 path (e.g. s3://{bucket}/{folder} to the storage_dir argument.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.1.2.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

lightray-0.1.2-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.1.2.tar.gz.

File metadata

  • Download URL: lightray-0.1.2.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.2.tar.gz
Algorithm Hash digest
SHA256 533da60da38cd1d121e80e6700d5e3ea92fdeadb9ec4e2e2ab0fb339f8938901
MD5 0dae2c4d543a51ec1370cf500ec04807
BLAKE2b-256 568b8b4bbf480623e1960b9a7cff07479ae673fde024383ca61ff3b4f8f53b84

See more details on using hashes here.

File details

Details for the file lightray-0.1.2-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.1.2-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.2-py3-none-any.whl
Algorithm Hash digest
SHA256 e662ef26317adfce3cb23767500faffb81b8e56af964841c91cee27112a9bb96
MD5 7988177b55b973155e213b5f6081766b
BLAKE2b-256 6243a02721de1660c600177190760d30a91bde930423044b349e6374b8fc04e9

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page