Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

A CLI for easily integrating LightningCLI with RayTune hyperparameter optimization

Getting Started

Imagine you have the following lightning DataModule, LightningModule and LightningCLI

from lightning.pytorch as pl

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class LightningModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

To launching a hyperparameter tuning job with RayTune using this LightningCLI can be done by configuring a yaml that looks like the following

# tune.yaml

# ray.tune.TuneConfig
tune_config:
  mode: "min"
  metric: "val_loss"
  scheduler: 
    class_path: ray.tune.schedulers.ASHAScheduler
    init_args:
      max_t: 200
      grace_period: 21
      reduction_factor: 2
  num_samples: 1
  reuse_actors: true

# ray.train.RunConfig
run_config:
  name: "my-first-run"
  storage_path: s3://aframe-test/new-tune/
  failure_config:
    class_path: ray.train.FailureConfig
    init_args:
      max_failures: 1
  checkpoint_config:
    class_path: ray.train.CheckpointConfig
    init_args:
      num_to_keep: 5
      checkpoint_score_attribute: "val_loss"
      checkpoint_score_order: "min"
  verbose: null

# ray.train.SyncConfig
sync_config:
  sync_period: 1000

# ray.init
ray_init:
  address: null
  
# tune.Tune.param_space
param_space:
  model.learning_rate: tune.loguniform(1e-3, 4)

# ray.tune.TuneCallback
tune_callback:
  class_path: lightray.callbacks.LightRayReportCheckpointCallback
  init_args:
    'on': "validation_end"
    checkpoint_every: 10

# resources per trial
cpus_per_trial: 2
gpus_per_trial: 1

# lightning cli
lightning_cli_cls: example.lightning.Cli
lightning_config: /home/ethan.marx/projects/lightray/example/cli.yaml

Then, launch the tuning job

lightray --config tune.yaml 

S3 Support

In addition, there is automatic support for s3 storage. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables. Then, simply pass the path to your bucket with the s3 prefix (e.g. s3://{bucket}/{folder}) to the run_config.storage_path argument.

lightray --config tune.yaml --run_config.storage_path s3://{bucket}/{folder}

There is also a wrapper around the ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback that will do checkpoint reporting with retries to handle transient s3 errors. This is provided at lightray.callbacks.LightRayReportCheckpointCallback

Remote cluster

To connect to a remote ray cluster, pass the ip address and (using port 10001) to ray_init.address:

lightray --config tune.yaml --ray_init.address ray://{ip}:10001

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.2.1.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

lightray-0.2.1-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.2.1.tar.gz.

File metadata

  • Download URL: lightray-0.2.1.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.2.1.tar.gz
Algorithm Hash digest
SHA256 4f91f664cd62d2453d63977f27592a8d7adfbe80919febe22db754fd57366cf0
MD5 b981216f080ebd0185b99d388876e6d0
BLAKE2b-256 59f98f4da79d505338b96750165b407c427d99c8f04700c22d5c9eb6885c5f0e

See more details on using hashes here.

File details

Details for the file lightray-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 803425ab3d243bcc7b8576997cedba6f62bcbd0c00810660c8c62531877bea69
MD5 54b0e779f273bcecc78d55845e039db2
BLAKE2b-256 7898724490d5e65f30d85ee55880bb19a1abed44a47a9ed433bff9c2822d8950

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page