Skip to main content

Select module classes and functions using yaml, without any if-statements.

Project description

easy_module_attribute_getter

Installation

pip install easy_module_attribute_getter

The Problem: unmaintainable if-statements and switches

It's common to specify script parameters in yaml config files. For example:

models:
  modelA:
    densenet121:
      pretrained: True
      memory_efficient: True
  modelB:
    resnext50_32x4d:
      pretrained: True

losses:
  lossA:
    CrossEntropyLoss:
  lossB:
    L1Loss:

Usually, the config file is loaded and then various if-statements or switches are used to instantiate objects etc:

if args.models["modelA"] == "densenet121":
  modelA = torchvision.models.densenet121(pretrained = args.pretrained)
elif args.models["modelA"] == "googlenet":
  modelA = torchvision.models.googlenet(pretrained = args.pretrained)
elif args.models["modelA"] == "resnet50":
  modelA = torchvision.models.resnet50(pretrained = args.pretrained)
elif args.models["modelA"] == "inception_v3":
  modelA = torchvision.models.inception_v3(pretrained = args.pretrained)
...
if args.losses["lossA"] == "CrossEntropyLoss":
  lossA = torch.nn.CrossEntropyLoss()
elif args.losses["lossA"] == "L1Loss":
  lossA = torch.nn.L1Loss()
...

The Solution

Use this package, and get rid of all those annoying if-statements and switches:

from easy_module_attribute_getter import PytorchGetter
pytorch_getter = PytorchGetter()
models = pytorch_getter.get_multiple("model", args.models)
losses = pytorch_getter.get_multiple("loss", args.losses)

"models" and "losses" are dictionaries that map from strings to the desired objects.

Load one or multiple yaml files into one args object

from easy_module_attribute_getter import YamlReader
yaml_reader = YamlReader()
args, _, _ = yaml_reader.load_yamls(['models.yaml'])

Provide a list of filepaths:

args, _, _ = yaml_reader.load_yamls(['models.yaml', 'optimizers.yaml', 'transforms.yaml'])

Or provide a root path and a dictionary mapping subfolder names to the bare filename

root_path = "/where/your/yaml/subfolders/are/"
subfolder_to_name_dict = {"models": "default", "optimizers": "special_trial", "transforms": "blah"}
args, _, _ = yaml_reader.load_yamls(root_path=root_path, subfolder_to_name_dict=subfolder_to_name_dict)

Merge or override complex config options via the command line:

The example yaml file contains 'models' which maps to a nested dictionary containing modelA and modelB. It's easy to add another key to models at the command line, using the standard python notation for nested dictionaries.

python example.py --models {modelC: {googlenet: {pretrained: True}}}

Then in your script:

import argparse
yaml_reader = YamlReader(argparse.ArgumentParser())
args, _, _ = yaml_reader.load_yamls(['models.yaml', 'losses.yaml'], max_merge_depth=1)

Now args.models contains 3 models.

If in general you'd like to merge config options, then in the load_yamls function, set the max_merge_depth argument to the number of sub-dictionaries you'd like the merge to apply to.

What if you have max_merge_depth set to 1, but want to do a total override for a particular flag? In that case, just append ~OVERRIDE~ to the flag:

python example.py --models~OVERRIDE~ {modelC: {googlenet: {pretrained: True}}}

Now args.models will contain just modelC, even though max_merge_depth is set to 1.

Easily register your own modules into an existing getter.

from pytorch_metric_learning import losses, miners, samplers 
pytorch_getter = PytorchGetter()
pytorch_getter.register('loss', losses) 
pytorch_getter.register('miner', miners)
pytorch_getter.register('sampler', samplers)
metric_loss = pytorch_getter.get('loss', class_name='ProxyNCALoss', return_uninitialized=True)
kl_div_loss = pytorch_getter.get('loss', class_name='KLDivLoss', return_uninitialized=True)

In the above example, the 'loss' key already exists, so the 'losses' module will be appended to the existing module.

Pytorch-specific features

Transforms

Specify transforms in your config file:

transforms:
  train:
    Resize:
      size: 256
    RandomResizedCrop:
      scale: 0.16 1
      ratio: 0.75 1.33
      size: 227
    RandomHorizontalFlip:
      p: 0.5

  eval:
    Resize:
      size: 256
    CenterCrop:
      size: 227

Then load composed transforms in your script:

transforms = {}
for k, v in args.transforms.items():
    transforms[k] = pytorch_getter.get_composed_img_transform(v, mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])

The transforms dict now contains:

{'train': Compose(
    Resize(size=256, interpolation=PIL.Image.BILINEAR)
    RandomResizedCrop(size=(227, 227), scale=(0.16, 1), ratio=(0.75, 1.33), interpolation=PIL.Image.BILINEAR)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
), 'eval': Compose(
    Resize(size=256, interpolation=PIL.Image.BILINEAR)
    CenterCrop(size=(227, 227))
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)}

Optimizers, schedulers, and gradient clippers

Optionally specify the scheduler and gradient clipping norm, within the optimizer parameters.

optimizers:
  modelA:
    Adam:
      lr: 0.00001
      weight_decay: 0.00005
      scheduler:
        StepLR:
          step_size: 2
          gamma: 0.95
      clip_grad_norm: 1
  modelB:
    RMSprop:
      lr: 0.00001
      weight_decay: 0.00005

Create the optimizers:

optimizers = {}
schedulers = {}
grad_clippers = {}
for k, v in models.items():
	optimizers[k], schedulers[k], grad_clippers[k] = pytorch_getter.get_optimizer(v, yaml_dict=args.optimizers[k])

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easy_module_attribute_getter-0.9.20.tar.gz (6.6 kB view details)

Uploaded Source

Built Distribution

easy_module_attribute_getter-0.9.20-py3-none-any.whl (9.1 kB view details)

Uploaded Python 3

File details

Details for the file easy_module_attribute_getter-0.9.20.tar.gz.

File metadata

  • Download URL: easy_module_attribute_getter-0.9.20.tar.gz
  • Upload date:
  • Size: 6.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0.post20191030 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for easy_module_attribute_getter-0.9.20.tar.gz
Algorithm Hash digest
SHA256 99c995992f49a6b6ccc7b48fd9edb5111fa34a028f195f396ed7d0ee1ee6eb4b
MD5 97b472c9130e92242105274c87d1f813
BLAKE2b-256 24797672c8fe06ecbe00c993a66688783c6ff76d80a1eb211a24a647c27e973f

See more details on using hashes here.

File details

Details for the file easy_module_attribute_getter-0.9.20-py3-none-any.whl.

File metadata

  • Download URL: easy_module_attribute_getter-0.9.20-py3-none-any.whl
  • Upload date:
  • Size: 9.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/2.0.0 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0.post20191030 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for easy_module_attribute_getter-0.9.20-py3-none-any.whl
Algorithm Hash digest
SHA256 866b7f10286b73d4ce9880e940c76a9a13617abb0acad43cfda47a32f1688b44
MD5 85ee342c318088ef1b6725a2cd3de8b1
BLAKE2b-256 33555d2a8b53dc933d9c8e1e14bd12c22210e2945728aaea80e8ec4814c47049

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page