Skip to main content

Distribute a LightningCLI hyperparameter search with Ray Tune

Project description

lightray

Easily integrate a LightningCLI with RayTune hyperparameter optimization

Getting Started

Extend your custom LightningCLI to a Trainable compatible with RayTune.

Imagine you have a LightningCLI parser for a

from lightning.pytorch as pl

class CustomCLI(pl.cli.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments(
            "data.init_args.parameter", 
            "model.init_args.parameter", 
            apply_on="parse"
        )

class DataModule(pl.LightningDataModule):
    def __init__(self, hidden_dim: int, learning_rate: float, parameter: int):
        self.parameter = parameter
        self.hidden_dim = hidden_dim
        self.learning_rate = learning_rate
    
    def train_dataloader(self):
        ...

class DataModule(pl.LightningModule):
    def __init__(self, parameter: int):
        self.parameter = parameter
    
    def training_step(self):
        ...

Launching a hyperparameter tuning job with RayTune using this LightningCLI is as simple as

from ray import tune
from raylightning import run

# define search space
search_space = {
    "model.init_args.hidden_dim": tune.choice([32, 64])
    "model.init_args.learning_rate": tune.loguniform(1e-4, 1e-2)
}

# define scheduler
scheduler = ASHAScheduler(max_t=50, grace_period=1, reduction_factor=2)

# pass command line style arguments as if you were
# running your `LightningCLI` from the command line
args = ["--config", "/path/to/config.yaml"]
results = run(
    args=args
    cli_cls=simple_cli,
    name="tune-test",
    metric="val_loss",
    objective="min",
    search_space=search_space,
    scheduler=scheduler,
    storage_dir=storage_dir,
    address=None,
    num_samples=num_samples,
    workers_per_trial=1,
    gpus_per_worker=1.0,
    cpus_per_gpu=4.0,
    temp_dir=None,
)

s3 storage works out of the box. Make sure you have set the AWS_ENDPOINT_URL, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY environment variables set. Then, simply pass an s3 path (e.g. s3://{bucket}/{folder} to the storage_dir argument.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

lightray-0.1.0.tar.gz (6.9 kB view details)

Uploaded Source

Built Distribution

lightray-0.1.0-py3-none-any.whl (8.3 kB view details)

Uploaded Python 3

File details

Details for the file lightray-0.1.0.tar.gz.

File metadata

  • Download URL: lightray-0.1.0.tar.gz
  • Upload date:
  • Size: 6.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.0.tar.gz
Algorithm Hash digest
SHA256 90a57daf381cb8849a81485722bebb74712222fe885c7e1e38abcdd87a98a465
MD5 110926e1df6b2f9651922dbfe78eaf94
BLAKE2b-256 bfc8c17bb62722cca83e4d02f1bebea06ca8909bf55eca13c25d6332e5c70391

See more details on using hashes here.

File details

Details for the file lightray-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: lightray-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 8.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.8.3 CPython/3.10.9 Linux/6.5.0-1025-azure

File hashes

Hashes for lightray-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 194f64601dfae3818c7ebc138e8845a5f2bb804bf99593e5a730a30d46d4c321
MD5 31765a125f6331467e3d75a96963d9e5
BLAKE2b-256 b7ca5ba9d734bdefab3e7b97080cc2c4b97f0bc02e3b82d59c82c931fc086ac7

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page