Skip to main content

Statistical distributions for structural biology

Project description

rs-distributions

Documentation Build PyPI - Version PyPI - Python Version


Table of Contents

rs-distributions provides statistical tools which are helpful for structural biologists who wish to model their data using variational inference.

Installation

pip install rs-distributions

Distributions

rs_distributions.distributions provides learnable distributions that are important in structural biology. These distributions follow the conventions in torch.dist. Here's a small example of distribution matching between a learnable distribution, q, and a target distribion, p. The example works by minimizing the Kullback-Leibler divergence between q and p using gradients calculated by the implicit reparameterization method.

import torch
from rs_distributions import distributions as rsd

target_loc = 4.
target_scale = 2.

loc_initial_guess = 10.
scale_initial_guess  = 3.

loc = torch.tensor(loc_initial_guess, requires_grad=True)

scale_transform = torch.distributions.transform_to(
    rsd.FoldedNormal.arg_constraints['scale']
)
scale_initial_guess = scale_transform.inv(
    torch.tensor(scale_initial_guess)
)
unconstrained_scale = torch.tensor(
    torch.tensor(scale_initial_guess),
    requires_grad=True
)

p = rsd.FoldedNormal(
    target_loc,
    target_scale,
)

opt = torch.optim.Adam([loc, unconstrained_scale])

steps = 10_000
num_samples = 100
for i in range(steps):
    opt.zero_grad()
    scale = scale_transform(unconstrained_scale)
    q = rsd.FoldedNormal(loc, scale)
    z = q.sample((num_samples,))
    kl_div = q.log_prob(z) - p.log_prob(z)
    kl_div = kl_div.mean()
    kl_div.backward()
    opt.step()

This example uses the folded normal distribution which is important in X-ray crystallography.

Modules

Working with PyTorch distributions can be a little verbose. So in addition to the torch.distributions style implementation, we provide DistributionModule classes which enable learnable distributions with automatic bijections in less code. These DistributionModule classes are subclasses of torch.nn.Module. They automatically instantiate problem parameters as TransformedParameter modules following the constraints in the distribution definition. In the following example, a FoldedNormal DistributionModule is instantiated with an initial location and scale and trained to match a target distribution.

from rs_distributions import modules as rsm
import torch

loc_init = 10.
scale_init = 5.

q = rsm.FoldedNormal(loc_init, scale_init)
p = torch.distributions.HalfNormal(1.)

opt = torch.optim.Adam(q.parameters())

steps = 10_000
num_samples = 256
for i in range(steps):
    opt.zero_grad()
    z = q.rsample((num_samples,))
    kl = (q.log_prob(z) - p.log_prob(z)).mean()
    kl.backward()
    opt.step()

License

rs-distributions is distributed under the terms of the MIT license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

rs_distributions-0.0.2.tar.gz (11.7 kB view details)

Uploaded Source

Built Distribution

rs_distributions-0.0.2-py3-none-any.whl (8.7 kB view details)

Uploaded Python 3

File details

Details for the file rs_distributions-0.0.2.tar.gz.

File metadata

  • Download URL: rs_distributions-0.0.2.tar.gz
  • Upload date:
  • Size: 11.7 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/5.0.0 CPython/3.12.3

File hashes

Hashes for rs_distributions-0.0.2.tar.gz
Algorithm Hash digest
SHA256 6b21af7114c99c6349a69f70fcefbbad9197532dda9db51f22eca9070e50b1dd
MD5 92c90013ee89d9bc7ef6b3b571e5a53c
BLAKE2b-256 0bb6361a826ad4505f7e1c26234be03a857737245603b76a2708e8dd7465728b

See more details on using hashes here.

File details

Details for the file rs_distributions-0.0.2-py3-none-any.whl.

File metadata

File hashes

Hashes for rs_distributions-0.0.2-py3-none-any.whl
Algorithm Hash digest
SHA256 86f874606c39660470f90a05e2c9e011fc4380ccbac859224712bad69875c13a
MD5 49a446ddb6e1f259445337f41ce4cebb
BLAKE2b-256 df5f901e8c908a5c9bdb235f1ab81cb61eb809f3e044fb7dfd8fd5b1bf2e4500

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page