Calculating exact and approximate confidence and information metrics for differentiable functions.
Project description
🧐 duvida
duvida (Portuguese for doubt) is a suite of python tools for calculating confidence and information metrics for deep learning. It provides lower-level function transforms for exact and approximate Hessian diagonals in JAX and pytorch.
Installation
The easy way
You can install the precompiled version directly using pip. You need to specify the machine learning framework
that you want to use:
$ pip install duvida[jax]
# or
$ pip install duvida[jax_cuda12] # for JAX installing CUDA 12 for GPU support
# or
$ pip install duvida[jax_cuda12_local] # for JAX using a locally-installed CUDA 12
# or
$ pip install duvida[torch]
We have implemented JAX and pytorch functional transformations for approximate and exact Hessian diagonals, and doubtscore and information sensitivity. These can be used with JAX- and pytorch-based frameworks.
From source
Clone the repository, then cd into it. Then run:
$ pip install -e .[torch]
Python API
duvida provides functional transforms for JAX and pytorch that calculate either exact or approximate Hessian diagonals.
You can check which backend you're using:
>>> from duvida.stateless.config import config
>>> config
Config(backend='jax', precision='double', fallback=True)
It can be changed:
>>> config.set_backend("torch")
'torch'
>>> config
Config(backend='torch', precision='double', fallback=True)
Now you can calculate exact Hessian diagonals without calculating the full matrix:
>>> from duvida.stateless.utils import hessian
>>> import duvida.stateless.numpy as dnp
>>> f = lambda x: dnp.sum(x ** 3. + x ** 2. + 4.)
>>> a = dnp.array([1., 2.])
>>> exact_diagonal(f)(a) == dnp.diag(hessian(f)(a))
Array([ True, True], dtype=bool)
Various approximations are also allowed.
>>> from duvida.stateless.hessians import get_approximators
>>> get_approximators() # Use no arguments to show what's available
('squared_jacobian', 'exact_diagonal', 'bekas', 'rough_finite_difference')
Now apply:
>>> approx_hessian_diag = get_approximators("bekas")
>>> g = lambda x: dnp.sum(dnp.sum(x) ** 3. + x ** 2. + 4.)
>>> a = dnp.array([1., 2.])
>>> dnp.diag(hessian(g)(a)) # Exact
Array([38., 38.], dtype=float64)
>>> approx_hessian_diag(g, n=1000)(a) # Less accurate when parameters interact
Array([38.52438307, 38.49679655], dtype=float64)
>>> approx_hessian_diag(g, n=1000, seed=1)(a) # Change the seed to alter the outcome
Array([39.07878869, 38.97796601], dtype=float64)
Issues, problems, suggestions
Add to the issue tracker.
Documentation
(To come at ReadTheDocs.)
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file duvida-0.0.3.post1.tar.gz.
File metadata
- Download URL: duvida-0.0.3.post1.tar.gz
- Upload date:
- Size: 14.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.24
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2a1b6e682af3a4a0ca11c741151b590df17802ba4bf37eada20387cbe1d1110b
|
|
| MD5 |
ebb7a1d66a05d968ff6decdbeb121e4b
|
|
| BLAKE2b-256 |
76855c1ba716fec8339e7359740c22da8cfde568dd6bd75cd2081a86da7be2c7
|
File details
Details for the file duvida-0.0.3.post1-py3-none-any.whl.
File metadata
- Download URL: duvida-0.0.3.post1-py3-none-any.whl
- Upload date:
- Size: 15.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.9.24
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
2f9acb206bd55e790f12f40b1736a21fdbfc52204e9cbb3bf81da6762e1af211
|
|
| MD5 |
45e647891e7d289f94513117176eef03
|
|
| BLAKE2b-256 |
99a787b3cb77ccf6f64cdc4e49b2668612d97f9b1115b75558edbc7c5495825f
|