Skip to main content

Calculating exact and approximate confidence and information metrics for differentiable functions.

Project description

🧐 duvida

GitHub Workflow Status (with branch) PyPI - Python Version PyPI

duvida (Portuguese for doubt) is a suite of python tools for calculating confidence and information metrics for deep learning. It provides lower-level function transforms for exact and approximate Hessian diagonals in JAX and pytorch.

Installation

The easy way

You can install the precompiled version directly using pip. You need to specify the machine learning framework that you want to use:

$ pip install duvida[jax]
# or
$ pip install duvida[jax_cuda12]  # for JAX installing CUDA 12 for GPU support
# or
$ pip install duvida[jax_cuda12_local]  # for JAX using a locally-installed CUDA 12
# or
$ pip install duvida[torch]

We have implemented JAX and pytorch functional transformations for approximate and exact Hessian diagonals, and doubtscore and information sensitivity. These can be used with JAX- and pytorch-based frameworks.

From source

Clone the repository, then cd into it. Then run:

$ pip install -e .[torch]

Python API

duvida provides functional transforms for JAX and pytorch that calculate either exact or approximate Hessian diagonals.

You can check which backend you're using:

>>> from duvida.stateless.config import config
>>> config
Config(backend='jax', precision='double', fallback=True)

It can be changed:

>>> config.set_backend("torch")
'torch'
>>> config
Config(backend='torch', precision='double', fallback=True)

Now you can calculate exact Hessian diagonals without calculating the full matrix:

>>> from duvida.stateless.utils import hessian
>>> import duvida.stateless.numpy as dnp 
>>> f = lambda x: dnp.sum(x ** 3. + x ** 2. + 4.)
>>> a = dnp.array([1., 2.])
>>> exact_diagonal(f)(a) == dnp.diag(hessian(f)(a))
Array([ True,  True], dtype=bool)

Various approximations are also allowed.

>>> from duvida.stateless.hessians import get_approximators
>>> get_approximators()  # Use no arguments to show what's available
('squared_jacobian', 'exact_diagonal', 'bekas', 'rough_finite_difference')

Now apply:

>>> approx_hessian_diag = get_approximators("bekas")
>>> g = lambda x: dnp.sum(dnp.sum(x) ** 3. + x ** 2. + 4.)
>>> a = dnp.array([1., 2.])
>>> dnp.diag(hessian(g)(a))  # Exact
Array([38., 38.], dtype=float64)
>>> approx_hessian_diag(g, n=1000)(a)  # Less accurate when parameters interact
Array([38.52438307, 38.49679655], dtype=float64)
>>> approx_hessian_diag(g, n=1000, seed=1)(a)  # Change the seed to alter the outcome
Array([39.07878869, 38.97796601], dtype=float64)

Issues, problems, suggestions

Add to the issue tracker.

Documentation

(To come at ReadTheDocs.)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

duvida-0.0.3.post1.tar.gz (14.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

duvida-0.0.3.post1-py3-none-any.whl (15.6 kB view details)

Uploaded Python 3

File details

Details for the file duvida-0.0.3.post1.tar.gz.

File metadata

  • Download URL: duvida-0.0.3.post1.tar.gz
  • Upload date:
  • Size: 14.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.24

File hashes

Hashes for duvida-0.0.3.post1.tar.gz
Algorithm Hash digest
SHA256 2a1b6e682af3a4a0ca11c741151b590df17802ba4bf37eada20387cbe1d1110b
MD5 ebb7a1d66a05d968ff6decdbeb121e4b
BLAKE2b-256 76855c1ba716fec8339e7359740c22da8cfde568dd6bd75cd2081a86da7be2c7

See more details on using hashes here.

File details

Details for the file duvida-0.0.3.post1-py3-none-any.whl.

File metadata

  • Download URL: duvida-0.0.3.post1-py3-none-any.whl
  • Upload date:
  • Size: 15.6 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.9.24

File hashes

Hashes for duvida-0.0.3.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 2f9acb206bd55e790f12f40b1736a21fdbfc52204e9cbb3bf81da6762e1af211
MD5 45e647891e7d289f94513117176eef03
BLAKE2b-256 99a787b3cb77ccf6f64cdc4e49b2668612d97f9b1115b75558edbc7c5495825f

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page