# bfloat16 in numpy: `ml_dtypes` - code repository: http://github.com/jax-ml/ml_dtypes - tracking issue: https://github.com/numpy/numpy/issues/19808 ```python= >>> from ml_dtypes import bfloat16 >>> import numpy as np >>> np.zeros(4, dtype=bfloat16) array([0, 0, 0, 0], dtype=bfloat16) ``` ## Goals - Single dependency for JAX, tensorflow, and perhaps other libraries in the future - Self-contained and permissively licensed, with eye toward possible future adoption in numpy core ## Problems - No way to register new dtypes with `numpy.finfo` - No way to make this a true subtype of `numpy.floating` because of hard-coded logic about float widths - Saving as `*.npy` does not work because there is no unique char code for bfloat16