Skip to main content

Logit-space logical activation functions for pytorch

Project description

A pytorch extension which provides functions and classes for logit-space operators equivalent to probabilistic Boolean logic-gates AND, OR, and XNOR for independent probabilities.

This provides the activation functions used in our paper:

SC Lowe, R Earle, J d’Eon, T Trappenberg, S Oore (2022). Logical Activation Functions: Logit-space equivalents of Probabilistic Boolean Operators. In Advances in Neural Information Processing Systems, volume 36. doi: 10.48550/arxiv.2110.11940.

For your convenience, we provide a copy of this citation in bibtex format.

Example usage:

from pytorch_logit_logic import actfun_name2factory
from torch import nn


class MLP(nn.Module):
    """
    A multi-layer perceptron which supports higher-dimensional activations.

    Parameters
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    n_layer : int, default=1
        Number of hidden layers.
    hidden_width : int, optional
        Pre-activation width. Default: same as ``in_channels``.
        Note that the actual pre-act width used may differ by rounding to
        the nearest integer that is divisible by the activation function's
        divisor.
    actfun : str, default="ReLU"
        Name of activation function to use.
    actfun_k : int, optional
        Dimensionality of the activation function. Default is the lowest
        ``k`` that the activation function supports, i.e. ``1`` for regular
        1D activation functions like ReLU, and ``2`` for GLU, MaxOut, and
        NAIL_OR.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        n_layer=1,
        hidden_width=None,
        actfun="ReLU",
        actfun_k=None,
    ):
        super().__init__()

        # Create a factory that generates objects that perform this activation
        actfun_factory = actfun_name2factory(actfun, k=actfun_k)
        # Get the divisor and space reduction factors for this activation
        # function. The pre-act needs to be divisible by the divisor, and
        # the activation will change the channel dimension by feature_factor.
        _actfun = actfun_factory()
        divisor = getattr(_actfun, "k", 1)
        feature_factor = getattr(_actfun, "feature_factor", 1)

        if hidden_width is None:
            hidden_width = in_channels

        # Ensure the hidden width is divisible by the divisor
        hidden_width = int(int(round(hidden_width / divisor)) * divisor)

        layers = []
        n_current = in_channels
        for i_layer in range(0, n_layer):
            layer = []
            layer.append(nn.Linear(n_current, hidden_width))
            n_current = hidden_width
            layer.append(actfun_factory())
            n_current = int(round(n_current * feature_factor))
            layers.append(nn.Sequential(*layer))
        self.layers = nn.Sequential(*layers)
        self.classifier = nn.Linear(n_current, out_channels)

    def forward(self, x):
        x = self.layers(x)
        x = self.classifier(x)
        return x


model = MLP(
    in_channels=512,
    out_channels=10,
    n_layer=2,
    actfun="nail_or",
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-logit-logic-1.0.0.tar.gz (11.7 kB view details)

Uploaded Source

Built Distribution

pytorch_logit_logic-1.0.0-py2.py3-none-any.whl (8.4 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file pytorch-logit-logic-1.0.0.tar.gz.

File metadata

  • Download URL: pytorch-logit-logic-1.0.0.tar.gz
  • Upload date:
  • Size: 11.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.1 CPython/3.10.6

File hashes

Hashes for pytorch-logit-logic-1.0.0.tar.gz
Algorithm Hash digest
SHA256 ba8f210f51bdba556c87f4b4f27877c01be8495d370e0939b57a9acaa453a314
MD5 df7c73b2cb1cc62c459e7e8368be6fe8
BLAKE2b-256 11ba1ffeeaeebb05dc1674c0b82a58fdd001e423e75b9a7ff2dc3495c4d018ee

See more details on using hashes here.

File details

Details for the file pytorch_logit_logic-1.0.0-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_logit_logic-1.0.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 bac30deb5342f82a7050682b38c331f3058b654e157e2e62de3ca65bbb3b8595
MD5 f008cc768e6f9359710285ff6118811f
BLAKE2b-256 4a38ac55c73b36c83d71344e7c56a97ab0b6c6ecf448a905fa29fed087c19836

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page