bfloat16 in numpy: ml_dtypes

>>> from ml_dtypes import bfloat16 >>> import numpy as np >>> np.zeros(4, dtype=bfloat16) array([0, 0, 0, 0], dtype=bfloat16)

Goals

  • Single dependency for JAX, tensorflow, and perhaps other libraries in the future
  • Self-contained and permissively licensed, with eye toward possible future adoption in numpy core

Problems

  • No way to register new dtypes with numpy.finfo
  • No way to make this a true subtype of numpy.floating because of hard-coded logic about float widths
  • Saving as *.npy does not work because there is no unique char code for bfloat16
Select a repo