Skip to main content

Non-negative Stiefel Approximating Flow for interpretable representation learning

Project description

🧠 NSA-Flow: Non-negative Stiefel Approximating Flow

NSA-Flow is a general-purpose optimization framework for interpretable representation learning.
It unifies sparse matrix factorization, orthogonalization, and manifold constraints into a single, differentiable algorithm that operates near the Stiefel manifold.

The NSA-flow framework

Documentation of functions here

Download the Project Slides


✨ Overview

Interpretable representation learning remains a core challenge in high-dimensional domains such as neuroimaging, genomics, and text analysis.
NSA-Flow provides a smooth geometric mechanism for balancing reconstruction fidelity and column-wise decorrelation, producing sparse, stable, and interpretable representations.

NSA-Flow enforces structured sparsity via a single tunable weight parameter, combining:

  • Continuous orthogonality control via manifold retraction (e.g., soft-polar, polar)
  • Non-negativity via proximal updates
  • Adaptive gradient scaling and learning-rate control

🧩 Key Features

  • ⚙️ Continuous flow near the Stiefel manifold
  • 🧮 Non-negative and orthogonal constraints
  • 🧠 Interpretable latent representations
  • 🚀 Compatible with PyTorch optimization routines
  • 🧬 Validated on neuroimaging and genomics datasets

📦 Installation

Install from PyPI (once published):

pip install nsa_flow

Or install the latest development version directly from GitHub:

pip install git+https://github.com/stnava/nsa_flow.git

🧰 Dependencies • Python ≥ 3.9 • PyTorch ≥ 2.0 • NumPy ≥ 1.23 • Matplotlib (for optional visualization)

🚀 Quick Start

import torch
import nsa_flow
torch.manual_seed(42)
# Random initialization
Y = torch.randn(120, 200)+1
print("Initial orthogonality defect:", nsa_flow.invariant_orthogonality_defect(Y))
# Run NSA-Flow optimization
result = nsa_flow.nsa_flow_orth(
    Y,
    w=0.5,
    retraction="soft_polar",
    optimizer="asgd",
    max_iter=5000,
    record_every=1,
    tol=1e-8,
    initial_learning_rate=None,
    lr_strategy='bayes',
    warmup_iters=10,
    verbose=False,
)
nsa_flow.plot_nsa_trace( result['traces'] )
print("Final orthogonality defect:", nsa_flow.invariant_orthogonality_defect(result["Y"]))

📖 Documentation

NSA-Flow exposes a small set of high-level functions:

Function Description

  • nsa_flow() Main optimization loop balancing fidelity and orthogonality

  • nsa_flow_retract_auto() Retraction operator enforcing manifold constraints

  • invariant_orthogonality_defect() Computes orthogonality defect measure

  • defect_fast() Fast approximate defect metric

  • nsa_flow_autograd() Autograd-compatible variant for joint optimization

  • get_torch_optimizer() Returns a configured PyTorch optimizer

🧪 Validation

NSA-Flow has been validated in:

•	Golub leukemia gene expression dataset

•	Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset

NSA-Flow constraints maintain or improve performance while simplifying latent representations and improving interpretability.

There is also a layer that can be included (potentially) in deep learning tools. See tests/test_nsaf_layer.py. This has not been used tested. ⸻

🧑‍💻 Citation

If you use NSA-Flow in research, please cite:

Stnava et al. (2025). NSA-Flow: Non-negative Stiefel Approximating Flow for Interpretable Representation Learning.

⚖️ License

MIT License © 2025

📫 Contact

For issues, feature requests, or contributions, open an issue on GitHub.


to publish a release

before doing this - make sure you have a recent run of pip-compile pyproject.toml

rm -r -f build/ nsa_flow.egg-info/ dist/
python -m  build .
python -m pip install --upgrade twine
python -m twine upload --repository nsa_flow dist/*

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nsa_flow-1.1.0.tar.gz (52.6 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nsa_flow-1.1.0-py3-none-any.whl (57.1 kB view details)

Uploaded Python 3

File details

Details for the file nsa_flow-1.1.0.tar.gz.

File metadata

  • Download URL: nsa_flow-1.1.0.tar.gz
  • Upload date:
  • Size: 52.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for nsa_flow-1.1.0.tar.gz
Algorithm Hash digest
SHA256 3735190f0a85e8b3d2af1b66bb346debbbd0a08611854b909c8d3c84f658424b
MD5 44f032df330cf0d20c75e12a441a3801
BLAKE2b-256 7e8c660de3698818055296b2a498126ac64d60ad04af1829a12b625076da6dd2

See more details on using hashes here.

File details

Details for the file nsa_flow-1.1.0-py3-none-any.whl.

File metadata

  • Download URL: nsa_flow-1.1.0-py3-none-any.whl
  • Upload date:
  • Size: 57.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.12

File hashes

Hashes for nsa_flow-1.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 45e2d6741bea43f5d1b256827ed83a12d494cf2b365c1bec36749e0e81675bf1
MD5 1278afe7d708e566293d87f3fe346504
BLAKE2b-256 565d6dea9b93d90bdeda2abe3ee1b61a8f6b8b41b2250c5e265781eaee60c2c4

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page