Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

A CLI for easily integrating LightningCLI with RayTune hyperparameter optimization

Getting Started

Imagine you have the following lightning DataModule, LightningModule and LightningCLI

from lightning.pytorch as pl

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class LightningModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

To launching a hyperparameter tuning job with RayTune using this LightningCLI can be done by configuring a yaml that looks like the following

# tune.yaml

# ray.tune.TuneConfig
tune_config:
  mode: "min"
  metric: "val_loss"
  scheduler: 
    class_path: ray.tune.schedulers.ASHAScheduler
    init_args:
      max_t: 200
      grace_period: 21
      reduction_factor: 2
  num_samples: 1
  reuse_actors: true

# ray.train.RunConfig
run_config:
  name: "my-first-run"
  storage_path: s3://aframe-test/new-tune/
  failure_config:
    class_path: ray.train.FailureConfig
    init_args:
      max_failures: 1
  checkpoint_config:
    class_path: ray.train.CheckpointConfig
    init_args:
      num_to_keep: 5
      checkpoint_score_attribute: "val_loss"
      checkpoint_score_order: "min"
  verbose: null

# ray.train.SyncConfig
sync_config:
  sync_period: 1000

# ray.init
ray_init:
  address: null
  
# tune.Tune.param_space
param_space:
  model.learning_rate: tune.loguniform(1e-3, 4)

# ray.tune.TuneCallback
tune_callback:
  class_path: lightray.callbacks.LightRayReportCheckpointCallback
  init_args:
    'on': "validation_end"
    checkpoint_every: 10

# resources per trial
cpus_per_trial: 2
gpus_per_trial: 1

# lightning cli class to tune
lightning_cli_cls: example.lightning.Cli

Then, launch the tuning job.

lightray --config tune.yaml -- --config cli.yaml

Any arguments passed after "--" will automatically be forwarded to the specified LightningCLI class as if it were being run from the command line. This way, tuning configuration and training configuration can be kept separate. Also, you can easily utilize existing configuration files compatible with your LightningCLI.

S3 Support

In addition, there is automatic support for s3 storage. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables. Then, simply pass the path to your bucket with the s3 prefix (e.g. s3://{bucket}/{folder}) to the run_config.storage_path argument.

lightray --config tune.yaml --run_config.storage_path s3://{bucket}/{folder}

There is also a wrapper around the ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback that will do checkpoint reporting with retries to handle transient s3 errors. This is provided at lightray.callbacks.LightRayReportCheckpointCallback

Remote cluster

To connect to a remote ray cluster, pass the ip address (using port 10001) to ray_init.address:

lightray --config tune.yaml --ray_init.address ray://{ip}:10001 -- --config cli.yaml

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.2.4.tar.gz (144.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

lightray-0.2.4-py3-none-any.whl (7.2 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.2.4.tar.gz.

File metadata

  • Download URL: lightray-0.2.4.tar.gz
  • Upload date:
  • Size: 144.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.1

File hashes

Hashes for lightray-0.2.4.tar.gz
Algorithm Hash digest
SHA256 aa90faed38f236552cd688ba624e70c1db1ce3d64f0430ada4307e42c7fd8eb1
MD5 a7b6d14981f70ba0d1984abe07759298
BLAKE2b-256 a206e213ba1ffd71da2ea0f0c8b354c90d809273b3f586180441375f18c33fdf

See more details on using hashes here.

File details

Details for the file lightray-0.2.4-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.2.4-py3-none-any.whl
  • Upload date:
  • Size: 7.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.6.1

File hashes

Hashes for lightray-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 c0868ee362f967a86cc57b4af9578b362e51ef09c998603d7ff62218130198fe
MD5 da8a9c4577b02baedb77ebbf92b07edf
BLAKE2b-256 679fc885e222ec4e50fe9f6e3067ee6cc9f8631e3b642c69db83934423043921

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page