Skip to main content

FFT and complex-valued tensor operations for AWS Trainium via NKI

Project description

trnfft

CI PyPI Python License Docs

FFT and complex-valued tensor operations for AWS Trainium via NKI.

Trainium has no native complex number support and ships no FFT library. trnfft fills that gap with split real/imaginary representation, complex neural network layers, and NKI kernels optimized for the NeuronCore architecture.

Incorporates neuron-complex-ops. Part of the trnsci scientific computing suite (github.com/trnsci).

Current phase

trnfft follows the trnsci 5-phase roadmap. Active work is tracked in phase-labeled GitHub issues:

Suite-wide tracker: trnsci/trnsci#1.

Why

NVIDIA has cuFFT, cuBLAS, and native complex64. Trainium has none of these. Every signal processing, speech enhancement, physics simulation, and spectral method workload on Trainium currently falls back to CPU or requires hand-rolling complex arithmetic. trnfft fixes this.

Install

pip install trnfft

# With Neuron hardware support
pip install trnfft[neuron]

Usage

import torch
import trnfft

# Drop-in replacement for torch.fft
signal = torch.randn(1024)
X = trnfft.fft(signal)
recovered = trnfft.ifft(X)

# Real-valued FFT
X = trnfft.rfft(signal)

# 2D FFT
image = torch.randn(256, 256)
F = trnfft.fft2(image)

# STFT (matches torch.stft signature)
waveform = torch.randn(16000)
S = trnfft.stft(waveform, n_fft=512, hop_length=256)

Complex Neural Network Layers

from trnfft import ComplexTensor
from trnfft.nn import ComplexLinear, ComplexConv1d, ComplexModReLU

# Build complex-valued models for speech/audio/physics
x = ComplexTensor(real_part, imag_part)
layer = ComplexLinear(256, 128)
y = layer(x)

Architecture

+--------------------------------------------+
|            User Code / Model               |
+--------------------------------------------+
|         trnfft.api (torch.fft API)         |
|   fft()  ifft()  rfft()  stft()  fft2()   |
+--------------------------------------------+
|   trnfft.fft_core     |  trnfft.nn        |
|   Cooley-Tukey         |  ComplexLinear    |
|   Bluestein            |  ComplexConv1d    |
|   Plan caching         |  ComplexModReLU   |
+------------------------+-------------------+
|       trnfft.nki.dispatch                  |
|   "auto" | "pytorch" | "nki"              |
+--------------------------------------------+
|  PyTorch ops     |  NKI kernels           |
|  (any device)    |  (Trainium only)       |
|  torch.matmul    |  nisa.nc_matmul        |
|  element-wise    |  Tensor Engine         |
|                  |  Vector Engine          |
|                  |  SBUF ↔ PSUM pipeline  |
+------------------+------------------------+

How It Works

No complex dtype? Trainium's NKI doesn't support complex64/complex128. ComplexTensor stores complex values as paired real tensors and decomposes complex arithmetic into real-valued operations.

FFT → butterflies → matmul. Each Cooley-Tukey butterfly stage performs complex-multiply-and-add across all groups simultaneously. On NKI, the complex multiply maps to the Tensor Engine (systolic array).

Algorithms:

  • Power-of-2: Cooley-Tukey radix-2 (iterative, decimation-in-time)
  • Arbitrary sizes: Bluestein's chirp-z transform (pads to power-of-2)

NKI complex GEMM uses stationary tile reuse (2 SBUF loads instead of 8) and PSUM accumulation, overlapping Vector Engine negation with Tensor Engine matmul.

Hardware compatibility

NKI kernels are validated against Neuron SDK 2.24+ on the Deep Learning AMI Neuron PyTorch 2.9 (Ubuntu 24.04) AMI (20260410 or later). See docs/installation.md for the full compatibility matrix.

Benchmarks

NKI vs PyTorch on the same Trainium instance — see the benchmarks page for the latest numbers.

Status

v0.10.0 — NKI kernels validated on trn1.2xlarge. For STFT and batched FFT, set_backend("nki") beats vanilla torch.fft.fft. See benchmarks for the full picture.

API coverage (13 common torch.fft functions): fft, ifft, rfft, irfft, fft2, rfft2, irfft2, fftn, ifftn, rfftn, irfftn, stft, istft.

Not implemented: hfft, ihfft — Hermitian-symmetric input variants. These assume the input tensor is already conjugate-symmetric (X[k] = conj(X[N-k])), which in practice means you've post-processed an rfft output or are reconstructing from a known real signal's spectrum. Both workflows are easier expressed with rfft / irfft plus a manual unpack/pack step. If you have a use-case producing Hermitian-symmetric tensors directly, open an issue with the concrete workload and we'll add them.

Roadmap

  • NKI ComplexConv1d / ComplexModReLU kernels (today both fall back to PyTorch on NKI)
  • BF16 / FP16 support across NKI kernels
  • Multi-NeuronCore parallelism (scaffold in trnfft/nki/multicore.py)
  • SBUF-resident dispatch to reduce small-op overhead

Related projects in the trnsci suite

All six siblings are on PyPI, along with the umbrella meta-package:

Project What Latest
trnsci Umbrella meta-package pulling the whole suite v0.1.0
trnblas BLAS Level 1–3 for Trainium v0.4.0
trnrand Philox / Sobol / Halton random number generation v0.1.0
trnsolver Linear solvers (CG, GMRES) and eigendecomposition v0.3.0
trnsparse Sparse matrix operations v0.1.1
trntensor Tensor contractions (einsum, TT/Tucker decompositions) v0.1.1
neuron-complex-ops Original proof-of-concept, folded into trnfft archived

License

Apache 2.0 — Copyright 2026 Scott Friedman

Acknowledgments

Built on insights from:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trnfft-0.11.0.tar.gz (90.4 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trnfft-0.11.0-py3-none-any.whl (33.7 kB view details)

Uploaded Python 3

File details

Details for the file trnfft-0.11.0.tar.gz.

File metadata

  • Download URL: trnfft-0.11.0.tar.gz
  • Upload date:
  • Size: 90.4 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for trnfft-0.11.0.tar.gz
Algorithm Hash digest
SHA256 29ecb77162b81a2107f9d4c9acc234ec22667d55235f209d6853baa141856ddf
MD5 c46a3439365fdcfc46567e44bdb6db56
BLAKE2b-256 5e8f2c729f8659dbfc468ae3ae8eb3ae78f7e6d00e5068464660fb08913e005d

See more details on using hashes here.

Provenance

The following attestation bundles were made for trnfft-0.11.0.tar.gz:

Publisher: publish.yml on trnsci/trnfft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file trnfft-0.11.0-py3-none-any.whl.

File metadata

  • Download URL: trnfft-0.11.0-py3-none-any.whl
  • Upload date:
  • Size: 33.7 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for trnfft-0.11.0-py3-none-any.whl
Algorithm Hash digest
SHA256 2c94712d9fb4bb20c27853746faeac020e1a37ff58dd33acc491f58de323afff
MD5 f039b7f2db3aa37732a5e393e596aa61
BLAKE2b-256 e7794d9fcef0815d6a03f154b62c7ea31e7f3ac556ac3c2b50df4e65840a6174

See more details on using hashes here.

Provenance

The following attestation bundles were made for trnfft-0.11.0-py3-none-any.whl:

Publisher: publish.yml on trnsci/trnfft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page