Jensen-Shannon divergence estimation for tabular data using discriminator-based methods
Project description
JensenShannonDivergence
Python package for Jensen-Shannon divergence estimation on tabular data.
The core API works with NumPy arrays, pandas DataFrames, PyTorch tensors, or numeric array-like values.
Installation
From GitHub
If the package is only published on GitHub, install it directly from the repository:
pip install "git+https://github.com/AlbaGarridoLopezz/jensenshannondivergence.git"
To install a specific branch or tag:
pip install "git+https://github.com/AlbaGarridoLopezz/jensenshannondivergence.git@main"
From PyPI
After publishing the package to PyPI, users will be able to install it with:
pip install jensenshannondivergence
Optional extras can be installed with:
pip install "jensenshannondivergence[all]"
Local Development
Clone the repository and install it in editable mode:
python -m venv venv
source venv/bin/activate
pip install -e .
If you want the optional extras:
pip install -e .[all]
What Users Should Import
The library is custom-first: call it with your own real/reference samples and generated/synthetic samples.
Simple API
estimate_jensen_shannon returns a float with the estimated Jensen-Shannon divergence.
import numpy as np
from jensenshannondivergence import estimate_jensen_shannon
x_reference = np.random.normal(size=(1000, 10))
x_synthetic = np.random.normal(loc=0.2, size=(1000, 10))
js = estimate_jensen_shannon(
x_reference,
x_synthetic,
discriminator_type="MLP", # MLP, RF, XGBoost, LogReg, LogRegPol, TabPFN
n_iter=30, # used by RF/XGBoost/LogReg optimizers
seed=0,
)
print(js)
Use return_result=True if you need the full evaluator and output path:
from jensenshannondivergence import estimate_jensen_shannon
result = estimate_jensen_shannon(
x_reference,
x_synthetic,
discriminator_type="RF",
m=500,
l=250,
n_iter=20,
return_result=True,
)
print(result.evaluator.disc_js)
print(result.results_path)
Predefined experiments (repository)
Predefined experiment loaders live under the experiments/ folder and are intended
for repository-based development. The library package no longer exposes a
use_predefined runtime option. To run a predefined use case from the repo,
load the experiment data and call the library API with the returned tensors:
# run from the repository root (so `experiments` is importable)
from experiments import data as exp_data
from jensenshannondivergence import estimate_jensen_shannon
# load tensors for a use case
x_r, x_s, dist_r, dist_s = exp_data.load_data('use_case_7', n=10, m=2000, l=2000, seed=0)
js = estimate_jensen_shannon(
x_r,
x_s,
discriminator_type='MLP',
m=2000,
l=2000,
seed=0,
)
If you prefer orchestration, use main_experiments.py which loads experiment
data and calls the library for you (see the Experiments CLI section above).
CLI For Your Own Data
After installation, jsd-estimate estimates JS directly from two CSV files:
jsd-estimate --x-p real_data.csv --x-q gen_data.csv --discriminator MLP --epochs 100
Useful arguments:
--discriminator/--classifier--m,--l--n-iter--epochs--ratio-correction-mode--results-root--save-plots
Experiments CLI
main_experiments.py is only for predefined experiments and is intended for repository development runs.
List available use cases:
python main_experiments.py list
Train selected experiments:
python main_experiments.py train --discriminators MLP RF --experiments use_case_1 use_case_3 --n-iter 30
Run discriminator tests across classifiers:
python main_experiments.py test-discriminators --experiments use_case_1 --n-iter 20
Use --models, --datasets-use-case-4, and --datasets-use-case-11 for the real-data use cases.
Interactive Tutorial Notebook
Use the tutorial notebook for a minimal end-to-end run:
Tutorial.ipynb
The tutorial saves run outputs to:
experiments/tutorials_outputs/
If you edit code under src/, restart the notebook kernel before re-running cells.
Experiments Path Convention
All read/write experiment paths are centralized under experiments/, including:
experiments/data/experiments/results_MLP/experiments/results_RF/experiments/results_XGBoost/experiments/results_LogReg/experiments/results_LogRegPol/experiments/results_TabPFN/experiments/results_discriminators/experiments/calibration_audit_results/experiments/tutorials_outputs/
You can override this root with the environment variable JSD_EXPERIMENTS_ROOT.
Notes on results and plotting behavior: library functions do not produce plots by default — plotting is performed by the experiments scripts and notebooks. Experiment outputs (tables, CSVs, and optional plots) are written under the experiments root; control plot saving with the save_plots argument in the experiments CLI or by setting save_plots=True in the notebooks.
Notes
- For best performance, run with GPU when available.
- Some baselines require optional dependencies (
syndat,synthcity,tabpfn). - If using TabPFN, make sure your PyTorch version is compatible.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file jensenshannondivergence-0.1.0.tar.gz.
File metadata
- Download URL: jensenshannondivergence-0.1.0.tar.gz
- Upload date:
- Size: 25.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
ff6aba80220d7ab8779ac40218748f796b831834fbea0d3a5868d79d4dd4df35
|
|
| MD5 |
a3c445342123e1eda312dc72a6a0e2ae
|
|
| BLAKE2b-256 |
0731966b4ed6cc1f0fcd605085ae627a49605fe39bfaa588bceb6aa94a6a6a19
|
File details
Details for the file jensenshannondivergence-0.1.0-py3-none-any.whl.
File metadata
- Download URL: jensenshannondivergence-0.1.0-py3-none-any.whl
- Upload date:
- Size: 31.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.10.20
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
0ec0fbde9d7765cbaa05c214af6ee8bc70059bf714b10dbedcb1c6905b7362d6
|
|
| MD5 |
5526d8539a088ba5a37218adf9046887
|
|
| BLAKE2b-256 |
6f5ee6a81e742d12028b6d2edf3aa25c8d3ee1fdc8e16702dd0356cb6edcb0eb
|