Skip to main content

An implementation of Pointwise Sliced Mutual Information (PSMI) for machine learning

Project description

Numerical estimator of Pointwise Sliced Mutual Information (PSMI).

Jérémie Dentan, École Polytechnique

Setup

Our library is published on PyPI!

pip install psmi

Usage

We implement a class PSMI which should be used in a similar way to scikit-learn classes, with fit, transform and fit_transform methods.

We only implement PSMI between scalar feature and integer labels belonging to a finite number of classes. We use algorithm 1 in [1] to estimate PSMI. The only hyperparameter of this algorithm is the number of estimator (i.e. the number of directions samples to estimate PSMI).

We propose two approaches to compute PSMI.

  1. Manual. You simply pass argument n_estimators with the desired value.

  2. Automatic. In that cas, you pass n_estimators="auto" and an algorithm will be used to determine a suitable value for n_estimators.

Example

import numpy as np
from psmi import PSMI

# Generating data
n_samples, dim, n_labels = 100, 1024, 5
features = np.random.random((n_samples, dim))
labels = np.random.randint(n_labels, size=n_samples)

# Manual number of estimator
psmi_estimator = PSMI(n_estimators=500)
psmi_mean, psmi_std, psmi_full = psmi_estimator.fit_transform(features, labels)
print(f"psmi_mean: {psmi_mean.shape}")  # Should be (100,)
print(f"psmi_std: {psmi_std.shape}")  # Should be (100,)
print(f"psmi_full: {psmi_full.shape}")  # Should be (100,500)
print(f"Num of estimator: {psmi_estimator.n_estimators}")  # Should be 500

# Automatic number of estimator
psmi_estimator = PSMI()
psmi_mean, psmi_std, psmi_full = psmi_estimator.fit_transform(features, labels)
print(f"psmi_mean: {psmi_mean.shape}")  # Should be (100,)
print(f"psmi_std: {psmi_std.shape}")  # Should be (100,)
print(f"psmi_full: {psmi_full.shape}")  # Should be (100,<psmi_estimator.n_estimators>)
print(f"Num of estimator: {psmi_estimator.n_estimators}")

# You can separate the fit and transform
n_test = 5
features_test = np.random.random((n_test, dim))
labels_test = np.random.randint(n_labels, size=n_test)
psmi_estimator = PSMI()
psmi_estimator.fit(features, labels)
psmi_mean, psmi_std, psmi_full = psmi_estimator.transform(features_test, labels_test)
print(f"psmi_mean: {psmi_mean.shape}")  # Should be (5,)
print(f"psmi_std: {psmi_std.shape}")  # Should be (5,)
print(f"psmi_full: {psmi_full.shape}")  # Should be (5,<psmi_estimator.n_estimators>)
print(f"Num of estimator: {psmi_estimator.n_estimators}")

Details on the auto mode

More specifically, we will iteratively add more and more estimators, in blocks of min_n_estimators. We stop this process when the PSMI of the elements that have the lowest PSMI minimally evolved between the current step and the one with half as many estimators.

More specifically, if lowest_psmi_quantile=0.05, we consider the 5% of elements with the lowest PSMI at current step. Then, we compare this value to the PSMI of theses elements using only the first int(n*milestone) blocks of estimators, where n is the current number of blocks that was added. Then we compare the absolute value of the variation divided by the PSMI at the current step. If it is below max_variation_of_the_lowest, we stop. Else, we add another block of min_n_estimators estimators.

For example, the default values corresponds to blocks of 500 estimators. We add blocks untill the 5% of elements with lowest PSMI have varied of less than 5% between the current step and the one with half as many blocks.

[1] Shelvia Wongso et al. Pointwise Sliced Mutual Information for Neural Network Explainability. IEEE International Symposium on Information Theory (ISIT). 2023. DOI: 10.1109/ISIT54713.2023.10207010

Contributing

You are welcome to submit pull requests! Please use pre-commit to correctly format your code:

pip install -r .github/dev-requirements.txt
pre-commit install

Please test your code:

pytest

License and Copyright

Copyright 2024-present Laboratoire d'Informatique de Polytechnique. This project is licensed under the GNU Lesser General Public License v3.0. See the LICENSE file for details.

Please cite this work as follows:

@misc{dentan_predicting_2024,
	title = {Predicting and analysing memorization within fine-tuned Large Language Models},
	url = {https://arxiv.org/abs/2409.18858},
	author = {Dentan, Jérémie and Buscaldi, Davide and Shabou, Aymen and Vanier, Sonia},
	month = sep,
	year = {2024},
}

Acknowldgements

This work received financial support from Crédit Agricole SA through the research chair ”Trustworthy and responsible AI” with École Polytechnique.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

psmi-0.2.0.tar.gz (12.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

psmi-0.2.0-py3-none-any.whl (11.1 kB view details)

Uploaded Python 3

File details

Details for the file psmi-0.2.0.tar.gz.

File metadata

  • Download URL: psmi-0.2.0.tar.gz
  • Upload date:
  • Size: 12.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for psmi-0.2.0.tar.gz
Algorithm Hash digest
SHA256 0df370d3ff57340dd78b0ca28ea7405f8e1a68a45b72efc894e621d9d9a8d891
MD5 69b87c9c01cdba44688226c12472007a
BLAKE2b-256 6ac6db2d51d1b5e75ef29945203b04c82b7debf098ce9525e6a7a92b4824c514

See more details on using hashes here.

File details

Details for the file psmi-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: psmi-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 11.1 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.12.8

File hashes

Hashes for psmi-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8923019c7c4b3bc3bd5d91d65d750ebfa5af2f2d4c5fb4a43aa3d5f49ff70653
MD5 1f4b662255787a31554a386ac7692497
BLAKE2b-256 87f9d766dce86499fcdcf235219acd95c1a042b58ec0a5bca0cc7701059f78e0

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page