Source code for gfloat.block

# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

# Block floating point formats
# https://en.wikipedia.org/wiki/Block_floating_point

from dataclasses import dataclass
from typing import Callable, Iterable

import numpy as np
import numpy.typing as npt

from .decode import decode_float
from .round import RoundMode, encode_float, round_float
from .types import FormatInfo


[docs] @dataclass class BlockFormatInfo: #: Short name for the format, e.g. BlockFP8 name: str #: Element data type etype: FormatInfo #: Scaling block size k: int #: Scale datatype stype: FormatInfo #: ## Derived values @property def element_bits(self) -> int: """The number of bits in each element, d""" return self.etype.k @property def scale_bits(self) -> int: """The number of bits in the scale, w""" return self.stype.k @property def block_size_bytes(self) -> int: """The number of bytes in a block""" bits = self.element_bits * self.k + self.scale_bits assert bits % 8 == 0 return bits // 8 @property def __name__(self) -> str: return self.name def __str__(self) -> str: return f"BlockFormatInfo:{self.name})"
[docs] def decode_block(fi: BlockFormatInfo, block: Iterable[int]) -> Iterable[float]: """ Decode a :paramref:`block` of integer codepoints in Block Format :paramref:`fi` The scale is encoded in the first value of :paramref:`block`, with the remaining values encoding the block elements. The size of the iterable is not checked against the format descriptor. Args: fi (BlockFormatInfo): Describes the block format block (Iterable[int]): Input block Returns: A sequence of floats representing the encoded values. """ it = iter(block) scale_encoding = next(it) scale = decode_float(fi.stype, scale_encoding).fval for val_encoding in it: val = scale * decode_float(fi.etype, val_encoding).fval yield val
# TODO: Assert length of block was k+1? Messy unless block is len()able
[docs] def encode_block( fi: BlockFormatInfo, scale: float, vals: Iterable[float], round: RoundMode = RoundMode.TiesToEven, ) -> Iterable[int]: """ Encode float :paramref:`vals` into block Format described by :paramref:`fi` The :paramref:`scale` is explicitly passed, and the :paramref:`vals` are assumed to already be multiplied by `1/scale`. That is, this is pure encoding, scaling is computed and applied elsewhere (see e.g. :func:`quantize_block`). It is checked for overflow in the target format, and will raise an exception if it does. Args: fi (BlockFormatInfo): Describes the target block format scale (float): Scale to be recorded in the block vals (Iterable[float]): Input block round (RoundMode): Rounding mode to use, defaults to `TiesToEven` Returns: A sequence of ints representing the encoded values. Raises: ValueError: The scale overflows the target scale encoding format. """ if scale > fi.stype.max or scale < fi.stype.min: raise ValueError(f"Scaled {scale} out of range for {fi.stype}") sat = True # Saturate elements if out of range def enc(ty: FormatInfo, x: float) -> int: return encode_float(ty, round_float(ty, x, round, sat)) yield enc(fi.stype, scale) for val in vals: yield enc(fi.etype, val)
ComputeScaleCallable = Callable[[float, npt.ArrayLike], float]
[docs] def compute_scale_amax(emax: float, vals: npt.ArrayLike) -> float: """ Compute a scale factor such that :paramref:`vals` can be scaled to the range [0, 2**emax]. That is, `scale` is computed such that the largest exponent in the array `vals * scale` will be `emax`. The scale is clipped to the range 2**[-127, 127]. If all values are zero, any scale value smaller than emax would be accurate, but returning the smallest possible means that quick checks on the magnitude to identify near-zero blocks will also find the all-zero blocks. Args: emax (float): Maximum exponent to appear in `vals * scale` vals (ArrayLike): Input block Returns: A float such that `vals * scale` has exponents less than or equal to `emax`. Note: If all vals are zero, 1.0 is returned. """ amax = np.max(np.abs(vals)) if amax == 0.0: q_log2scale = -127.0 else: q_log2scale = np.floor(np.log2(amax)) - emax q_log2scale = np.clip(q_log2scale, -127.0, 127.0) return 2.0**q_log2scale
[docs] def quantize_block( fi: BlockFormatInfo, vals: npt.NDArray[np.float64], compute_scale: ComputeScaleCallable, round: RoundMode = RoundMode.TiesToEven, ) -> npt.NDArray[np.float64]: """ Encode and decode a block of :paramref:`vals` of bytes into block format described by :paramref:`fi` Args: fi (BlockFormatInfo): Describes the target block format vals (numpy.array): Input block compute_scale ((float, ArrayLike) -> float): Callable to compute the scale, defaults to :func:`compute_scale_amax` round (RoundMode): Rounding mode to use, defaults to `TiesToEven` Returns: An array of floats representing the quantized values. Raises: ValueError: The scale overflows the target scale encoding format. """ q_scale = compute_scale(fi.etype.emax, vals) scaled_vals = vals / q_scale enc = encode_block(fi, q_scale, scaled_vals, round) return np.fromiter(decode_block(fi, enc), float)