Skip to main content

GPU-accelerated imaginary error function for real PyTorch tensors

Project description

erfi-pytorch

erfi-pytorch provides a forward-only imaginary error function for real PyTorch tensors:

import torch
from erfi_pytorch import erfi

x = torch.linspace(-4, 4, 1_000_000, device="cuda")
y = erfi(x)

The package supports torch.float32 and torch.float64 and preserves tensor shape, dtype, and device. Its pure-PyTorch graph is compatible with torch.compile(fullgraph=True, backend="eager"). Inductor compilation depends on a working platform compiler or Triton installation and is validated separately on supported Linux CUDA environments.

Backends

  • Pure PyTorch: portable CPU and CUDA implementation.
  • Triton: fused path for contiguous NVIDIA CUDA tensors with at least 65,536 elements, when Triton is available.

Windows and systems without Triton automatically use the pure PyTorch path. No CUDA toolkit or native compiler is required.

Installation

pip install erfi-pytorch

For development and reference tests:

pip install -e ".[test]"

On Linux, install the optional Triton dependency if it is not already provided by your PyTorch installation:

pip install -e ".[test,triton]"

Numerical method

For real x, the implementation uses

erfi(x) = exp(x^2) Im(w(x)),

where w is the Faddeeva function. Im(w(x)) is evaluated with a Taylor polynomial near zero and a 100-interval table of low-degree polynomial approximations elsewhere. Near floating-point overflow, the final magnitude is reconstructed in the log domain so representable results are not lost to premature overflow in exp(x^2).

The polynomial coefficients originate from Steven G. Johnson's MIT-licensed Faddeeva implementation. The original license notice is retained in third_party/faddeeva.

The detailed implementation notes are in docs/faddeeva.md.

License

This project is released under the MIT License. The vendored Faddeeva sources and material derived from them retain the original Copyright (c) 2012 Massachusetts Institute of Technology attribution and MIT license notice.

Limitations

  • Inputs must be real torch.float32 or torch.float64 tensors.
  • This release is forward-only. requires_grad=True raises an error.
  • Triton acceleration currently targets NVIDIA CUDA.
  • Windows uses the pure-PyTorch CUDA backend because upstream Triton support is not generally available there.

Benchmark

python benchmarks/benchmark_erfi.py --dtype float32

The benchmark covers powers of two from 2^10 through 2^24 and reports eager PyTorch, compiled PyTorch, eager dispatch, and compiled dispatch. Before timing, it compares the operator against scipy.special.erfi and reports maximum absolute error, maximum and mean relative error, and infinity mismatches. Use --precision-elements to change the comparison sample count or --skip-precision to run timing only.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

erfi_pytorch-0.1.0.tar.gz (102.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

erfi_pytorch-0.1.0-py3-none-any.whl (52.4 kB view details)

Uploaded Python 3

File details

Details for the file erfi_pytorch-0.1.0.tar.gz.

File metadata

  • Download URL: erfi_pytorch-0.1.0.tar.gz
  • Upload date:
  • Size: 102.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.6

File hashes

Hashes for erfi_pytorch-0.1.0.tar.gz
Algorithm Hash digest
SHA256 913336657c1886cb1f04523ca164e8cb45f59cd91fa6f2ea79b71b23beaa3f84
MD5 69b0d35e0512edce6ced7b2f875cf4be
BLAKE2b-256 b0a8389bef9791670ab47d52a0632e8270e0202077400d77fd681aa805c6b36f

See more details on using hashes here.

File details

Details for the file erfi_pytorch-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: erfi_pytorch-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 52.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.6

File hashes

Hashes for erfi_pytorch-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9d4725362585122d144d70591c6962b2dc65bd804251bf19c324664e9285900c
MD5 17483aa42ca3573367d2938ce2bba390
BLAKE2b-256 5b3d4bf96b93c272a080811bb61d7edabea3c7b57e6058fd6a6dc8890bc9709f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page