Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

A CLI for easily integrating LightningCLI with RayTune hyperparameter optimization

Getting Started

Imagine you have the following lightning DataModule, LightningModule and LightningCLI

from lightning.pytorch as pl

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class LightningModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

To launching a hyperparameter tuning job with RayTune using this LightningCLI can be done by configuring a yaml that looks like the following

# tune.yaml

# ray.tune.TuneConfig
tune_config:
  mode: "min"
  metric: "val_loss"
  scheduler: 
    class_path: ray.tune.schedulers.ASHAScheduler
    init_args:
      max_t: 200
      grace_period: 21
      reduction_factor: 2
  num_samples: 1
  reuse_actors: true

# ray.train.RunConfig
run_config:
  name: "my-first-run"
  storage_path: s3://aframe-test/new-tune/
  failure_config:
    class_path: ray.train.FailureConfig
    init_args:
      max_failures: 1
  checkpoint_config:
    class_path: ray.train.CheckpointConfig
    init_args:
      num_to_keep: 5
      checkpoint_score_attribute: "val_loss"
      checkpoint_score_order: "min"
  verbose: null

# ray.train.SyncConfig
sync_config:
  sync_period: 1000

# ray.init
ray_init:
  address: null
  
# tune.Tune.param_space
param_space:
  model.learning_rate: tune.loguniform(1e-3, 4)

# ray.tune.TuneCallback
tune_callback:
  class_path: lightray.callbacks.LightRayReportCheckpointCallback
  init_args:
    'on': "validation_end"
    checkpoint_every: 10

# resources per trial
cpus_per_trial: 2
gpus_per_trial: 1

# lightning cli
lightning_cli_cls: example.lightning.Cli
lightning_config: /home/ethan.marx/projects/lightray/example/cli.yaml

Then, launch the tuning job

lightray --config tune.yaml 

S3 Support

In addition, there is automatic support for s3 storage. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables. Then, simply pass the path to your bucket with the s3 prefix (e.g. s3://{bucket}/{folder}) to the run_config.storage_path argument.

lightray --config tune.yaml --run_config.storage_path s3://{bucket}/{folder}

There is also a wrapper around the ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback that will do checkpoint reporting with retries to handle transient s3 errors. This is provided at lightray.callbacks.LightRayReportCheckpointCallback

Remote cluster

To connect to a remote ray cluster, pass the ip address and (using port 10001) to ray_init.address:

lightray --config tune.yaml --ray_init.address ray://{ip}:10001

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.2.0.tar.gz (5.7 kB view details)

Uploaded Source

Built Distribution

lightray-0.2.0-py3-none-any.whl (7.1 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.2.0.tar.gz.

File metadata

  • Download URL: lightray-0.2.0.tar.gz
  • Upload date:
  • Size: 5.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.2.0.tar.gz
Algorithm Hash digest
SHA256 5af54354d5ee75e4b489ec62724e2c75ed8b0e665d2d872104ed0f3e3d765a2b
MD5 a81334cc8646fbf0f046aa7f54df975d
BLAKE2b-256 8656530ec9600b3cdc5fec3e4c918191bca8e6fc880bedd1872904604d52fc42

See more details on using hashes here.

File details

Details for the file lightray-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 7.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.4 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 183244e676668c2d92060091f1923deb72daea9a9459b43dc15fa4e44ca8ed5d
MD5 85c81b34dddb153697751734ffd654f9
BLAKE2b-256 02ae579874737219f99e9ee229a92e6f7c63c2ac708e7a60669b1c8140c9b476

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page