ENGAGE: Explainable Neural GP with Active Guided Exploration — human-guided, interpretable deep kernel Gaussian processes with active learning
Project description
ENGAGE
Explainable Neural GP with Active Guided Exploration
ENGAGE combines deep neural networks with Gaussian Processes for regression, classification, and preference learning — designed around three core principles:
- Human-guided — pairwise preference learning lets users steer the model through comparisons rather than numeric labels; Bayesian optimization acquisition functions support active experimental design
- Explainable — attention-based extractors expose what the model focuses on; per-feature importance maps and head×head attention matrices are first-class outputs, not afterthoughts
- Active learning — uncertainty estimates from the GP drive principled query selection; sample weighting identifies unreliable measurements automatically
The neural network learns a compressed feature representation; the GP is fit on top for calibrated uncertainty-aware predictions.
Version: 0.2.0
Install: pip install engage-gp
Import: import engagegp
engage-gp/ (import as: import engagegp)
engagegp/
├── models.py # Feature extractors
├── gpr.py # GP Regression + Bayesian optimization
├── gpc.py # GP Classification
├── gppw.py # GP Pairwise (preference learning)
├── sample_weighting.py # Learnable sample weights
├── utils.py # Attention analysis, model I/O, utilities
└── __init__.py # Package exports
Feature Extractors
Available Extractors
| Type | Class | Best for |
|---|---|---|
fc |
FCFeatureExtractor |
Quick prototyping, small datasets |
fcbn |
FCBNFeatureExtractor |
General use (default), regularized |
resnet |
ResNetFeatureExtractor |
Deeper networks, gradient stability |
attention |
AttentionFeatureExtractor |
Feature interactions, relational data |
direct_attention |
DirectAttentionExtractor |
Attention on raw inputs (e.g., spectroscopy wavelengths) |
attention_weighted |
AttentionWeightedExtractor |
Interpretable feature importance weights |
custom |
any nn.Module |
User-provided architecture |
Attention Extractors: Key Differences
All three attention-based extractors use the self-attention mechanism differently. Choosing the right one depends on what you want to learn and interpret.
attention |
direct_attention |
attention_weighted |
|
|---|---|---|---|
| What attends to what | Heads attend to each other (in hidden space) | Raw input features attend to each other | Each input feature gets a scalar importance score |
| Attention map shape | (batch, num_heads, num_heads) |
(batch, num_heads, input_dim, input_dim) |
(batch, input_dim) |
| Projection | Input → hidden → multi-head | Input directly split into heads | Input → attention scores (no projection) |
| Interpretation | Which learned perspectives are correlated | Which input features interact with which | Which input features matter most |
| Computational cost | Low (heads × heads) | High (input_dim²) | Low |
| Best for | General relational data | Spectroscopy, signals, spatial data | Feature selection, interpretability |
attention (AttentionFeatureExtractor)
Projects the input into a hidden space, then applies multi-head self-attention. The attention is computed between the num_heads learned projections of the input — not between input features directly. The resulting map (batch, num_heads, num_heads) tells you how the model's internal "perspectives" relate to each other. Good general-purpose choice when you want attention-based representation learning without the cost of input×input maps.
direct_attention (DirectAttentionExtractor)
Applies attention directly to the raw input features — no projection into hidden space first. Each head computes an input_dim × input_dim attention map, capturing which input positions (e.g., wavelengths, pixels, time steps) attend to which others. The map shape (batch, num_heads, input_dim, input_dim) is interpretable as a feature-to-feature relationship matrix. Use this when the spatial or sequential structure of the input itself is meaningful.
attention_weighted (AttentionWeightedExtractor)
Does not compute pairwise attention between features. Instead, it learns a single importance weight per input feature — a soft feature selection mask. The weights (batch, input_dim) sum to 1.0 and are applied element-wise before passing through a base extractor (default: fcbn). This is the most interpretable option: you can directly see which input dimensions the model relies on. Use this when you want to know which features matter, not how they interact.
Factory Function
from engagegp import get_feature_extractor
# Simple FC
extractor = get_feature_extractor('fc', input_dim=100, feature_dim=16)
# FC + BatchNorm (recommended default)
extractor = get_feature_extractor('fcbn', input_dim=100, feature_dim=16,
hidden_dims=[512, 256, 128], dropout=0.3)
# ResNet
extractor = get_feature_extractor('resnet', input_dim=100, feature_dim=16,
hidden_dim=256, num_blocks=3)
# Self-attention
extractor = get_feature_extractor('attention', input_dim=100, feature_dim=16,
hidden_dim=128, num_heads=4)
# Direct attention on raw features (e.g., 256 wavelengths → wavelength-to-wavelength map)
extractor = get_feature_extractor('direct_attention', input_dim=256, feature_dim=16,
num_heads=4)
# Attention-weighted (interpretable per-feature importance scores)
extractor = get_feature_extractor('attention_weighted', input_dim=256, feature_dim=16,
base_extractor='fcbn')
# Custom nn.Module
import torch.nn as nn
my_net = nn.Sequential(nn.Linear(100, 64), nn.ReLU(), nn.Linear(64, 16))
extractor = get_feature_extractor('custom', custom_extractor=my_net)
Regression
from engagegp import fit_dkgp, predict_dkgpr
# Fit (default: FC + BatchNorm extractor)
mll, gp, dkl, losses = fit_dkgp(X_train, y_train, feature_dim=16)
# Predict
mean, std = predict_dkgpr(dkl, X_test, return_std=True)
With confidence weights (heteroscedastic data)
weights = np.array([1.0, 0.5, 1.0, ...]) # Lower weight for noisy samples
mll, gp, dkl, losses = fit_dkgp(X_train, y_train, confidence_weights=weights)
Full example
import numpy as np
from engagegp import fit_dkgp, predict_dkgpr
np.random.seed(42)
X_train = np.random.randn(200, 50)
y_train = np.sum(X_train[:, :5], axis=1) + 0.1 * np.random.randn(200)
X_test = np.random.randn(50, 50)
mll, gp, dkl, losses = fit_dkgp(
X_train, y_train,
feature_dim=16,
extractor_type='resnet',
extractor_kwargs={'hidden_dim': 128, 'num_blocks': 2},
num_epochs=1000,
lr_features=1e-4,
lr_gp=1e-2
)
mean, std = predict_dkgpr(dkl, X_test, return_std=True)
Classification
from engagegp import fit_dkgp_classifier, predict_classifier
# Fit (num_classes auto-detected if not specified)
model, losses = fit_dkgp_classifier(X_train, y_train, num_classes=4)
# Predict labels
y_pred = predict_classifier(model, X_test)
# Predict probabilities
y_proba = predict_classifier(model, X_test, return_proba=True)
# Batched prediction
y_pred = predict_classifier(model, X_test, batch_size=256)
With confidence weights
weights = np.array([1.0, 0.8, 1.0, 0.5, ...])
model, losses = fit_dkgp_classifier(X_train, y_train, confidence_weights=weights)
Bayesian Optimization
After fitting a regression model, use acquisition functions to guide optimization:
from engagegp import (
expected_improvement,
upper_confidence_bound,
probability_of_improvement,
thompson_sampling,
expected_improvement_with_constraints,
)
# Expected Improvement
ei = expected_improvement(dkl, X_candidates, best_observed=y_train.max())
# Upper Confidence Bound
ucb = upper_confidence_bound(dkl, X_candidates, beta=2.0)
# Probability of Improvement
pi = probability_of_improvement(dkl, X_candidates, best_observed=y_train.max())
# Thompson Sampling
sample = thompson_sampling(dkl, X_candidates)
# EI with constraints
ei_c = expected_improvement_with_constraints(dkl, X_candidates, constraint_models=[...])
Pairwise GP (Preference Learning)
Learn from pairwise comparisons (A preferred over B) rather than absolute values:
from engagegp import (
fit_dkgppw,
predict_utility,
dkgppw_eubo, # Expected Utility of Best Option acquisition
dkgppw_ucb, # UCB acquisition
acquire_preference,
sample_comparison_pairs,
get_simulated_preference,
)
# Fit from comparison pairs
model = fit_dkgppw(X, comparisons) # comparisons: (n_pairs, 2) array of [winner_idx, loser_idx]
# Predict utility scores
utility_mean, utility_std = predict_utility(model, X_candidates)
# Acquire next pair to compare (active preference learning)
next_pair = acquire_preference(model, X_candidates, method='eubo')
Sample Weighting
Learnable per-sample weights for robust training against noisy labels:
from engagegp import SampleWeightModule, analyze_sample_weights
# Create weight module
weight_module = SampleWeightModule(n_samples=len(X_train))
# Get learned weights after training
weights = weight_module.get_weights().detach().numpy() # shape (n_samples,)
# Analyze which samples are noisy
analysis = analyze_sample_weights(
sample_weights=weights,
y_train=y_train,
predictions=train_preds,
threshold=0.5
)
# Returns: {'noisy_indices', 'clean_indices', 'weight_stats', 'noisy_stats'}
Use SampleWeightedMLL in gpr.py to incorporate learned sample weights into training.
Attention Analysis
Inspect what the model has learned to focus on:
from engagegp import (
get_attention_scores,
get_attention_for_sample,
analyze_attention_locality,
summarize_attention,
)
# Attention scores for all training data
scores = get_attention_scores(model, X_train)
# Attention for a single sample
attn = get_attention_for_sample(model, X_train[0], average_heads=True)
# Analyze how localized the attention is
locality = analyze_attention_locality(scores)
# Full summary report
summarize_attention(model, X_train, sample_idx=0)
For DirectAttentionExtractor — direct feature-to-feature maps:
attention_map = extractor.get_attention_map(x)
# shape: (num_heads, input_dim, input_dim)
# e.g., for spectroscopy: which wavelengths attend to which
For AttentionWeightedExtractor — per-feature importance scores:
weights = extractor.get_attention_weights(x)
# shape: (input_dim,), sums to 1.0
top_features = np.argsort(weights)[-10:][::-1]
Model Persistence
from engagegp import save_model, load_model
save_model(dkl, 'my_model.pt')
dkl = load_model('my_model.pt')
Custom Hybrid Extractor
import torch
import torch.nn as nn
from engagegp import get_feature_extractor, fit_dkgp
class HybridExtractor(nn.Module):
def __init__(self, input_dim, feature_dim):
super().__init__()
self.resnet = get_feature_extractor('resnet', input_dim, feature_dim // 2)
self.attention = get_feature_extractor('attention', input_dim, feature_dim // 2)
self.input_dim = input_dim
self.feature_dim = feature_dim
def forward(self, x):
return torch.cat([self.resnet(x), self.attention(x)], dim=-1)
hybrid = HybridExtractor(input_dim=100, feature_dim=16)
mll, gp, dkl, losses = fit_dkgp(
X, y,
extractor_type='custom',
extractor_kwargs={'custom_extractor': hybrid}
)
API Reference
Regression (gpr.py)
| Symbol | Description |
|---|---|
DeepKernelGP |
Main regression model class |
ConfidenceWeightedMLL |
Weighted MLL for heteroscedastic data |
SampleWeightedMLL |
MLL with learnable per-sample weights |
train_dkgp() |
Low-level training function |
fit_dkgp() |
High-level training interface |
predict_dkgpr() |
Prediction (mean, std, full distribution) |
expected_improvement() |
EI acquisition function |
upper_confidence_bound() |
UCB acquisition function |
probability_of_improvement() |
PI acquisition function |
thompson_sampling() |
Thompson sampling acquisition |
expected_improvement_with_constraints() |
Constrained EI |
Classification (gpc.py)
| Symbol | Description |
|---|---|
DeepKernelGPClassifier |
Main classification model class |
BinaryGPClassificationModel |
Variational GP for binary classification |
MultiClassGPClassificationModel |
Variational GP for multi-class |
ConfidenceWeightedELBO |
Weighted ELBO loss |
train_dkgp_classifier() |
Low-level training function |
fit_dkgp_classifier() |
High-level training interface |
predict_classifier() |
Prediction (labels or probabilities) |
Pairwise GP (gppw.py)
| Symbol | Description |
|---|---|
DeepKernelPairwiseGP |
Main pairwise GP model class |
fit_dkgppw() |
High-level training interface |
train_dkgppw() |
Low-level training function |
predict_utility() |
Predict utility scores |
dkgppw_eubo() |
EUBO acquisition function |
dkgppw_ucb() |
UCB acquisition function |
acquire_preference() |
Select next pair to compare |
sample_comparison_pairs() |
Sample random pairs |
get_simulated_preference() |
Simulate preferences from a true function |
get_user_preference() |
Collect preference from user interactively |
plot_option() |
Plot a candidate option |
plot_predictions() |
Plot predicted utilities |
Feature Extractors (models.py)
| Symbol | Description |
|---|---|
FCFeatureExtractor |
Simple FC |
FCBNFeatureExtractor |
FC + BatchNorm + Dropout (default) |
ResNetFeatureExtractor |
ResNet-style with skip connections |
AttentionFeatureExtractor |
Self-attention (attends in hidden space) |
DirectAttentionExtractor |
Self-attention on raw input features |
AttentionWeightedExtractor |
Per-feature importance weighting |
get_feature_extractor() |
Factory function |
Sample Weighting (sample_weighting.py)
| Symbol | Description |
|---|---|
SampleWeightModule |
Learnable per-sample weight module |
analyze_sample_weights() |
Identify noisy/outlier samples |
Utilities (utils.py)
| Symbol | Description |
|---|---|
get_attention_scores() |
Attention scores across dataset |
get_attention_for_sample() |
Attention for a single input |
analyze_attention_locality() |
Measure attention localization |
summarize_attention() |
Print attention summary report |
save_model() |
Save model to file |
load_model() |
Load model from file |
split_train_test() |
Train/test split utility |
get_grid_coords() |
Grid coordinates for image data |
get_subimages() |
Extract subimage patches |
License
MIT
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file engage_gp-0.1.0.tar.gz.
File metadata
- Download URL: engage_gp-0.1.0.tar.gz
- Upload date:
- Size: 42.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
e74c4249cf5f2991bc00d5664cebe39c098a9ac3c26030b690a23fbcc95a9c81
|
|
| MD5 |
42801025619879213a7a789c54682fad
|
|
| BLAKE2b-256 |
a844e962447937607055065938a714773dccbaf4b9ec53c6431091b7c09e475d
|
File details
Details for the file engage_gp-0.1.0-py3-none-any.whl.
File metadata
- Download URL: engage_gp-0.1.0-py3-none-any.whl
- Upload date:
- Size: 45.2 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: uv/0.9.13 {"installer":{"name":"uv","version":"0.9.13"},"python":null,"implementation":{"name":null,"version":null},"distro":{"name":"macOS","version":null,"id":null,"libc":null},"system":{"name":null,"release":null},"cpu":null,"openssl_version":null,"setuptools_version":null,"rustc_version":null,"ci":null}
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b72efc3ea3272670e8a673bbaf8a558cae52fcba8d6fab5fca4640c8eff7b8e5
|
|
| MD5 |
5c4f9d7d2afe1d81c0cda5d05eb53aa1
|
|
| BLAKE2b-256 |
6c9afc3e94d1c765e354227d25cc64e87c936dacc382fb7030f609f33b84a305
|