Skip to main content

PyTorch package for Structured Neural Networks.

Project description

Structured Neural Networks

Official implementation of the NeurIPS 2023 paper Structured Neural Networks for Density Estimation and Causal Inference.

Introduction

We introduce the Structured Neural Network (StrNN), a network architecture that enforces functional independence relationships between inputs and outputs via weight masking. This is an efficient way to inject prior assumed structure into arbitrary neural networks, which could lead to applications in improved density estimation, generative modelling, and causal inference.

Setup

The StrNN can be installed using pip install strnn.

If you are interested in reproducing our flow experiments, we instead recommend using an editable install. To do so, clone this repository and then run pip install -e . The flow experiments in the paper can be reproduced using torch==2.0.1 with cuda version 11.7.

Quick Start

The StrNN provides a drop-in replacement for a fully connected neural network, allowing it to respect prescribed functional independencies. For example, given an adjacency matrix $A$, we can initialize an StrNN as such:

import numpy as np
import torch
from strnn.models.strNN import StrNN

A = np.array([
    [0, 0, 0, 0],
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [1, 0, 1, 0]
])

out_dim = A.shape[0]
in_dim = A.shape[1]
hid_dim = (50, 50)

strnn = StrNN(in_dim, hid_dim, out_dim, opt_type="greedy", adjacency=A)

x = torch.randn(in_dim)
y = strnn(x)

The most important parameter to the StrNN class here is opt_type, which specifies how the StrNN factorizes the adjacency matrix into weight masks, dictating which paths are zeroed out in the resulting masked neural network. We currently support three factorization options:

  1. Greedy: Algorithm 1 in the paper; an efficient greedy algorithm that approximates the optimization objective (Equation 2) but keeps zero reconstruction loss when it comes to the sparsity pattern of adjacency matrix $A$.
  2. Zuko: a similar approximate algorithm inspired by this repo.
  3. MADE: baseline algorithm from Germain et al that only supports the fully autoregressive structure.

The StrNN can be used to improve density estimation capabilities when an adjacency structure is known for the task, for example when integrated into normalizing flows, as we show in the experiments below.

Examples

Scripts to reproduce all experiments from the paper can be found in the /experiments folder. Here we provide some code examples on how StrNN can be used.

Binary Density Estimation with StrNN

The StrNN can be directly used to perform density estimation on binary or Gaussian data, similar to how MADE was used in Germain et al. In fact, the StrNN density estimator class in models/strNNDensityEstimator.py can be used for density estimation wehenver we have a closed parametric form in mind for the data generating process, although this is seldom the case in practice. The training script and sample experiment configs for this basic density estimator can be found under /experiments/binary_and_gaussian/. Use the following command to reproduce one of the binary experiments from our paper:

python run_binary_experiment.py --experiment_name binary_random_sparse_d15_n2000_StrNN_best --wandb_name strnn

Density Estimation with Structured Normalizing Flows

The StrNN can be integrated into Normalizing Flow architectures for more complex density estimation tasks. We insert the StrNN into autoregressive flows and continuous normalizing flows. It replaces the feed forward networks used to represent the conditioner and flow dynamics, respectively, allowing them to respect known structure. An example to initialize these flow estimators are shown below. Baseline flows are also implemented.

import numpy as np
import torch

from strnn.models.discrete_flows import AutoregressiveFlowFactory
from strnn.models.continuous_flows import ContinuousFlowFactory

A = np.array([
    [0, 0, 0, 0],
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [1, 0, 1, 0]
])

af_config = {
    "input_dim": A.shape[0],
    "adjacency_matrix": A,
    "base_model": "ANF",
    "opt_type": "greedy",
    "opt_args": {},
    "conditioner_type": "strnn",
    "conditioner_hid_dim": [50, 50],
    "conditioner_act_type": "relu",
    "normalizer_type": "umnn",
    "umnn_int_solver": "CC",
    "umnn_int_step": 20,
    "umnn_int_hid_dim": [50, 50],
    "flow_steps": 10,
    "flow_permute": False,
    "n_param_per_var": 25
}

straf = AutoregressiveFlowFactory(af_config).build_flow()

x = torch.randn(1, A.shape[0])
z, jac = straf(x)
x_bar = straf.invert(z)


cnf_config = {...} # See model_config.yaml for values
strcnf = ContinuousFlowFactory(cnf_config).build_flow()
z, jac = strcnf(x)
x_bar = strcnf.invert(z)

See experiments/synthetic_multimodal/config/model_config.yaml for config values used in the paper experiments.

The flow experiments from the paper are reproducible using experiments/synthetic_multimodal/run_experiment_mm.py. After performing a editable install of this repository, we can run:

python run_experiment_mm.py --dataset_name multimodal --model_config straf_best --wandb_name strnn --model_seed 2541  --lr 1e-3 --scheduler plateau

Hyperparameter values / grids are available in the paper appendix. Note that we must modify the adjacency matrix to include the main diagonal for the StrCNF to work well. The CIs are generated using model seeds: [2541 2547 412 411 321 3431 4273 2526].

Citation

Please use the following citation if you use this code or methods in your own work.

@inproceedings{
    chen2023structured,
    title = {Structured Neural Networks for Density Estimation and Causal Inference},
    author = {Asic Q Chen, Ruian Shi, Xiang Gao, Ricardo Baptista, Rahul G Krishnan},
    booktitle = {Thirty-seventh Conference on Neural Information Processing Systems},
    year = {2023},
    url = {https://arxiv.org/abs/2311.02221}
}

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

strnn_hv-1.0.1.tar.gz (29.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

strnn_hv-1.0.1-py3-none-any.whl (38.3 kB view details)

Uploaded Python 3

File details

Details for the file strnn_hv-1.0.1.tar.gz.

File metadata

  • Download URL: strnn_hv-1.0.1.tar.gz
  • Upload date:
  • Size: 29.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.19

File hashes

Hashes for strnn_hv-1.0.1.tar.gz
Algorithm Hash digest
SHA256 6b3de68e918281babfa5364eaca6db8e54118815c931c1a05f7136f20820f8f3
MD5 90e43fb0f54e529b98923eba7c018f1a
BLAKE2b-256 125fdd130689960cc2a48e63afc607cb261ac501fe2805be70b9ad4731800220

See more details on using hashes here.

File details

Details for the file strnn_hv-1.0.1-py3-none-any.whl.

File metadata

  • Download URL: strnn_hv-1.0.1-py3-none-any.whl
  • Upload date:
  • Size: 38.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/5.0.0 CPython/3.9.19

File hashes

Hashes for strnn_hv-1.0.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1d1e17cc5e6027277c15273fb73df800c116dc1e1fc7ece016a50d65e662136f
MD5 de00ebdca7cd80edade3a761e6e8d1be
BLAKE2b-256 5705fa146b0a3fc9869e5487c25709969bed9844cbaa17706ad6b29a8f0fa3cc

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page