A Pytorch implementation of Poisson Identifiable VAE (pi-VAE), a variational auto encoder used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables.
Project description
Poisson Identifiable VAE (pi-VAE)
This is a Pytorch implementation of Poisson Identifiable VAE (pi-VAE), used to construct latent variable models of neural activity while simultaneously modeling the relation between the latent and task variables (non-neural variables, e.g. sensory, motor, and other externally observable states).
The original implementation by Dr. Ding Zhou and Dr. Xue-Xin Wei in Tensorflow 1.13 is available here.
Another Pytorch implementation by Dr. Lyndon Duong is available here.
Install
pip install pi-vae-pytorch
Usage
import torch
from pi_vae_pytorch import PiVAE
model = PiVAE(
x_dim = 100,
u_dim = 3,
z_dim = 2,
discrete_labels=False
)
x = torch.randn(1, 100) # Size([n_samples, x_dim])
u = torch.randn(1, 3) # Size([n_samples, u_dim])
outputs = model(x, u) # dict
Parameters
-
x_dim
: int
Dimension of observationx
-
u_dim
: int
Dimension of labelu
-
z_dim
: int
Dimension of latentz
-
discrete_labels
: bool- Default:
True
Flag denoting
u
's label type -True
: discrete orFalse
: continuous. - Default:
-
encoder_n_hidden_layers
: int- Default:
2
Number of hidden layers in the MLP of the model's encoder.
- Default:
-
encoder_hidden_layer_dim
: int- Default:
120
Dimensionality of each hidden layer in the MLP of the model's encoder.
- Default:
-
encoder_hidden_layer_activation
: nn.Module- Default:
nn.Tanh
Activation function applied to the outputs of each hidden layer in the MLP of the model's encoder.
- Default:
-
decoder_n_gin_blocks
: int- Default:
2
Number of GIN blocks used within the model's decoder.
- Default:
-
decoder_gin_block_depth
: int- Default:
2
Number of AffineCouplingLayers which comprise each GIN block.
- Default:
-
decoder_affine_input_layer_slice_dim
: int- Default None (corresponds to
x_dim / 2
)
Index at which to split an n-dimensional input x.
- Default None (corresponds to
-
decoder_affine_n_hidden_layers
: int- Default:
2
Number of hidden layers in the MLP of the model's encoder.
- Default:
-
decoder_affine_hidden_layer_dim
: int- Default:
None
(corresponds tox_dim / 4
)
Dimensionality of each hidden layer in the MLP of each AffineCouplingLayer.
- Default:
-
decoder_affine_hidden_layer_activation
: nn.Module- Default:
nn.ReLU
Activation function applied to the outputs of each hidden layer in the MLP of each AffineCouplingLayer.
- Default:
-
decoder_nflow_n_hidden_layers
: int- Default:
2
Number of hidden layers in the MLP of the decoder's NFlowLayer.
- Default:
-
decoder_nflow_hidden_layer_dim
: int- Default:
None
(corresponds tox_dim / 4
)
Dimensionality of each hidden layer in the MLP of the decoder's NFlowLayer.
- Default:
-
decoder_nflow_hidden_layer_activation
: nn.Module- Default:
nn.ReLU
Activation function applied to the outputs of each hidden layer in the MLP of the decoder's NFlowLayer.
- Default:
-
decoder_observation_model
: str- Default:
poisson
- One of
gaussian
orpoisson
Observation model used by the model's decoder.
- Default:
-
decoder_fr_clamp_min
: float- Default:
1E-7
- Only applied when
decoder_observation_model="poisson"
Mininimum threshold used when clamping decoded firing rates.
- Default:
-
decoder_fr_clamp_max
: float- Default:
1E7
- Only applied when
decoder_observation_model="poisson"
Maximum threshold used when clamping decoded firing rates.
- Default:
-
z_prior_n_hidden_layers
: int- Default:
2
- Only applied when
discrete_labels=False
Number of hidden layers in the MLP of the ZPriorContinuous module.
- Default:
-
z_prior_hidden_layer_dim
: int- Default:
20
- Only applied when
discrete_labels=False
Dimensionality of each hidden layer in the MLP of the ZPriorContinuous module.
- Default:
-
z_prior_hidden_layer_activation
: nn.Module- Default:
nn.Tanh
- Only applied when
discrete_labels=False
Activation function applied to the outputs of each hidden layer in the MLP of the decoder's ZPriorContinuous module.
- Default:
Returns
A dicitonary with the following items.
-
firing_rate
: Tensor- Size([n_samples, x_dim])
Predicted firing rates of
z_sample
. -
lambda_mean
: Tensor- Size([n_samples, z_dim])
Mean for each sample using label prior p(z | u).
-
lambda_log_variance
: Tensor- Size([n_samples, z_dim])
Log of variance for each sample using label prior p(z | u).
-
posterior_mean
: Tensor- Size([n_samples, z_dim])
Mean for each sample using full posterior of q(z | x,u) ~ q(z | x) × p(z | u).
-
posterior_log_variance
: Tensor- Size([n_samples, z_dim])
Log of variance for each sample using full posterior of q(z | x,u) ~ q(z | x) × p(z | u).
-
z_mean
: Tensor- Size([n_samples, z_dim])
Mean for each sample using approximation of q(z | x).
-
z_log_variance
: Tensor- Size([n_samples, z_dim])
Log of variance for each sample using approximation of q(z | x).
-
z_sample
: Tensor- Size([n_samples, z_dim])
Generated latents
z
.
Loss Function
Poisson observation model
from pi_vae_pytorch.utils import compute_loss
outputs = model(x, u) # Initialized with decoder_observation_model="poisson"
loss = compute_loss(
x=x,
firing_rate=outputs["firing_rate"],
lambda_mean=outputs["lambda_mean"],
lambda_log_variance=outputs["lambda_log_variance"],
posterior_mean=outputs["posterior_mean"],
posterior_log_variance=outputs["posterior_log_variance"],
observation_model=model.decoder_observation_model
)
loss.backward()
Gaussian observation model
from pi_vae_pytorch.utils import compute_loss
outputs = model(x, u) # Initialized with decoder_observation_model="gaussian"
loss = compute_loss(
x=x,
firing_rate=outputs["firing_rate"],
lambda_mean=outputs["lambda_mean"],
lambda_log_variance=outputs["lambda_log_variance"],
posterior_mean=outputs["posterior_mean"],
posterior_log_variance=outputs["posterior_log_variance"],
observation_model=model.decoder_observation_model,
observation_noise_model=model.observation_noise_model
)
loss.backward()
Parameters
-
x
: Tensor- Size([n_samples, x_dim])
Observations
x
. -
firing_rate
: Tensor- Size([n_samples, x_dim])
Predicted firing rate of generated latent
z
. -
lambda_mean
: Tensor- Size([n_samples, z_dim])
Means from label prior p(z | u).
-
lambda_log_variance
: Tensor- Size([n_samples, z_dim])
Log of variances from label prior p(z | u).
-
posterior_mean
: Tensor- Size([n_samples. z_dim])
Means from full posterior of q(z | x,u) ~ q(z | x) × p(z | u).
-
posterior_log_variance
: Tensor- Size([n_samples. z_dim])
Log of variances from full posterior of q(z | x,u) ~ q(z | x) × p(z | u).
-
observation_model
: str- One of
poisson
orgaussian
- Should use the same value passed to
decoder_observation_model
when initializingPiVAE
.
The observation model used by pi-VAE's decoder.
- One of
-
observation_noise_model
: nn.Module- Default: None
- Only applied when
observation model="gaussian"
The noise model used when pi-VAE's decoder utilizes a Gaussian observation model. When
PiVAE
is initialized withdecoder_observation_model="gaussian"
, the model'sobservation_noise_model
attribute can be used.
Citation
@misc{zhou2020learning,
title={Learning identifiable and interpretable latent models of high-dimensional neural activity using pi-VAE},
author={Ding Zhou and Xue-Xin Wei},
year={2020},
eprint={2011.04798},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for pi_vae_pytorch-1.0.0b3-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9cf0242b1ccbafb3b2adbf4425e8d3b132b5a97a7c58a15a214b6a241397e8ac |
|
MD5 | 5933f0e9cda136e77ebb07ef36dd40f4 |
|
BLAKE2b-256 | 72c73e02ca699782846364fb37c6bf2a3798dafeecd6ebf750b226ee33d8bbcd |