No project description provided
Project description
ml_dtypes
ml_dtypes
is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
bfloat16
: an alternative to the standardfloat16
formatfloat8_*
: several experimental 8-bit floating point representations including:float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
int4
anduint4
: low precision integer types.
See below for specifications of these number formats.
Installation
The ml_dtypes
package is tested with Python versions 3.8-3.11, and can be installed
with the following command:
pip install ml_dtypes
To test your installation, you can run the following:
pip install absl-py pytest
pytest --pyargs ml_dtypes
To build from source, clone the repository and run:
git submodule init
git submodule update
pip install .
Example Usage
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)
Importing ml_dtypes
also registers the data types with numpy, so that they may
be referred to by their string name:
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)
Specifications of implemented floating point formats
bfloat16
A bfloat16
number is a single-precision float truncated at 16 bits.
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
float8_e4m3b11fnuz
Exponent: 4, Mantissa: 3, bias: 11.
Extended range: no inf, NaN represented by 0b1000'0000.
float8_e4m3fn
Exponent: 4, Mantissa: 3, bias: 7.
Extended range: no inf, NaN represented by 0bS111'1111.
The fn
suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f
indicates it is finite values only. The n
indicates it includes NaNs, but only at the outer range.
float8_e4m3fnuz
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM
- exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
float8_e5m2
Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
float8_e5m2fnuz
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM
- exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
int4
and uint4
4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).
NumPy does not support types smaller than a single byte. For example, the
distance between adjacent elements in an array (.strides
) is expressed in
bytes. Relaxing this restriction would be a considerable engineering project.
The int4
and uint4
types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower four bits
of each byte contain the representation of the number, whereas the upper four
bits are ignored.
Quirks of low-precision Arithmetic
If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like sum
; consider this bfloat16
summation in NumPy (run with version 1.24.2):
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256
The true sum should be close to 5000, but numpy returns exactly 256: this is
because bfloat16
does not have the precision to increment 256
by values less than
1
:
>>> bfloat16(256) + bfloat16(1)
256
After 256, the next representable value in bfloat16 is 258:
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
For better results you can specify that the accumulation should happen in a
higher-precision type like float32
:
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)
License
This is not an officially supported Google product.
The ml_dtypes
source code is licensed under the Apache 2.0 license
(see LICENSE). Pre-compiled wheels are built with the
EIGEN project, which is released under the
MPL 2.0 license (see LICENSE.eigen).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for ml_dtypes-0.2.0-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 832a019a1b6db5c4422032ca9940a990fa104eee420f643713241b3a518977fa |
|
MD5 | 6c36c93d5b86e6bb3197287b2eec225a |
|
BLAKE2b-256 | 0889c727fde1a3d12586e0b8c01abf53754707d76beaa9987640e70807d4545f |
Hashes for ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e85ba8e24cf48d456e564688e981cf379d4c8e644db0a2f719b78de281bac2ca |
|
MD5 | da96031d8db68298928021e5035db464 |
|
BLAKE2b-256 | 8791d57c2d22e4801edeb7f3e7939214c0ea8a28c6e16f85208c2df2145e0213 |
Hashes for ml_dtypes-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 36d28b8861a8931695e5a31176cad5ae85f6504906650dea5598fbec06c94606 |
|
MD5 | 66c689c4e5db51874e37d6d235fc9b4e |
|
BLAKE2b-256 | 49a001570d615d16f504be091b914a6ae9a29e80d09b572ebebc32ecb1dfb22d |
Hashes for ml_dtypes-0.2.0-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e70047ec2c83eaee01afdfdabee2c5b0c133804d90d0f7db4dd903360fcc537c |
|
MD5 | 39a5a5811391d5eb631057d14c72c56f |
|
BLAKE2b-256 | 15da43bee505963da0c730ee50e951c604bfdb90d4cccc9c0044c946b10e68a7 |
Hashes for ml_dtypes-0.2.0-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 75015818a7fccf99a5e8ed18720cb430f3e71a8838388840f4cdf225c036c983 |
|
MD5 | 5c1ad96a127147456541f4320d2b10a2 |
|
BLAKE2b-256 | 5551c430b4f5f4a6df00aa41c1ee195e179489565e61cfad559506ca7442ce67 |
Hashes for ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f08c391c2794f2aad358e6f4c70785a9a7b1df980ef4c232b3ccd4f6fe39f719 |
|
MD5 | e56a9b8eed386d47e850e77703dfe62e |
|
BLAKE2b-256 | d11dd5cf76e5e40f69dbd273036e3172ae4a614577cb141673427b80cac948df |
Hashes for ml_dtypes-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bc29a0524ef5e23a7fbb8d881bdecabeb3fc1d19d9db61785d077a86cb94fab2 |
|
MD5 | 285967397150e08b2ab9e95de4f40d31 |
|
BLAKE2b-256 | 19057a6480a69f8555a047a56ae6af9490bcdc5e432658208f3404d8e8442d02 |
Hashes for ml_dtypes-0.2.0-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | df6a76e1c8adf484feb138ed323f9f40a7b6c21788f120f7c78bec20ac37ee81 |
|
MD5 | 10286821304ee329b5dd8029d2efa0ad |
|
BLAKE2b-256 | 669f3c133f83f3e5a7959345585e9ac715ef8bf6e8987551f240032e1b0d3ce6 |
Hashes for ml_dtypes-0.2.0-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1749b60348da71fd3c2ab303fdbc1965958dc50775ead41f5669c932a341cafd |
|
MD5 | 75c0a5fe313da9d5512aac7982ff333d |
|
BLAKE2b-256 | c74754b1e5eea9ed7f8a5f701713e47ea45e798a4f3e5f476a053fd0b537e2af |
Hashes for ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 32107e7fa9f62db9a5281de923861325211dfff87bd23faefb27b303314635ab |
|
MD5 | 8705527d3e93c7781b2aed3986baddf7 |
|
BLAKE2b-256 | 7bbe4b211a4e432502c432e3077aa66b0d64f6d7cb4c36613d65c49d9b799919 |
Hashes for ml_dtypes-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 80d304c836d73f10605c58ccf7789c171cc229bfb678748adfb7cea2510dfd0e |
|
MD5 | fa25eb3ad91afd04ec090dfa33ef1eb3 |
|
BLAKE2b-256 | c8b429ec494b77fff1d9dc7f567bdf26fed8ffcea19ef03eb44400288c3b0535 |
Hashes for ml_dtypes-0.2.0-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f00c71c8c63e03aff313bc6a7aeaac9a4f1483a921a6ffefa6d4404efd1af3d0 |
|
MD5 | e0ca3057c6781ebc98bb5e3f826ab7f4 |
|
BLAKE2b-256 | f0d8f602f05db13d187884ddb5ae4e823d333beb28bbd3d12c057450afa3acee |
Hashes for ml_dtypes-0.2.0-cp38-cp38-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 50845af3e9a601810751b55091dee6c2562403fa1cb4e0123675cf3a4fc2c17a |
|
MD5 | bcc39b5ed7914bd63b2ad6b80eadc1f9 |
|
BLAKE2b-256 | b66e0e9aa10a26f2222dad75d4ab8806357087285c89128002b3ba510dfef613 |
Hashes for ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 022d5a4ee6be14569c2a9d1549e16f1ec87ca949681d0dca59995445d5fcdd5b |
|
MD5 | 7a5171e620806626b871fb5091393990 |
|
BLAKE2b-256 | e7db16992470d8adc93e5230f01b0be8fe32a4eb25cd1c306a2efd1349d36d1a |
Hashes for ml_dtypes-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 35b984cddbe8173b545a0e3334fe56ea1a5c3eb67c507f60d0cfde1d3fa8f8c2 |
|
MD5 | c39d26a11d896f00fa2a6af2cf921a9f |
|
BLAKE2b-256 | e756003c07856e560faa01b82b91ab4e79c7bb2e0780d9c3bf53f9305367974e |
Hashes for ml_dtypes-0.2.0-cp38-cp38-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8faaf0897942c8253dd126662776ba45f0a5861968cf0f06d6d465f8a7bc298a |
|
MD5 | 32ce9b98c2b1d3fc3e2b69dfff0a8551 |
|
BLAKE2b-256 | 02ae8107b467ae5312be9dd434b818c2aceec7cbd2a1c00b0ed81aeb63d0a4bc |