Skip to main content

Matrix decomposition of neurophysiological event waveform populations

Project description

subwave

Data-driven decomposition of neurophysiological event waveform populations.


Five principal waveform templates extracted from 821 sleep spindles via SVD.


200 sigma-filtered spindle waveforms overlaid, with the population mean in black.


Left: all events sorted by Component 1 loading. Right: k-means clustering in the first two component scores.

Install

pip install subwave

Optional extras:

pip install "subwave[mne]"          # MNE-Python and Luna I/O
pip install "subwave[yasa]"         # YASA spindle detection I/O (pulls in mne)
pip install "subwave[scattering]"   # Wavelet scattering decomposition (kymatio)
pip install "subwave[tensor]"       # Multi-channel tensor decomposition (tensorly)
pip install "subwave[all]"          # All optional backends

Quickstart

import numpy as np
import subwave as sw

rng = np.random.default_rng(0)
t = np.linspace(0, 1, 256)
X = np.stack([np.sin(2 * np.pi * 13 * t) + rng.normal(0, 0.1, 256) for _ in range(100)])

result = sw.decompose(X, method="svd", n_components=5)
result.plot_templates(n=3)

Each template is a basis waveform shape; each loading is how strongly a given event expresses that template. Together they reconstruct the original population.

Or let subwave choose the number of components:

result = sw.decompose(X, n_components="auto")

ECG example

import numpy as np
import subwave as sw

# Stack detected QRS complexes into an event matrix
X_beats = np.load("heartbeats.npy")     # (n_beats, n_samples)
result = sw.decompose(X_beats, method="svd", n_components=3)

# Outlier detection flags ectopic beats
scores = result.outlier_scores()
ectopic = np.where(scores > np.percentile(scores, 99))[0]

Loading data

sw.from_array(X, sfreq=256)                              # plain numpy
sw.from_npz("spindles.npz")                              # Lunascope format
sw.from_mne(epochs)                                       # MNE Epochs
sw.from_yasa(spindles_df, raw_signal, sfreq=256)          # YASA output
sw.from_luna("spindles.txt", "recording.edf", sfreq=256)  # Luna output + EDF
sw.from_edf_batch(["s1.edf", "s2.edf"], channel="C3")     # detect + pool across files

from_edf_batch runs a detector (default YASA) on every EDF, extracts windows around each peak, and returns a single EventMatrix with a .meta DataFrame recording subject, file, event_index, and peak_sec for each pooled event.

Decomposition methods

  • SVD / PCA (method='svd') — default, optimal low-rank approximation. Uses randomized SVD for >5000 events.
  • NMF (method='nmf') — parts-based, requires non-negative input.
  • Dictionary learning (method='dictlearn') — sparse atoms.
  • Fourier-then-SVD (method='fourier_svd') — shift-invariant: SVD on per-event magnitude rFFT spectra. Robust to small temporal jitter that corrupts plain SVD. Templates live in frequency space (config['domain'] = 'frequency').
  • Scattering-then-SVD (method='scattering_svd') — locally translation-invariant and stable to deformation via 1-D wavelet scattering (kymatio). Templates live in scattering-coefficient space. Requires subwave[scattering].

Tensor decomposition (multi-channel)

For 3-D event tensors (n_events, n_samples, n_channels), tensor_decompose factors temporal and spatial structure separately instead of flattening:

result = sw.tensor_decompose(X, method="cp", rank=3)        # PARAFAC / CP
result = sw.tensor_decompose(X, method="tucker", rank=3)    # Tucker (with core)

result.event_factors        # (n_events, rank)
result.temporal_factors     # (n_samples, rank)
result.spatial_factors      # (n_channels, rank)
result.reconstruction_error # ||X - X_hat|| / ||X||
result.component_waveform(0, channel=1)   # template scaled by channel weight

Requires subwave[tensor].

Working with results

result.templates                        # (n_components, n_samples) basis waveforms
result.loadings                         # (n_events, n_components) per-event scores
result.explained_variance_ratio         # variance captured per component
result.singular_values                  # singular values
result.factor_tables["instance"]        # DataFrame: instance_id, score_1…k, recon_error

result.reconstruct(n_components=3)      # rank-k reconstruction
result.project(new_X)                   # project new events onto learned subspace
result.outlier_scores()                 # per-event reconstruction error

Component selection

k = sw.parallel_analysis(X)             # Horn's parallel analysis
k = sw.elbow(result.singular_values)    # Kneedle elbow detection
k = sw.kaiser(result)                   # Kaiser rule

Spectral characterization

freqs, powers = result.template_spectrum(sfreq=256)
peak_hz = result.template_peak_freq(sfreq=256)     # e.g. [13.2, 11.1] Hz
bw_hz = result.template_bandwidth(sfreq=256)

Clustering

cr = result.cluster(method="kmeans", n_clusters=2)
cr["labels"]                            # cluster assignments
result.cluster_templates(n_clusters=2)  # mean waveform per cluster

Group comparison

perm = sw.permutation_test(X, groups, n_components=3, n_perm=500)
perm.p_value                            # do two groups span different subspaces?

df = sw.comparison.loading_test(result, groups, n_perm=1000)
# Columns: component, observed_diff, p_value, cohens_d, p_corrected
# cohens_d:    standardized effect size (pooled SD)
# p_corrected: Benjamini-Hochberg FDR-adjusted p-values across components

Spindle helpers

Convenience routines for sleep-spindle waveforms (sigma-band filtering, envelope-based alignment, canonical templates).

from subwave.spindles import (
    sigma_filter,             # 9–16 Hz Butterworth bandpass (sosfiltfilt)
    align_by_envelope_peak,   # circularly shift so each event's sigma-envelope peak is centered
    CANONICAL_FAST,           # ~13.5 Hz Gaussian-windowed sinusoid (256 Hz, 1 s)
    CANONICAL_SLOW,           # ~11 Hz   Gaussian-windowed sinusoid (256 Hz, 1 s)
)

X_filt = sigma_filter(X, sfreq=256)
X_aligned = align_by_envelope_peak(X, sfreq=256)

Validation

Tools for checking decompositions against ground truth and for estimating component reliability.

from subwave.validation import (
    synthetic_population,   # generate events from known templates with noise / jitter / amplitude variability
    recovery_score,         # cosine similarity between true and recovered templates (Hungarian-matched)
    cluster_recovery_ari,   # Adjusted Rand Index between true and recovered cluster labels
    bootstrap_stability,    # mean cosine similarity of bootstrap-resampled templates to full-data templates
    split_half,             # template & loading reproducibility on odd/even split
)

# Ground-truth recovery on synthetic data
truth = synthetic_population(templates, n_events=200, noise_std=0.1, random_state=0)
result = sw.decompose(truth["X"], method="svd", n_components=truth["templates"].shape[0])
score = recovery_score(truth["templates"], result.templates)
score["mean_score"]            # cosine similarity (1.0 = perfect)

# Reliability of a real decomposition
boot = bootstrap_stability(X, n_components=3, n_boot=100, random_state=0)
boot["stability_scores"]       # per-component mean similarity across bootstraps

sh = split_half(X, n_components=3, random_state=0)
sh["template_similarity"]      # per-component template reproducibility
sh["loading_correlation"]      # correlation of projected vs native loadings

Serialization

result.save("result.npz")
result = sw.load_result("result.npz")
df = result.to_dataframe()              # flat DataFrame for R/Stata

Plots

result.plot_spectrum()                  # singular value scree plot
result.plot_templates(n=5)              # basis waveforms
result.plot_template_spectra(sfreq=256) # power spectrum of each template
result.plot_scatter(x=0, y=1)           # component 0 vs 1 (supports color= for clusters)
result.plot_heatmap(comp=0)             # events × samples sorted by loading
result.plot_waterfall(n=100)            # overlaid waveforms with bold mean
result.plot_mean_pm(comp=0)             # mean ± component
result.plot_sorted_grid(comp=0)         # events sorted by score
result.plot_residual_hist()             # reconstruction error distribution
result.plot_cumulative_variance()       # cumulative EVR curve
result.plot_reconstruction(event_idx=0) # original vs reconstruction
result.plot_loadings_by_group(groups)   # box/violin by group
result.plot_loadings_over_time(times)   # loading drift across time

See docstrings via help(sw.decompose) for full options, including cluster_sweep, loading_test, subspace_angles, scatter_colored_by, loadings_correlated_with, and more.

Citation

If you use subwave in published work, see CITATION.cff in the repository.

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

subwave-0.2.1.tar.gz (42.3 MB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

subwave-0.2.1-py3-none-any.whl (40.5 kB view details)

Uploaded Python 3

File details

Details for the file subwave-0.2.1.tar.gz.

File metadata

  • Download URL: subwave-0.2.1.tar.gz
  • Upload date:
  • Size: 42.3 MB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for subwave-0.2.1.tar.gz
Algorithm Hash digest
SHA256 c846ca747a2c33495ad2f28ff445bdd6de2e0f67c8b625ce9cfd5c6d03c020b4
MD5 98e4d5be261356c79b71b3b7b65acc29
BLAKE2b-256 c4851c88486f81ad9b1b21d3ed7de1d830dbfe184b8d440b1b117c7d0dd265cf

See more details on using hashes here.

File details

Details for the file subwave-0.2.1-py3-none-any.whl.

File metadata

  • Download URL: subwave-0.2.1-py3-none-any.whl
  • Upload date:
  • Size: 40.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.14.0

File hashes

Hashes for subwave-0.2.1-py3-none-any.whl
Algorithm Hash digest
SHA256 1a991314b244c3b222f89687b4e9eb861447a822bde01184eebe7e471e9f66f6
MD5 f32eb136c26b59a5f47d268de8d5291f
BLAKE2b-256 6fbcf1a311dbca44af735d753e469769f557d306b91bd1fb94806e2e88ad69b1

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page