Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

A CLI for easily integrating LightningCLI with RayTune hyperparameter optimization

Getting Started

Imagine you have the following lightning DataModule, LightningModule and LightningCLI

from lightning.pytorch as pl

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class LightningModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

To launching a hyperparameter tuning job with RayTune using this LightningCLI can be done by configuring a yaml that looks like the following

# tune.yaml

# ray.tune.TuneConfig
tune_config:
  mode: "min"
  metric: "val_loss"
  scheduler: 
    class_path: ray.tune.schedulers.ASHAScheduler
    init_args:
      max_t: 200
      grace_period: 21
      reduction_factor: 2
  num_samples: 1
  reuse_actors: true

# ray.train.RunConfig
run_config:
  name: "my-first-run"
  storage_path: s3://aframe-test/new-tune/
  failure_config:
    class_path: ray.train.FailureConfig
    init_args:
      max_failures: 1
  checkpoint_config:
    class_path: ray.train.CheckpointConfig
    init_args:
      num_to_keep: 5
      checkpoint_score_attribute: "val_loss"
      checkpoint_score_order: "min"
  verbose: null

# ray.train.SyncConfig
sync_config:
  sync_period: 1000

# ray.init
ray_init:
  address: null
  
# tune.Tune.param_space
param_space:
  model.learning_rate: tune.loguniform(1e-3, 4)

# ray.tune.TuneCallback
tune_callback:
  class_path: lightray.callbacks.LightRayReportCheckpointCallback
  init_args:
    'on': "validation_end"
    checkpoint_every: 10

# resources per trial
cpus_per_trial: 2
gpus_per_trial: 1

# lightning cli class to tune
lightning_cli_cls: example.lightning.Cli

Then, launch the tuning job.

lightray --config tune.yaml -- --config cli.yaml

Any arguments passed after "--" will automatically be forwarded to the specified LightningCLI class as if it were being run from the command line. This way, tuning configuration and training configuration can be kept separate. Also, you can easily utilize existing configuration files compatible with your LightningCLI.

S3 Support

In addition, there is automatic support for s3 storage. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables. Then, simply pass the path to your bucket with the s3 prefix (e.g. s3://{bucket}/{folder}) to the run_config.storage_path argument.

lightray --config tune.yaml --run_config.storage_path s3://{bucket}/{folder}

There is also a wrapper around the ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback that will do checkpoint reporting with retries to handle transient s3 errors. This is provided at lightray.callbacks.LightRayReportCheckpointCallback

Remote cluster

To connect to a remote ray cluster, pass the ip address (using port 10001) to ray_init.address:

lightray --config tune.yaml --ray_init.address ray://{ip}:10001 -- --config cli.yaml

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.2.3.tar.gz (5.9 kB view details)

Uploaded Source

Built Distribution

lightray-0.2.3-py3-none-any.whl (7.3 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.2.3.tar.gz.

File metadata

  • Download URL: lightray-0.2.3.tar.gz
  • Upload date:
  • Size: 5.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.2.3.tar.gz
Algorithm Hash digest
SHA256 2f7332e434466a1b0a8bf8c526730411217a94a21f1183618fbdcced558f942e
MD5 f868c5139ba96a4c25d0e285fa9244e8
BLAKE2b-256 38a4181ea006c70575a3a1d63997d898e358649018ef03ee0dcedfcd66d1dadd

See more details on using hashes here.

File details

Details for the file lightray-0.2.3-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.2.3-py3-none-any.whl
  • Upload date:
  • Size: 7.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.2.3-py3-none-any.whl
Algorithm Hash digest
SHA256 387d156d307737203c456ded1f6e37a90b1850b0c0e1ca3a3bde53bfa7cd8f9d
MD5 ca2f4b1150f0097e9dfb072fba35aef1
BLAKE2b-256 1b208ec9305c46d3337d56b890176692324428800b6968c6a55f2d6c909fd806

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page