Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

Easily integrate a LightningCLI with RayTune hyperparameter optimization

Getting Started

Extend your custom LightningCLI to a Trainable compatible with RayTune.

Imagine you have a LightningCLI parser for a

from lightning.pytorch as pl

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class DataModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

Launching a hyperparameter tuning job with RayTune using this LightningCLI is as simple as

from ray import tune
from raylightning import run

# define search space
search_space = {
    "model.init_args.hidden_dim": tune.choice([32, 64]),
    "model.init_args.learning_rate": tune.loguniform(1e-4, 1e-2)
}

# define scheduler
scheduler = ASHAScheduler(max_t=50, grace_period=1, reduction_factor=2)

# pass command line style arguments as if you were
# running your `LightningCLI` from the command line
args = ["--config", "/path/to/config.yaml"]
results = run(
    args=args
    cli_cls=simple_cli,
    name="tune-test",
    metric="val_loss",
    objective="min",
    search_space=search_space,
    scheduler=scheduler,
    storage_dir=storage_dir,
    address=None,
    num_samples=num_samples,
    workers_per_trial=1,
    gpus_per_worker=1.0,
    cpus_per_gpu=4.0,
    temp_dir=None,
)

s3 storage works out of the box. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables set. Then, simply pass an s3 path (e.g. s3://{bucket}/{folder} to the storage_dir argument.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.1.3.tar.gz (8.0 kB view details)

Uploaded Source

Built Distribution

lightray-0.1.3-py3-none-any.whl (9.0 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.1.3.tar.gz.

File metadata

  • Download URL: lightray-0.1.3.tar.gz
  • Upload date:
  • Size: 8.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.3.tar.gz
Algorithm Hash digest
SHA256 b08e6d8d37b5a7d39fa603ae7a3fee3535bff6ec494242516c4a1acc2b08b1d4
MD5 6511efef6e8bbd66c456e89bce16e1ff
BLAKE2b-256 2f409b72112b0cb269a57a81cf213f5ab6175de657b5bafd1639c3b06d1e8dca

See more details on using hashes here.

File details

Details for the file lightray-0.1.3-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.1.3-py3-none-any.whl
  • Upload date:
  • Size: 9.0 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.3-py3-none-any.whl
Algorithm Hash digest
SHA256 86e9c53eb24e8eb1cb8091d66f8c40287cc1c4f92bfd30d8af1a9c8df79da822
MD5 6b174c89313a6302df66356f47c25e40
BLAKE2b-256 20edce87fb677be430b9c3e713ddaae1e7cfc90749b0cce91ac9938f8fac3098

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page