Skip to main content

Non-negative Stiefel Approximating Flow for interpretable representation learning

Project description

🧠 NSA-Flow: Non-negative Stiefel Approximating Flow

NSA-Flow is a general-purpose optimization framework for interpretable representation learning.
It unifies sparse matrix factorization, orthogonalization, and manifold constraints into a single, differentiable algorithm that operates near the Stiefel manifold.


✨ Overview

Interpretable representation learning remains a core challenge in high-dimensional domains such as neuroimaging, genomics, and text analysis.
NSA-Flow provides a smooth geometric mechanism for balancing reconstruction fidelity and column-wise decorrelation, producing sparse, stable, and interpretable representations.

NSA-Flow enforces structured sparsity via a single tunable weight parameter, combining:

  • Continuous orthogonality control via manifold retraction (e.g., soft-polar, polar)
  • Non-negativity via proximal updates
  • Adaptive gradient scaling and learning-rate control

🧩 Key Features

  • ⚙️ Continuous flow near the Stiefel manifold
  • 🧮 Non-negative and orthogonal constraints
  • 🧠 Interpretable latent representations
  • 🚀 Compatible with PyTorch optimization routines
  • 🧬 Validated on neuroimaging and genomics datasets

📦 Installation

Install from PyPI (once published):

pip install nsa_flow

Or install the latest development version directly from GitHub:

pip install git+https://github.com/stnava/nsa_flow.git


⸻

🧰 Dependencies
		Python  3.9
		PyTorch  2.0
		NumPy  1.23
		Matplotlib (for optional visualization)

⸻

🚀 Quick Start

import torch
import nsa_flow
from nsa_flow import nsa_flow, invariant_orthogonality_defect

torch.manual_seed(42)

# Random initialization
Y = torch.randn(50, 10)
X0 = torch.randn_like(Y)

# Run NSA-Flow optimization
result = nsa_flow(
    Y,
    X0=X0,
    w=0.8,
    retraction="soft_polar",
    optimizer="sgdp",
    max_iter=50,
    record_every=1,
    tol=1e-8,
    initial_learning_rate=1e-2,
    verbose=True,
)

print("Initial orthogonality defect:", invariant_orthogonality_defect(Y))
print("Final orthogonality defect:", invariant_orthogonality_defect(result["Y"]))


⸻

📖 Documentation

NSA-Flow exposes a small set of high-level functions:

Function	Description
nsa_flow()	Main optimization loop balancing fidelity and orthogonality
nsa_flow_retract_auto()	Retraction operator enforcing manifold constraints
invariant_orthogonality_defect()	Computes orthogonality defect measure
defect_fast()	Fast approximate defect metric
nsa_flow_autograd()	Autograd-compatible variant for joint optimization
get_torch_optimizer()	Returns a configured PyTorch optimizer


⸻

🧪 Validation

NSA-Flow has been validated in:
		Golub leukemia gene expression dataset
		Alzheimer’s Disease Neuroimaging Initiative (ADNI) dataset

NSA-Flow constraints maintain or improve performance while simplifying latent representations and improving interpretability.

There is also a layer that can be included (potentially) in deep learning tools.  See `tests/test_nsaf_layer.py`. 
This has not been used tested.
⸻

🧑‍💻 Citation

If you use NSA-Flow in research, please cite:

Stnava et al. (2025). NSA-Flow: Non-negative Stiefel Approximating Flow for Interpretable Representation Learning.

⸻

⚖️ License

MIT License © 2025 

⸻

📫 Contact

For issues, feature requests, or contributions, open an issue on
GitHub.

---

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

nsa_flow-0.2.0.tar.gz (17.2 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

nsa_flow-0.2.0-py3-none-any.whl (11.5 kB view details)

Uploaded Python 3

File details

Details for the file nsa_flow-0.2.0.tar.gz.

File metadata

  • Download URL: nsa_flow-0.2.0.tar.gz
  • Upload date:
  • Size: 17.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for nsa_flow-0.2.0.tar.gz
Algorithm Hash digest
SHA256 7feaa33a885644ca750e34540c6c1f7a52d2093ddd41de3d8016a96c54adc89b
MD5 c3a5854566ba901422da442f29eccf69
BLAKE2b-256 846c06b19795bf7d0bc27b7b2c95cb2f0909053d1cf4659cf1a59a737708a433

See more details on using hashes here.

File details

Details for the file nsa_flow-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: nsa_flow-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 11.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.11.9

File hashes

Hashes for nsa_flow-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 6f38d867754b825ae4b974ff210f620653ff1254363a88752f254dea0564c89f
MD5 014f02298ca1124f29edae535d258ff2
BLAKE2b-256 157e72d82fc136e9618a745ae02ec130073524b3b2ef2e32ee640f4128951191

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page