Skip to main content

PyTorch implementation of the Ricciardi transfer function.

Project description

tests workflow status codecov

About

An efficient, GPU-friendly, and differentiable PyTorch implementation of the Ricciardi transfer function based on equations and default parameters from Sanzeni et al. (2020).

Plot of ricciardi transfer function

Usage

For using the ricciardi function in your own code, you can either just copy the source file at src/ricciardi/ricciardi.py to your own code, or install the package in your python environment with pip install ricciardi and import the function with from ricciardi import ricciardi. To run tests, clone the repository, create a new environment, install the neccessary packages with pip install -r requirements, and run the command pytest.

Implementation

The Ricciardi transfer function, in the notation of Sanzeni et al. (2020), is given by

$$ f(\mu) = \left[\tau_{rp} + \tau\sqrt{\pi}\int_{u_\mathrm{min}(\mu)}^{u_\mathrm{max}(\mu)}e^{u^2}(1+\mathrm{erf}(u)) du\right]^{-1} $$

where

$$ u_\mathrm{max}(\mu) = \frac{\theta - \mu}{\sigma}, u_\mathrm{min}(\mu) = \frac{V_r - \mu}{\sigma} $$

The integral can be written in terms of the hypergeometric function ${}_2F_2$. However, there is currently no implementation of this hypergeometric function that is performant enough for large neural network simulations. Thus we take the approach of directly computing the integral with a fixed order Gauss-Legendre quadrature rule. We find that an order 5 quadrature is sufficient to obtain good numerical accuracy for realistic parameter regimes.

A note on the computation of the integral

Direct computation of $e^{x^2}(1 + \mathrm{erf}(x))$ results in numerical issues for large, negative $x$ since the first term is huge while the second term is tiny. To address this, we note that since $1 + \mathrm{erf}(x) = 1 - \mathrm{erf}(-x)$, we can rewrite the integral as

$$ f(\mu) = \left[\tau_{rp} + \tau\sqrt{\pi}\int_{-u_\mathrm{max}(\mu)}^{-u_\mathrm{min}(\mu)} \mathrm{erfcx}(u) du\right]^{-1} $$

where $\mathrm{erfcx}$ is the scaled complementary error function defined by

$$ \mathrm{erfcx}(x) = e^{x^2}(1 - \mathrm{erf}(x)) $$

$\mathrm{erfcx}$ is a native PyTorch function which has high precision for a wide range of inputs, so by using it we avoid the numerical issue mentioned above.

Benchmark

Compare performance with a naive, linear interpolation-based approach. Forward pass is slightly faster, and backward pass is much faster (>2x on GPU).

Results on CPU (AMD EPYC 7662, 8 cores) (python benchmark/benchmark.py -N 100000 -r 100):

forward pass, requires_grad=False
ricciardi: median=1.81 ms, min=1.79 ms (100 repeats)
ricciardi_interp: median=1.91 ms, min=1.9 ms (100 repeats)

forward pass, requires_grad=True
ricciardi: median=1.8 ms, min=1.79 ms (100 repeats)
ricciardi_interp: median=2.11 ms, min=1.98 ms (100 repeats)

backward pass
ricciardi: median=786 μs, min=765 μs (100 repeats)
ricciardi_interp: median=1.17 ms, min=1.09 ms (100 repeats)

Results on GPU (Nvidia A40) (python benchmark/benchmark.py -N 100000 -r 100 --device cuda):

forward pass, requires_grad=False
ricciardi: median=451 μs, min=441 μs (100 repeats)
ricciardi_interp: median=455 μs, min=448 μs (100 repeats)

forward pass, requires_grad=True
ricciardi: median=478 μs, min=470 μs (100 repeats)
ricciardi_interp: median=523 μs, min=513 μs (100 repeats)

backward pass
ricciardi: median=486 μs, min=475 μs (100 repeats)
ricciardi_interp: median=1.1 ms, min=1.08 ms (100 repeats)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

ricciardi-0.1.6.tar.gz (6.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

ricciardi-0.1.6-py3-none-any.whl (6.5 kB view details)

Uploaded Python 3

File details

Details for the file ricciardi-0.1.6.tar.gz.

File metadata

  • Download URL: ricciardi-0.1.6.tar.gz
  • Upload date:
  • Size: 6.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for ricciardi-0.1.6.tar.gz
Algorithm Hash digest
SHA256 d2ac5eaaa6be45dc2727e0ebb5b2d1e51698c35a84b5e8a44f1e083f79ddaf1f
MD5 4134a7353e5fe348eddd2e17b496a09b
BLAKE2b-256 8bffe83556860d3a28506a48f85bc73757f106f2ce11ecb51903dc07cbb2fb3a

See more details on using hashes here.

File details

Details for the file ricciardi-0.1.6-py3-none-any.whl.

File metadata

  • Download URL: ricciardi-0.1.6-py3-none-any.whl
  • Upload date:
  • Size: 6.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for ricciardi-0.1.6-py3-none-any.whl
Algorithm Hash digest
SHA256 3b9326d154141e4c3d803c6a17c8b017420398a667c32062375c922a4bac2cf3
MD5 8a2c6c6e51ba5a3b0b47d678d540d900
BLAKE2b-256 6bd6242ffc29a6e2cfa994a5f27948b25867f0148cabbbb17ee645f53d20d2b7

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page