PyTorch implementation of the Ricciardi transfer function.
Project description
About
An efficient, GPU-friendly, and differentiable PyTorch implementation of the Ricciardi transfer function based on equations and default parameters from Sanzeni et al. (2020).
Usage
For using the ricciardi function in your own code, you can either just copy the source file at src/ricciardi/ricciardi.py
to your own code, or install the package in your python environment with pip install ricciardi
and import the function with from ricciardi import ricciardi
. To run tests, clone the repository, create a new environment, install the neccessary packages with pip install -r requirements
, and run the command pytest
.
Benchmark
Compare performance with an interpolation-based approach. Forward pass is slightly slower, but backward pass is >2x faster on GPU.
Results on CPU (AMD EPYC 7662, 8 cores) (python benchmark/benchmark.py -N 100000 -r 100
):
forward pass, requires_grad=False
ricciardi: median=1.86 ms, min=1.84 ms (100 repeats)
ricciardi_interp: median=1.75 ms, min=1.72 ms (100 repeats)
forward pass, requires_grad=True
ricciardi: median=1.94 ms, min=1.9 ms (100 repeats)
ricciardi_interp: median=1.92 ms, min=1.75 ms (100 repeats)
backward pass
ricciardi: median=814 μs, min=796 μs (100 repeats)
ricciardi_interp: median=1.17 ms, min=1.15 ms (100 repeats)
Results on GPU (Nvidia A40) (python benchmark/benchmark.py -N 100000 -r 100 --device cuda
):
forward pass, requires_grad=False
ricciardi: median=517 μs, min=508 μs (100 repeats)
ricciardi_interp: median=460 μs, min=453 μs (100 repeats)
forward pass, requires_grad=True
ricciardi: median=556 μs, min=549 μs (100 repeats)
ricciardi_interp: median=527 μs, min=520 μs (100 repeats)
backward pass
ricciardi: median=463 μs, min=364 μs (100 repeats)
ricciardi_interp: median=1.11 ms, min=1.09 ms (100 repeats)
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Hashes for ricciardi-0.1.1-py3-none-any.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ee63f396f90f930f3129b9239a9a60ec0682254812ac8a9f84451f71f030dd3d |
|
MD5 | e5db212509273f1ed8e7b7d35d5cffc4 |
|
BLAKE2b-256 | 331ecf6e826a689f00c471a7b9f91bdefdc244fb0d28d6939fbb56b813883fc2 |