Skip to main content

Frequency-domain model explanation (IG) package

Project description

freqIG

Overview

This repository contains the implementation of freqIG, a method based on the principle of FLEX (Frequency Layer Explanation) [1], designed to explain the predictions of deep neural networks (DNNs) for time-series classification tasks. freqIG combines Integrated Gradients (IG) with a frequency-domain transform (via the Real Fast Fourier Transform (RFFT)) to provide frequency-based attribution scores.

The method is generally useful for understanding how different frequency components of a time-series input influence the predictions of a DNN, thus enhancing model interpretability.

For details on the general concept, see [1]: "Using EEG Frequency Attributions to Explain the Classifications of a Deep Neural Network for Sleep Staging" (Paul Gräve et al.).


Features

  • RFFT Transformation: Input time-series data are transformed into the frequency domain using the RFFT.
  • iRFFT Transformation: The inverse RFFT (iRFFT) is implemented as the first layer in the DNN to process frequency-domain inputs.
  • Integrated Gradients Attribution: Captum's IG method is used to compute relevance scores for frequency bands, providing insights into the features contributing to the model's predictions.

Definition (FLEX principle)

Let F be our model (DNN) and x be our input (time-series data). Then with $\bar{F} = F \circ iRFFT$ and $\bar{x} = RFFT(x)$ we get
$$FLEX_i(F,x) = IG_i(\bar{F},\bar{x})$$,
where $FLEX(F,x) = (FLEX_1(F,x), ..., FLEX_n(F,x))$ with $x \in \mathbb{R}^n$.


Installation

Requirements

  • Python 3.8+
  • Required libraries:
    • numpy
    • torch
    • captum

Install Dependencies

You can install the required Python libraries using pip:

pip install numpy torch captum

Documentation

freqIG.attribute

Compute frequency-based attribution scores for a model predicting on time-series data.

freqIG.attribute(
    input: Union[np.ndarray, list, torch.Tensor],
    model: Any,
    target: Optional[int] = None,
    baseline: Optional[Union[np.ndarray, list, torch.Tensor]] = None,
    n_steps: int = 50,
    segment: Optional[Union[np.ndarray, list, torch.Tensor]] = None,
    start_idx: Optional[int] = None,
    additional_forward_args: Optional[Any] = None
) -> np.ndarray

Parameters

  • input : array-like or torch.Tensor
    The input time-series data.

  • model : callable
    The (frequency-domain) model to explain.

  • target : int, optional
    Index of the class to explain. If None, explains the model's predicted class.

  • baseline : array-like or torch.Tensor, optional
    Baseline input for Integrated Gradients. Defaults to zero input.

  • n_steps : int, default=50
    Number of steps in the IG path.

  • segment : array-like or torch.Tensor, optional
    Segment of the input for localized attribution.

  • start_idx : int, optional
    Start index of the segment within the original input.

  • additional_forward_args : Any, optional
    Additional arguments passed to the model during attribution.

Returns

  • np.ndarray
    Array containing the frequency attribution scores.

Raises

  • ValueError
    If segment is provided but start_idx is missing, or if the segment exceeds the bounds of the input.
  • ValueError
    If baseline is provided but its shape does not match the input.

Notes

This function applies Integrated Gradients in the frequency domain to provide frequency-wise attributions for any model acting on time-series data, following the FLEX [1] principle.

References

[1] Using EEG Frequency Attributions to Explain the Classifications of a Deep Neural Network for Sleep Staging
Paul Gräve, T. Steinbrinker, F. Ehrlich, P. Hempel, P. Zaschke, D. Krefting, N. Spicher; 2025.

Examples

from freqIG import attribute

# Generate dummy time-series data: 20 samples, each of length 10
np.random.seed(0)
n_samples = 20
n_features = 10
X = np.linspace(0, 1, n_features)[None, :] + 0.1 * np.random.randn(n_samples, n_features)
y = np.random.randint(0, 2, size=(n_samples,))  # Binary classification

X_torch = torch.tensor(X, dtype=torch.float32)
y_torch = torch.tensor(y, dtype=torch.long)

# Simple frequency-domain model (2-class classifier)
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(n_features, 2)
    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# VERY simple training loop (just for demonstration!)
model.train()
for epoch in range(30):
    optimizer.zero_grad()
    outputs = model(X_torch)
    loss = criterion(outputs, y_torch)
    loss.backward()
    optimizer.step()
model.eval()

# Pick one sample for explanation
sample = X[0:1]

# Run freqIG.attribute to get attributions
attr_scores = attribute(
    input=sample,    # shape (1, 10)
    model=model,
    target=0,        # Explain class 0
    n_steps=30
)

print("Attribution scores:", attr_scores)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

freqig-0.1.1.tar.gz (6.1 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

freqig-0.1.1-py3-none-any.whl (6.6 kB view details)

Uploaded Python 3

File details

Details for the file freqig-0.1.1.tar.gz.

File metadata

  • Download URL: freqig-0.1.1.tar.gz
  • Upload date:
  • Size: 6.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for freqig-0.1.1.tar.gz
Algorithm Hash digest
SHA256 a2c3bc74b9413b4021b9a2c8e78d7757082dce81d92ba9e3d65b5bb9f777e600
MD5 633b01e4fac586b7ab1ae1d6780cebec
BLAKE2b-256 2057297820dce2c76fbf0740c5dfabc18b5c5032c3f297d9130dc82fb91cd82f

See more details on using hashes here.

File details

Details for the file freqig-0.1.1-py3-none-any.whl.

File metadata

  • Download URL: freqig-0.1.1-py3-none-any.whl
  • Upload date:
  • Size: 6.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for freqig-0.1.1-py3-none-any.whl
Algorithm Hash digest
SHA256 8fb4f349f6a3c24d45d8bdba09724b4ac49a115de284c26e77169026126e70c1
MD5 527a1ab4ae9f4593b050afc532b67db2
BLAKE2b-256 a09981274f240b00b0b1f04ced598732d4988f11c1ae1a96368a7aba60e5f85b

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page