Skip to main content

Select module classes and functions using yaml, without any if-statements.

Project description

easy-module-attribute-getter

Installation

pip install easy-module-attribute-getter

The Problem: unmaintainable if-statements and dictionaries

It's common to specify script parameters in yaml config files. For example:

models:
  modelA:
    densenet121:
      pretrained: True
      memory_efficient: True
  modelB:
    resnext50_32x4d:
      pretrained: True

Usually, the config file is loaded and then various if-statements or switches are used to instantiate objects etc. It might look something like this (depending on how the config file is organized):

models = {}
for k in ["modelA", "modelB"]:
	model_name = list(args.models[k].keys())[0]
	if model_name == "densenet121":
	  models[k] = torchvision.models.densenet121(**args.models[k][model_name])
	elif model_name == "googlenet":
	  models[k] = torchvision.models.googlenet(**args.models[k][model_name])
	elif model_name == "resnet50":
	  models[k] = torchvision.models.resnet50(**args.models[k][model_name])
	elif model_name == "inception_v3":
	  models[k] = torchvision.models.inception_v3(**args.models[k][model_name])
	...

This is kind of annoying to do, and every time PyTorch adds new classes or functions that you want access to, you need to add new cases to your giant if-statement. An alternative is to make a dictionary:

model_dict = {"densenet121": torchvision.models.densenet121,
                      "googlenet": torchvision.models.googlenet,
                      "resnet50": torchvision.models.resnet50,
                      "inception_v3": torchvision.models.inception_v3
		      ...}
models = {}
for k in ["modelA", "modelB"]:
	model_name = list(args.models[k].keys())[0]
	models[k] = model_dict[model_name](**args.models[k][model_name])

This is shorter than the if statement, but still requires you to manually spell out all the keys and classes. And you still have to update it yourself when the package updates.

The Solution

Fetch and initialize multiple models in one line

With this package, the above for-loop and if-statements get reduced to this:

from easy_module_attribute_getter import PytorchGetter
pytorch_getter = PytorchGetter()
models = pytorch_getter.get_multiple("model", args.models)

"models" is a dictionary that maps from strings ("modelA" and "modelB") to the desired objects, which have already been initialized with the parameters specified in the config file.

Access multiple modules in one line

Say you want access to the default package (torchvision.models), as well as the pretrainedmodels package, and two other custom model modules, X and Y. You can register these:

pytorch_getter.register('model', pretrainedmodels) 
pytorch_getter.register('model', X)
pytorch_getter.register('model', Y)

Now you can still do the 1-liner:

models = pytorch_getter.get_multiple("model", args.models)

And pytorch_getter will try all 4 registered modules until it gets a match.

Automatically have yaml access to new classes

If you upgrade to a new version of PyTorch which has 20 new classes, you don't have to change anything. You automatically have access to all the new classes, and you can specify them in your yaml file.

Merge or override complex config options via the command line:

The example yaml file contains 'models' which maps to a nested dictionary containing modelA and modelB. It's easy to add another key to models at the command line, using the standard python notation for nested dictionaries.

python example.py --models {modelC: {googlenet: {pretrained: True}}}

Then in your script:

import argparse
yaml_reader = YamlReader(argparse.ArgumentParser())
args, _, _ = yaml_reader.load_yamls(['models.yaml', 'losses.yaml'], max_merge_depth=1)

Now args.models contains 3 models.

If in general you'd like to merge config options, then in the load_yamls function, set the max_merge_depth argument to the number of sub-dictionaries you'd like the merge to apply to.

What if you have max_merge_depth set to 1, but want to do a total override for a particular flag? In that case, just append ~OVERRIDE~ to the flag:

python example.py --models~OVERRIDE~ {modelC: {googlenet: {pretrained: True}}}

Now args.models will contain just modelC, even though max_merge_depth is set to 1.

Load one or multiple yaml files into one args object

from easy_module_attribute_getter import YamlReader
yaml_reader = YamlReader()
args, _, _ = yaml_reader.load_yamls(['models.yaml'])

Provide a list of filepaths:

args, _, _ = yaml_reader.load_yamls(['models.yaml', 'optimizers.yaml', 'transforms.yaml'])

Or provide a root path and a dictionary mapping subfolder names to the bare filename

root_path = "/where/your/yaml/subfolders/are/"
subfolder_to_name_dict = {"models": "default", "optimizers": "special_trial", "transforms": "blah"}
args, _, _ = yaml_reader.load_yamls(root_path=root_path, subfolder_to_name_dict=subfolder_to_name_dict)

Pytorch-specific features

Transforms

Specify transforms in your config file:

transforms:
  train:
    Resize:
      size: 256
    RandomResizedCrop:
      scale: 0.16 1
      ratio: 0.75 1.33
      size: 227
    RandomHorizontalFlip:
      p: 0.5

  eval:
    Resize:
      size: 256
    CenterCrop:
      size: 227

Then load composed transforms in your script:

transforms = {}
for k, v in args.transforms.items():
    transforms[k] = pytorch_getter.get_composed_img_transform(v, mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])

The transforms dict now contains:

{'train': Compose(
    Resize(size=256, interpolation=PIL.Image.BILINEAR)
    RandomResizedCrop(size=(227, 227), scale=(0.16, 1), ratio=(0.75, 1.33), interpolation=PIL.Image.BILINEAR)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
), 'eval': Compose(
    Resize(size=256, interpolation=PIL.Image.BILINEAR)
    CenterCrop(size=(227, 227))
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)}

Optimizers, schedulers, and gradient clippers

Optionally specify the scheduler and gradient clipping norm, within the optimizer parameters.

optimizers:
  modelA:
    Adam:
      lr: 0.00001
      weight_decay: 0.00005
      scheduler:
        StepLR:
          step_size: 2
          gamma: 0.95
      clip_grad_norm: 1
  modelB:
    RMSprop:
      lr: 0.00001
      weight_decay: 0.00005

Create the optimizers:

optimizers = {}
schedulers = {}
grad_clippers = {}
for k, v in models.items():
	optimizers[k], schedulers[k], grad_clippers[k] = pytorch_getter.get_optimizer(v, yaml_dict=args.optimizers[k])

Not just for PyTorch

Note that the YamlReader and EasyModuleAttributeGetter classes are totally independent of PyTorch. I wrote the child class PyTorchGetter since that's what I'm using this package for, but the other two classes can be used in general cases and extended for your own purpose.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

easy-module-attribute-getter-0.9.26.tar.gz (8.1 kB view details)

Uploaded Source

Built Distribution

easy_module_attribute_getter-0.9.26-py3-none-any.whl (9.7 kB view details)

Uploaded Python 3

File details

Details for the file easy-module-attribute-getter-0.9.26.tar.gz.

File metadata

  • Download URL: easy-module-attribute-getter-0.9.26.tar.gz
  • Upload date:
  • Size: 8.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0.post20191030 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for easy-module-attribute-getter-0.9.26.tar.gz
Algorithm Hash digest
SHA256 9810f64f524784a635a2e95063237915413b88b9a7ccfad8d7d7c1774bc3eda6
MD5 952bd56afe9cb36387c6f6ce224ddac3
BLAKE2b-256 dc75f9a59d408377f744ab4bdc8df676c9b0ee008221196547120bfc6f3983dc

See more details on using hashes here.

File details

Details for the file easy_module_attribute_getter-0.9.26-py3-none-any.whl.

File metadata

  • Download URL: easy_module_attribute_getter-0.9.26-py3-none-any.whl
  • Upload date:
  • Size: 9.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/41.6.0.post20191030 requests-toolbelt/0.9.1 tqdm/4.38.0 CPython/3.7.5

File hashes

Hashes for easy_module_attribute_getter-0.9.26-py3-none-any.whl
Algorithm Hash digest
SHA256 a79c5542dfc1411d0abd8d9466b059a4e94cac200e61cd0f2c77a11c9f7ab17f
MD5 c5cf8273e4277e02269719b02d081f90
BLAKE2b-256 04a5f8f6213b063f6ec293a8c387a91dff4a32af426c05f6685e643e89c4dd64

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page