Skip to main content

Instantiate Python objects with Jsonnet

Project description

cresset

cresset makes it super simple to construct Pytorch modules directly from config files. You no longer have to rely on the use of the unsafe pickle module which can run arbitrary code. Rather, cresset builds on the power of confactory to enable loading the full model definition directly from JSON files.

Simple Example

Here's a simple example demonstrating how to create the same model from the Pytorch Quickstart Tutorial. First, write a quickstart.json file with the following structure.

{
  "type": "Sequential",
  "args": [
    { "type": "Linear", "in_features": 784, "out_features": 512 },
    { "type": "ReLU" },
    { "type": "Linear", "in_features": 512, "out_features": 512 },
    { "type": "ReLU" },
    { "type": "Linear", "in_features": 512, "out_features": 10 }
  ]
}

Now in Python you can build the Module directly from the config file:

from cresset.nn import Module

model = Module.from_config("quickstart.json")
print(model)

A More Complex Example

Large models have lots of layers and might even have conditional configuration options. Using JSON directly would be combersome for defining such models. Fortunately, cresset supports Jsonnet, making configuration a breeze. Below you can see a much more complicated example of creating a Transformer-like model from builtin Pytorch modules. First create a file named transformer-like.jsonnet with the following configuration:

# An attention-less Transformer-like model (only performs feedforward ops)
function (vocab_size, heads=8, dim=512, hidden_dim=2048, layers=6, dropout_p=0.1) {
  # Make a reference to this transformer so nested objects can reference the top-level helpers
  local transformer = self,

  # Define a number of hidden helpers that can be overridden by callers of this function
  "dropout":: function(module, name="module") {
    "type": "Sequential",
    "args": {
      [name]: module,
      "dropout": { "type": "Dropout", "p": dropout_p }
    }
  },

  "residual":: function(name, module) {
    "type": "Sequential",
    "args": {
      "residual": {
        "type": "Parallel",
        "args": [
          { "type": "Identity" },
          transformer.dropout(module, name=name)
        ]
      },
      "sum": { "type": "Reduce", "op": "SUM" },
      "norm": { "type": "LayerNorm", "normalized_shape": dim }
    }
  },

  "ffn":: {
    "type": "Sequential",
    "args": {
      "hidden": { "type": "Linear", "in_features": dim, "out_features": hidden_dim },
      "activation": { "type": "ReLU", "inplace": true },
      "output": { "type": "Linear", "in_features": hidden_dim, "out_features": dim }
    }
  },

  "type": "Sequential",
  "args": [{
    "type": "Embedding",
    "embedding_dim": dim,
    "num_embeddings": vocab_size,
  }] + std.repeat([transformer.residual("ffn", transformer.ffn)], layers),
}

The top-level Jsonnet function has one required parameter, vocab_size, which must be specified. This can be done using the bindings parameter of Module.from_config. See the example Python below:

from cresset.nn import Module

# vocab_size is required
Module.from_config("transformer-like.jsonnet", bindings={"vocab_size": 65535})

# can override default values, e.g. disable dropout for inference
Module.from_config("transformer-like.jsonnet", bindings={"vocab_size": 65535, "dropout_p": 0})

Additional Features

In addition to support for Pytorch's torch.nn.Module, cresset also wraps torch.optim.Optimizer and torch.optim.lr_scheduler.LRScheduler, allowing for defining training configurations. Furthermore, it creates wrappers for the initialization functions in torch.nn.init, such that they can be used to initialize your model. For example, the below code will run a set of initializers on your model parameters which match a regular expression.

import logging
import warnings
from typing import List, Set

import attr

from confactory import configurable
from cresset.nn import Sequential
from cresset.nn.init import Initializer, Zeros

class UninitializedParameters(Warning):
    """A warning indicating an uninitialized parameter"""

    def __init__(self, parameters: List[str]):
        super().__init__(", ".join(parameters))

@configurable
class ModelWithInitializers(Sequential):
    """A sequence model"""

    initializers: List[Initializer] = attr.ib(
        # The mypy attrs plugin has a bug, so we have to ignore the arg type
        # issue below, see https://github.com/python/mypy/issues/5313
        default=[Zeros(regex=r".*")],  # type: ignore[arg-type]
        validator=attr.validators.deep_iterable(
            attr.validators.instance_of(Initializer)
        ),
    )
    warn_uninitialized_params: bool = False
    log_initializers: bool = False

    def __attrs_post_init__(self):
        matched: Set[str] = set()
        all_params: Set[str] = set()
        for name, parameter in self.named_parameters():
            all_params.add(name)
            for initializer in self.initializers:
                if initializer.match(name):
                    if self.log_initializers:
                        logging.info(f"Running initializer {initializer} for {name}")
                    initializer(parameter)
                    matched.add(name)

        uninitialized = all_params - matched
        if self.warn_uninitialized_params and uninitialized:
            warnings.warn(UninitializedParameters(sorted(uninitialized)))

Now if we use the previously defined Transformer-like model as a base, we can create an initialized version. Name the new Jsonnet file transformer-like-initialized.jsonnet with the following contents:

# Imports
local transformer = import "transformer-like.jsonnet";

function(vocab_size, log=false, warn_uninitialized=false)
  transformer(vocab_size, layers=1) {
    "type": "ModelWithInitializers",
    "initializers": [
      {"type": "Zeros", "regex": "bias"},
      {"type": "Ones", "regex": "norm.weight"},
      {"type": "XavierUniform", "regex": "ffn.hidden.weight", "gain": "linear"},
      {"type": "XavierUniform", "regex": "ffn.output.weight", "gain": "relu"}
    ],
    "log_initializers": log,
    "warn_uninitialized_params": warn_uninitialized,
  };

Now we can load it similarly to the previous model, but this time the initializers will have been run:

from cresset.nn import Module

logging.basicConfig(level=logging.INFO)
Module.from_config("transformer-like-initialized.jsonnet", bindings={"vocab_size": 65535, log: True})

After running the above code, you should see output like this:

INFO:root:Running initializer Ones(regex=re.compile('norm.weight')) for 0.norm.weight
INFO:root:Running initializer Zeros(regex=re.compile('bias')) for 0.norm.bias
INFO:root:Running initializer XavierUniform(regex=re.compile('ffn.hidden.weight'), gain=1) for 0.residual.1.ffn.hidden.weight
INFO:root:Running initializer Zeros(regex=re.compile('bias')) for 0.residual.1.ffn.hidden.bias
INFO:root:Running initializer XavierUniform(regex=re.compile('ffn.output.weight'), gain=1.4142135623730951) for 0.residual.1.ffn.output.weight
INFO:root:Running initializer Zeros(regex=re.compile('bias')) for 0.residual.1.ffn.output.bias

License

This repository is licensed under the MIT license.

SPDX-License-Identifer: MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cresset-0.1.1.tar.gz (9.9 kB view details)

Uploaded Source

Built Distribution

cresset-0.1.1-py3-none-any.whl (12.5 kB view details)

Uploaded Python 3

File details

Details for the file cresset-0.1.1.tar.gz.

File metadata

  • Download URL: cresset-0.1.1.tar.gz
  • Upload date:
  • Size: 9.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for cresset-0.1.1.tar.gz
Algorithm Hash digest
SHA256 1f4c1ca13a9eb3d9ebe478f6232a10155b0adf5c450178d2a69dd440f068b76b
MD5 6d446e1c6c399176515419a8911dd5d2
BLAKE2b-256 7c31e59072cadce5584e9d39241508a1683a2e5241044c2dccff8145ca18d4c7

See more details on using hashes here.

File details

Details for the file cresset-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: cresset-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 12.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.10.13

File hashes

Hashes for cresset-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 ab20c5b103ec2764e926a572cdfd0f94e088b43863275bc0291f5802082d3d15
MD5 36e7a04c4ed6aab515dbd19a8135b8b0
BLAKE2b-256 6cf832bcfb4264926233a3f89f999f5cf0870ca25742a02f02ef1464c091fd53

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page