Skip to main content

FFT and complex-valued tensor operations for AWS Trainium via NKI

Project description

trnfft

CI PyPI Python License Docs

FFT and complex-valued tensor operations for AWS Trainium via NKI.

Trainium has no native complex number support and ships no FFT library. trnfft fills that gap with split real/imaginary representation, complex neural network layers, and NKI kernels optimized for the NeuronCore architecture.

Incorporates neuron-complex-ops. Part of the trn-* scientific computing suite by Playground Logic.

Why

NVIDIA has cuFFT, cuBLAS, and native complex64. Trainium has none of these. Every signal processing, speech enhancement, physics simulation, and spectral method workload on Trainium currently falls back to CPU or requires hand-rolling complex arithmetic. trnfft fixes this.

Install

pip install trnfft

# With Neuron hardware support
pip install trnfft[neuron]

Usage

import torch
import trnfft

# Drop-in replacement for torch.fft
signal = torch.randn(1024)
X = trnfft.fft(signal)
recovered = trnfft.ifft(X)

# Real-valued FFT
X = trnfft.rfft(signal)

# 2D FFT
image = torch.randn(256, 256)
F = trnfft.fft2(image)

# STFT (matches torch.stft signature)
waveform = torch.randn(16000)
S = trnfft.stft(waveform, n_fft=512, hop_length=256)

Complex Neural Network Layers

from trnfft import ComplexTensor
from trnfft.nn import ComplexLinear, ComplexConv1d, ComplexModReLU

# Build complex-valued models for speech/audio/physics
x = ComplexTensor(real_part, imag_part)
layer = ComplexLinear(256, 128)
y = layer(x)

Architecture

+--------------------------------------------+
|            User Code / Model               |
+--------------------------------------------+
|         trnfft.api (torch.fft API)         |
|   fft()  ifft()  rfft()  stft()  fft2()   |
+--------------------------------------------+
|   trnfft.fft_core     |  trnfft.nn        |
|   Cooley-Tukey         |  ComplexLinear    |
|   Bluestein            |  ComplexConv1d    |
|   Plan caching         |  ComplexModReLU   |
+------------------------+-------------------+
|       trnfft.nki.dispatch                  |
|   "auto" | "pytorch" | "nki"              |
+--------------------------------------------+
|  PyTorch ops     |  NKI kernels           |
|  (any device)    |  (Trainium only)       |
|  torch.matmul    |  nisa.nc_matmul        |
|  element-wise    |  Tensor Engine         |
|                  |  Vector Engine          |
|                  |  SBUF ↔ PSUM pipeline  |
+------------------+------------------------+

How It Works

No complex dtype? Trainium's NKI doesn't support complex64/complex128. ComplexTensor stores complex values as paired real tensors and decomposes complex arithmetic into real-valued operations.

FFT → butterflies → matmul. Each Cooley-Tukey butterfly stage performs complex-multiply-and-add across all groups simultaneously. On NKI, the complex multiply maps to the Tensor Engine (systolic array).

Algorithms:

  • Power-of-2: Cooley-Tukey radix-2 (iterative, decimation-in-time)
  • Arbitrary sizes: Bluestein's chirp-z transform (pads to power-of-2)

NKI complex GEMM uses stationary tile reuse (2 SBUF loads instead of 8) and PSUM accumulation, overlapping Vector Engine negation with Tensor Engine matmul.

Hardware compatibility

NKI kernels are validated against Neuron SDK 2.24+ on the Deep Learning AMI Neuron PyTorch 2.9 (Ubuntu 24.04) AMI (20260410 or later). See docs/installation.md for the full compatibility matrix.

Status

v0.1.0 — CPU fallback works, NKI kernels scaffolded for on-hardware validation.

  • ComplexTensor with full arithmetic
  • Complex matmul (4 real matmuls)
  • 1D FFT/IFFT (power-of-2, Cooley-Tukey)
  • Bluestein (arbitrary sizes)
  • rfft/irfft
  • 2D FFT
  • STFT
  • Complex NN layers (Linear, Conv1d, BatchNorm, ModReLU)
  • NKI dispatch layer (auto/pytorch/nki)
  • Plan caching
  • NKI butterfly kernel validation on trn1/trn2
  • NKI GEMM kernel validation
  • Multi-NeuronCore parallelism
  • Benchmarks vs cuFFT
  • Inverse STFT
  • N-D FFT

Related Projects

Project What
neuron-complex-ops Original proof-of-concept (now folded into this library)
trnblas BLAS for Trainium (Level 1-3, DF-MP2 use case)
trnsolver (planned) Linear solvers and eigendecomposition for Trainium

License

Apache 2.0 — Playground Logic LLC

Acknowledgments

Built on insights from:

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

trnfft-0.6.0.tar.gz (47.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

trnfft-0.6.0-py3-none-any.whl (25.3 kB view details)

Uploaded Python 3

File details

Details for the file trnfft-0.6.0.tar.gz.

File metadata

  • Download URL: trnfft-0.6.0.tar.gz
  • Upload date:
  • Size: 47.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for trnfft-0.6.0.tar.gz
Algorithm Hash digest
SHA256 b78153cce2d5ae6403ab30f14b696dc97bbdb95b011067d91dabb2f7d68fb2c1
MD5 3b3ab7e431b2376e9c45e1d49f0c84f0
BLAKE2b-256 c8f7ca56cc0e7c9fc9ec50f766c4c96ddd40b3ea8c057562055ae05c61372971

See more details on using hashes here.

Provenance

The following attestation bundles were made for trnfft-0.6.0.tar.gz:

Publisher: publish.yml on scttfrdmn/trnfft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

File details

Details for the file trnfft-0.6.0-py3-none-any.whl.

File metadata

  • Download URL: trnfft-0.6.0-py3-none-any.whl
  • Upload date:
  • Size: 25.3 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? Yes
  • Uploaded via: twine/6.1.0 CPython/3.13.12

File hashes

Hashes for trnfft-0.6.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a64f3f5ce7ca5b38343a41289bfec64e0c9e22d1b62ec61d9481929a5e939438
MD5 dd6f67822f8e453bf92154af8192106d
BLAKE2b-256 139148b31a449e8e06f6d7520de25f9df3c6e0192c05e78ee721ec86ad0764a1

See more details on using hashes here.

Provenance

The following attestation bundles were made for trnfft-0.6.0-py3-none-any.whl:

Publisher: publish.yml on scttfrdmn/trnfft

Attestations: Values shown here reflect the state when the release was signed and may no longer be current.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page