No project description provided
Project description
ml_dtypes
ml_dtypes
is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
bfloat16
: an alternative to the standardfloat16
formatfloat8_*
: several experimental 8-bit floating point representations including:float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
int4
anduint4
: low precision integer types.
See below for specifications of these number formats.
Installation
The ml_dtypes
package is tested with Python versions 3.9-3.12, and can be installed
with the following command:
pip install ml_dtypes
To test your installation, you can run the following:
pip install absl-py pytest
pytest --pyargs ml_dtypes
To build from source, clone the repository and run:
git submodule init
git submodule update
pip install .
Example Usage
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)
Importing ml_dtypes
also registers the data types with numpy, so that they may
be referred to by their string name:
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)
Specifications of implemented floating point formats
bfloat16
A bfloat16
number is a single-precision float truncated at 16 bits.
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
float8_e4m3b11fnuz
Exponent: 4, Mantissa: 3, bias: 11.
Extended range: no inf, NaN represented by 0b1000'0000.
float8_e4m3fn
Exponent: 4, Mantissa: 3, bias: 7.
Extended range: no inf, NaN represented by 0bS111'1111.
The fn
suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f
indicates it is finite values only. The n
indicates it includes NaNs, but only at the outer range.
float8_e4m3fnuz
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM
- exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
float8_e5m2
Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
float8_e5m2fnuz
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM
- exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
int4
and uint4
4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).
NumPy does not support types smaller than a single byte. For example, the
distance between adjacent elements in an array (.strides
) is expressed in
bytes. Relaxing this restriction would be a considerable engineering project.
The int4
and uint4
types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower four bits
of each byte contain the representation of the number, whereas the upper four
bits are ignored.
Quirks of low-precision Arithmetic
If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like sum
; consider this bfloat16
summation in NumPy (run with version 1.24.2):
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256
The true sum should be close to 5000, but numpy returns exactly 256: this is
because bfloat16
does not have the precision to increment 256
by values less than
1
:
>>> bfloat16(256) + bfloat16(1)
256
After 256, the next representable value in bfloat16 is 258:
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
For better results you can specify that the accumulation should happen in a
higher-precision type like float32
:
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)
License
This is not an officially supported Google product.
The ml_dtypes
source code is licensed under the Apache 2.0 license
(see LICENSE). Pre-compiled wheels are built with the
EIGEN project, which is released under the
MPL 2.0 license (see LICENSE.eigen).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for ml_dtypes-0.3.1-cp312-cp312-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a777928dcba8865ab4a8157eeb25d23aed7bc82e5fd74e1d5eca821d3f148b39 |
|
MD5 | 08707b3270b1cf85ef12fb2c400805ef |
|
BLAKE2b-256 | 299092bf12f9202f1f21c7bc2973461a7136e8bfd5ea491a1c256b18e5aaa513 |
Hashes for ml_dtypes-0.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d1a8dc3bac1da2a17d0e2e4cba36ee89721d0bd33ea4765af2eefb5f41409e0f |
|
MD5 | cb91a0f3600d71e003efcf12a9128112 |
|
BLAKE2b-256 | 13208fa6e50ee5a7574bb53ccc7f15c2802d247e610cca28d238da1bab71b7d1 |
Hashes for ml_dtypes-0.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4828b62fa3bf1ae35faa40f3db9a38ec72fbce02f328a1d14c3a9da4606af364 |
|
MD5 | 49f3bc1a8fa513c5f4412e1634bcae7b |
|
BLAKE2b-256 | 2e9c305a78b2087573136fd0b13c9bf4c63612eec24031294af197986be24c8a |
Hashes for ml_dtypes-0.3.1-cp312-cp312-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3d8ca0acbd377082792d8b97081ba580abdad67c6afb7f827012c675b052f058 |
|
MD5 | 56d0828a13195acddee55d075b98c638 |
|
BLAKE2b-256 | ebabaf35d8db985a39b5bf39e65b9eac993c192c22d25942000a8f6e72f69f4b |
Hashes for ml_dtypes-0.3.1-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | cb0c404e0dd3e25b56362c1c1e5de0ef717f727dde59fa721af4ce6ab2acca44 |
|
MD5 | f60713ff01525029d33681a9bf345680 |
|
BLAKE2b-256 | 4f9548c66f80acb9f91c3c2fd0cc6939457b8b0c1bd0d2b96edb461a5209df80 |
Hashes for ml_dtypes-0.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 42a8980afd8b7c8e270e8b5c260237286b5b26acd276fcb758d13cd7cb567e99 |
|
MD5 | 702b1ec00ba0046f2e3a4953767c8158 |
|
BLAKE2b-256 | 5fb052580c12377c7f9a4319a80c85e235f4eeb7f8a09d362900dd3091b246ce |
Hashes for ml_dtypes-0.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5727effa7650f7ab10906542d137cfb3244fdc3b2b519beff42f82def8ba59be |
|
MD5 | 56d43117753abf5345fd029965ee6b29 |
|
BLAKE2b-256 | abbab3579c4a7845fdbbe4b1488fe20bb84837f65735500c2e4ecde7dacaa08a |
Hashes for ml_dtypes-0.3.1-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5e0b0b6bb07fa5ad11bb61d174667176bee5e05857225067aabfc5adc1b51d23 |
|
MD5 | e32fb318a284b310d1c1212b73c43c2d |
|
BLAKE2b-256 | 2254f50df872e42beffe566e29b6ef6cc9daf3901160c9e5290e254aad54560f |
Hashes for ml_dtypes-0.3.1-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16 |
|
MD5 | 157c8450da4508a732e4b4f7fda8d971 |
|
BLAKE2b-256 | 58cc570cf7a14db8d30ab2edc7a0b506aa9d8fdf7780a089e4ca9bff96f50fe9 |
Hashes for ml_dtypes-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e |
|
MD5 | 0633f590705de4a951f485ca2ca9847f |
|
BLAKE2b-256 | 9031ec94e33a799323a8c37d1883f44b517c38d9defa7667db97cba212384d71 |
Hashes for ml_dtypes-0.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b |
|
MD5 | ea40fd430caf4f1e6d3fabc78e8a4cb8 |
|
BLAKE2b-256 | f0d59cfcfb866343e770cbe190a0b7290da357f8c75a6d959e2ae935472e44f6 |
Hashes for ml_dtypes-0.3.1-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267 |
|
MD5 | 9fcbd55789064ae0d1ca913dacc5eb15 |
|
BLAKE2b-256 | 976a8ceeace58c85188d844cc1ab76e12ec7ad0980a769a953387554fcbbd530 |
Hashes for ml_dtypes-0.3.1-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 979d7d196d9a17e0135ae22878f74241fbd3522cef58d7b292f1fd5b32282201 |
|
MD5 | 4bd15a5cab8dccc0c418930401f80668 |
|
BLAKE2b-256 | 2be8cc9c1736bd9ba89f5d3fe9a931770bf4c56721e1f56b18d6bcad085cc3b8 |
Hashes for ml_dtypes-0.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4d94b2d1bed77284694f7fd0479640fa7aa5d96433dca3cbcec407a5ef752e77 |
|
MD5 | 6dc15aba3df91af2d6c74bf68ab0e315 |
|
BLAKE2b-256 | 343618f70d9eda70be1d17a8092c287acd5deea2b916bf8f3ce8d4299045eb11 |
Hashes for ml_dtypes-0.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 70984b473db6489ec1d8c79b082a1322105155193049d08a3b0c515094e9777b |
|
MD5 | 3f23d20662dd884c1ca5bd94b55d32b3 |
|
BLAKE2b-256 | 74f62d4c6a971dd4b8ce3c43c062c04dd51673c0f45475560a7ba0680103654a |
Hashes for ml_dtypes-0.3.1-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 438437e2e614a3c91d75581653b6c40ec890e8b5994d7190a90c931740151c95 |
|
MD5 | 759a0183c81227003446114f9a2fbdc9 |
|
BLAKE2b-256 | 4e7b6bb9ea497a9575f0cf52c12324583b2eacc21efb29b4fb4f1d447018939f |