Skip to main content

ENGAGE: Explainable Neural GP with Active Guided Exploration — human-guided, interpretable deep kernel Gaussian processes with active learning

Project description

ENGAGE

Explainable Neural GP with Active Guided Exploration

ENGAGE combines deep neural networks with Gaussian Processes for regression, classification, and preference learning — designed around three core principles:

  • Human-guided — pairwise preference learning lets users steer the model through comparisons rather than numeric labels; Bayesian optimization acquisition functions support active experimental design
  • Explainable — attention-based extractors expose what the model focuses on; per-feature importance maps and head×head attention matrices are first-class outputs, not afterthoughts
  • Active learning — uncertainty estimates from the GP drive principled query selection; sample weighting identifies unreliable measurements automatically

The neural network learns a compressed feature representation; the GP is fit on top for calibrated uncertainty-aware predictions.

Version: 0.2.0 Install: pip install engage-gp Import: import engagegp

engage-gp/   (import as: import engagegp)
engagegp/
├── models.py           # Feature extractors
├── gpr.py              # GP Regression + Bayesian optimization
├── gpc.py              # GP Classification
├── gppw.py             # GP Pairwise (preference learning)
├── sample_weighting.py # Learnable sample weights
├── utils.py            # Attention analysis, model I/O, utilities
└── __init__.py         # Package exports

Feature Extractors

Available Extractors

Type Class Best for
fc FCFeatureExtractor Quick prototyping, small datasets
fcbn FCBNFeatureExtractor General use (default), regularized
resnet ResNetFeatureExtractor Deeper networks, gradient stability
attention AttentionFeatureExtractor Feature interactions, relational data
direct_attention DirectAttentionExtractor Attention on raw inputs (e.g., spectroscopy wavelengths)
attention_weighted AttentionWeightedExtractor Interpretable feature importance weights
custom any nn.Module User-provided architecture

Attention Extractors: Key Differences

All three attention-based extractors use the self-attention mechanism differently. Choosing the right one depends on what you want to learn and interpret.

attention direct_attention attention_weighted
What attends to what Heads attend to each other (in hidden space) Raw input features attend to each other Each input feature gets a scalar importance score
Attention map shape (batch, num_heads, num_heads) (batch, num_heads, input_dim, input_dim) (batch, input_dim)
Projection Input → hidden → multi-head Input directly split into heads Input → attention scores (no projection)
Interpretation Which learned perspectives are correlated Which input features interact with which Which input features matter most
Computational cost Low (heads × heads) High (input_dim²) Low
Best for General relational data Spectroscopy, signals, spatial data Feature selection, interpretability

attention (AttentionFeatureExtractor) Projects the input into a hidden space, then applies multi-head self-attention. The attention is computed between the num_heads learned projections of the input — not between input features directly. The resulting map (batch, num_heads, num_heads) tells you how the model's internal "perspectives" relate to each other. Good general-purpose choice when you want attention-based representation learning without the cost of input×input maps.

direct_attention (DirectAttentionExtractor) Applies attention directly to the raw input features — no projection into hidden space first. Each head computes an input_dim × input_dim attention map, capturing which input positions (e.g., wavelengths, pixels, time steps) attend to which others. The map shape (batch, num_heads, input_dim, input_dim) is interpretable as a feature-to-feature relationship matrix. Use this when the spatial or sequential structure of the input itself is meaningful.

attention_weighted (AttentionWeightedExtractor) Does not compute pairwise attention between features. Instead, it learns a single importance weight per input feature — a soft feature selection mask. The weights (batch, input_dim) sum to 1.0 and are applied element-wise before passing through a base extractor (default: fcbn). This is the most interpretable option: you can directly see which input dimensions the model relies on. Use this when you want to know which features matter, not how they interact.


Factory Function

from engagegp import get_feature_extractor

# Simple FC
extractor = get_feature_extractor('fc', input_dim=100, feature_dim=16)

# FC + BatchNorm (recommended default)
extractor = get_feature_extractor('fcbn', input_dim=100, feature_dim=16,
                                  hidden_dims=[512, 256, 128], dropout=0.3)

# ResNet
extractor = get_feature_extractor('resnet', input_dim=100, feature_dim=16,
                                  hidden_dim=256, num_blocks=3)

# Self-attention
extractor = get_feature_extractor('attention', input_dim=100, feature_dim=16,
                                  hidden_dim=128, num_heads=4)

# Direct attention on raw features (e.g., 256 wavelengths → wavelength-to-wavelength map)
extractor = get_feature_extractor('direct_attention', input_dim=256, feature_dim=16,
                                  num_heads=4)

# Attention-weighted (interpretable per-feature importance scores)
extractor = get_feature_extractor('attention_weighted', input_dim=256, feature_dim=16,
                                  base_extractor='fcbn')

# Custom nn.Module
import torch.nn as nn
my_net = nn.Sequential(nn.Linear(100, 64), nn.ReLU(), nn.Linear(64, 16))
extractor = get_feature_extractor('custom', custom_extractor=my_net)

Regression

from engagegp import fit_dkgp, predict_dkgpr

# Fit (default: FC + BatchNorm extractor)
mll, gp, dkl, losses = fit_dkgp(X_train, y_train, feature_dim=16)

# Predict
mean, std = predict_dkgpr(dkl, X_test, return_std=True)

With confidence weights (heteroscedastic data)

weights = np.array([1.0, 0.5, 1.0, ...])  # Lower weight for noisy samples
mll, gp, dkl, losses = fit_dkgp(X_train, y_train, confidence_weights=weights)

Full example

import numpy as np
from engagegp import fit_dkgp, predict_dkgpr

np.random.seed(42)
X_train = np.random.randn(200, 50)
y_train = np.sum(X_train[:, :5], axis=1) + 0.1 * np.random.randn(200)
X_test  = np.random.randn(50, 50)

mll, gp, dkl, losses = fit_dkgp(
    X_train, y_train,
    feature_dim=16,
    extractor_type='resnet',
    extractor_kwargs={'hidden_dim': 128, 'num_blocks': 2},
    num_epochs=1000,
    lr_features=1e-4,
    lr_gp=1e-2
)

mean, std = predict_dkgpr(dkl, X_test, return_std=True)

Classification

from engagegp import fit_dkgp_classifier, predict_classifier

# Fit (num_classes auto-detected if not specified)
model, losses = fit_dkgp_classifier(X_train, y_train, num_classes=4)

# Predict labels
y_pred = predict_classifier(model, X_test)

# Predict probabilities
y_proba = predict_classifier(model, X_test, return_proba=True)

# Batched prediction
y_pred = predict_classifier(model, X_test, batch_size=256)

With confidence weights

weights = np.array([1.0, 0.8, 1.0, 0.5, ...])
model, losses = fit_dkgp_classifier(X_train, y_train, confidence_weights=weights)

Bayesian Optimization

After fitting a regression model, use acquisition functions to guide optimization:

from engagegp import (
    expected_improvement,
    upper_confidence_bound,
    probability_of_improvement,
    thompson_sampling,
    expected_improvement_with_constraints,
)

# Expected Improvement
ei = expected_improvement(dkl, X_candidates, best_observed=y_train.max())

# Upper Confidence Bound
ucb = upper_confidence_bound(dkl, X_candidates, beta=2.0)

# Probability of Improvement
pi = probability_of_improvement(dkl, X_candidates, best_observed=y_train.max())

# Thompson Sampling
sample = thompson_sampling(dkl, X_candidates)

# EI with constraints
ei_c = expected_improvement_with_constraints(dkl, X_candidates, constraint_models=[...])

Pairwise GP (Preference Learning)

Learn from pairwise comparisons (A preferred over B) rather than absolute values:

from engagegp import (
    fit_dkgppw,
    predict_utility,
    dkgppw_eubo,   # Expected Utility of Best Option acquisition
    dkgppw_ucb,    # UCB acquisition
    acquire_preference,
    sample_comparison_pairs,
    get_simulated_preference,
)

# Fit from comparison pairs
model = fit_dkgppw(X, comparisons)  # comparisons: (n_pairs, 2) array of [winner_idx, loser_idx]

# Predict utility scores
utility_mean, utility_std = predict_utility(model, X_candidates)

# Acquire next pair to compare (active preference learning)
next_pair = acquire_preference(model, X_candidates, method='eubo')

Sample Weighting

Learnable per-sample weights for robust training against noisy labels:

from engagegp import SampleWeightModule, analyze_sample_weights

# Create weight module
weight_module = SampleWeightModule(n_samples=len(X_train))

# Get learned weights after training
weights = weight_module.get_weights().detach().numpy()  # shape (n_samples,)

# Analyze which samples are noisy
analysis = analyze_sample_weights(
    sample_weights=weights,
    y_train=y_train,
    predictions=train_preds,
    threshold=0.5
)
# Returns: {'noisy_indices', 'clean_indices', 'weight_stats', 'noisy_stats'}

Use SampleWeightedMLL in gpr.py to incorporate learned sample weights into training.


Attention Analysis

Inspect what the model has learned to focus on:

from engagegp import (
    get_attention_scores,
    get_attention_for_sample,
    analyze_attention_locality,
    summarize_attention,
)

# Attention scores for all training data
scores = get_attention_scores(model, X_train)

# Attention for a single sample
attn = get_attention_for_sample(model, X_train[0], average_heads=True)

# Analyze how localized the attention is
locality = analyze_attention_locality(scores)

# Full summary report
summarize_attention(model, X_train, sample_idx=0)

For DirectAttentionExtractor — direct feature-to-feature maps:

attention_map = extractor.get_attention_map(x)
# shape: (num_heads, input_dim, input_dim)
# e.g., for spectroscopy: which wavelengths attend to which

For AttentionWeightedExtractor — per-feature importance scores:

weights = extractor.get_attention_weights(x)
# shape: (input_dim,), sums to 1.0
top_features = np.argsort(weights)[-10:][::-1]

Model Persistence

from engagegp import save_model, load_model

save_model(dkl, 'my_model.pt')
dkl = load_model('my_model.pt')

Custom Hybrid Extractor

import torch
import torch.nn as nn
from engagegp import get_feature_extractor, fit_dkgp

class HybridExtractor(nn.Module):
    def __init__(self, input_dim, feature_dim):
        super().__init__()
        self.resnet   = get_feature_extractor('resnet',   input_dim, feature_dim // 2)
        self.attention = get_feature_extractor('attention', input_dim, feature_dim // 2)
        self.input_dim  = input_dim
        self.feature_dim = feature_dim

    def forward(self, x):
        return torch.cat([self.resnet(x), self.attention(x)], dim=-1)

hybrid = HybridExtractor(input_dim=100, feature_dim=16)
mll, gp, dkl, losses = fit_dkgp(
    X, y,
    extractor_type='custom',
    extractor_kwargs={'custom_extractor': hybrid}
)

API Reference

Regression (gpr.py)

Symbol Description
DeepKernelGP Main regression model class
ConfidenceWeightedMLL Weighted MLL for heteroscedastic data
SampleWeightedMLL MLL with learnable per-sample weights
train_dkgp() Low-level training function
fit_dkgp() High-level training interface
predict_dkgpr() Prediction (mean, std, full distribution)
expected_improvement() EI acquisition function
upper_confidence_bound() UCB acquisition function
probability_of_improvement() PI acquisition function
thompson_sampling() Thompson sampling acquisition
expected_improvement_with_constraints() Constrained EI

Classification (gpc.py)

Symbol Description
DeepKernelGPClassifier Main classification model class
BinaryGPClassificationModel Variational GP for binary classification
MultiClassGPClassificationModel Variational GP for multi-class
ConfidenceWeightedELBO Weighted ELBO loss
train_dkgp_classifier() Low-level training function
fit_dkgp_classifier() High-level training interface
predict_classifier() Prediction (labels or probabilities)

Pairwise GP (gppw.py)

Symbol Description
DeepKernelPairwiseGP Main pairwise GP model class
fit_dkgppw() High-level training interface
train_dkgppw() Low-level training function
predict_utility() Predict utility scores
dkgppw_eubo() EUBO acquisition function
dkgppw_ucb() UCB acquisition function
acquire_preference() Select next pair to compare
sample_comparison_pairs() Sample random pairs
get_simulated_preference() Simulate preferences from a true function
get_user_preference() Collect preference from user interactively
plot_option() Plot a candidate option
plot_predictions() Plot predicted utilities

Feature Extractors (models.py)

Symbol Description
FCFeatureExtractor Simple FC
FCBNFeatureExtractor FC + BatchNorm + Dropout (default)
ResNetFeatureExtractor ResNet-style with skip connections
AttentionFeatureExtractor Self-attention (attends in hidden space)
DirectAttentionExtractor Self-attention on raw input features
AttentionWeightedExtractor Per-feature importance weighting
get_feature_extractor() Factory function

Sample Weighting (sample_weighting.py)

Symbol Description
SampleWeightModule Learnable per-sample weight module
analyze_sample_weights() Identify noisy/outlier samples

Utilities (utils.py)

Symbol Description
get_attention_scores() Attention scores across dataset
get_attention_for_sample() Attention for a single input
analyze_attention_locality() Measure attention localization
summarize_attention() Print attention summary report
save_model() Save model to file
load_model() Load model from file
split_train_test() Train/test split utility
get_grid_coords() Grid coordinates for image data
get_subimages() Extract subimage patches

License

MIT

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

engage_gp-0.1.0.tar.gz (42.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

engage_gp-0.1.0-py3-none-any.whl (45.2 kB view details)

Uploaded Python 3

File details

Details for the file engage_gp-0.1.0.tar.gz.

File metadata

  • Download URL: engage_gp-0.1.0.tar.gz
  • Upload date:
  • Size: 42.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for engage_gp-0.1.0.tar.gz
Algorithm Hash digest
SHA256 e74c4249cf5f2991bc00d5664cebe39c098a9ac3c26030b690a23fbcc95a9c81
MD5 42801025619879213a7a789c54682fad
BLAKE2b-256 a844e962447937607055065938a714773dccbaf4b9ec53c6431091b7c09e475d

See more details on using hashes here.

File details

Details for the file engage_gp-0.1.0-py3-none-any.whl.

File metadata

  • Download URL: engage_gp-0.1.0-py3-none-any.whl
  • Upload date:
  • Size: 45.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}

File hashes

Hashes for engage_gp-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b72efc3ea3272670e8a673bbaf8a558cae52fcba8d6fab5fca4640c8eff7b8e5
MD5 5c4f9d7d2afe1d81c0cda5d05eb53aa1
BLAKE2b-256 6c9afc3e94d1c765e354227d25cc64e87c936dacc382fb7030f609f33b84a305

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page