Created
October 13, 2022 15:43
-
-
Save thecharlieblake/62fe60553050d078caf04764a2228681 to your computer and use it in GitHub Desktop.
A representation of the IEEE754 binary format, as well as variants for FP8.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A representation of the IEEE754 binary format, as well as variants for FP8.""" | |
from dataclasses import dataclass | |
from functools import total_ordering | |
from typing import Any, ClassVar | |
@dataclass | |
class IEEE754BinaryFormat: | |
"""Can be used to express any format defined in the IEEE754 standard. | |
To represent a value using this format, see `FloatInstance`. | |
""" | |
e_width: int | |
m_width: int | |
inf_encoding: ClassVar[str] = "E=1s M=0s" | |
nan_encoding: ClassVar[str] = "E=1s M≠0s" | |
zero_encoding: ClassVar[str] = "S=0/1 E=0s M=0s" | |
subnormal_encoding: ClassVar[str] = "E=0s" | |
@property | |
def bias(self) -> int: | |
"""Exponent bias.""" | |
return int(2 ** (self.e_width - 1)) - 1 | |
@property | |
def min_e(self) -> int: | |
"""Minimum exponent. | |
Note: this is not `0 - self.bias` as the 0 exponent is reserved for subnormals. | |
""" | |
return 1 - self.bias | |
@property | |
def max_e(self) -> int: | |
"""Maximum exponent. | |
Note: The `-2` accounts for the fact that the all 1s exponent denotes NaN/Inf. | |
""" | |
return int(2**self.e_width) - 2 - self.bias | |
@property | |
def abs_min_normal(self) -> float: | |
"""Absolute minimum normal (i.e. not subnormal) representable value.""" | |
return float(2**self.min_e) | |
@property | |
def abs_min(self) -> float: | |
"""Absolute minimum representable value (this is in the subnormal range).""" | |
return float(2 ** (self.min_e - self.m_width)) | |
@property | |
def abs_max(self) -> float: | |
"""Absolute maximum representable value.""" | |
return float((2**self.max_e) * (2 - 2**-self.m_width)) | |
@dataclass | |
class GAQProposedFormat(IEEE754BinaryFormat): | |
"""Used for both of Graphcore, AMD and Qualcomm's proposed FP8 formats.""" | |
custom_bias: int | |
inf_encoding: ClassVar[str] = "N/A" | |
nan_encoding: ClassVar[str] = "S=1 E=0s M=0s" | |
zero_encoding: ClassVar[str] = "S=0 E=0s M=0s" | |
@property | |
def bias(self) -> int: | |
"""Exponent bias. GAQ use a bias that doesn't match the standard IEE754 one.""" | |
return self.custom_bias | |
@property | |
def max_e(self) -> int: | |
"""Maximum exponent. The all 1s exponent no longer denotes NaN/Inf.""" | |
return super().max_e + 1 | |
@dataclass | |
class NAIProposedFormat(IEEE754BinaryFormat): | |
"""Used for Nvidia, ARM and Intel's proposed E4M3 format.""" | |
inf_encoding: ClassVar[str] = "N/A" | |
nan_encoding: ClassVar[str] = "E=1s M=1s" | |
@property | |
def max_e(self) -> int: | |
"""Maximum exponent. The all 1s exponent no longer denotes NaN/Inf.""" | |
return super().max_e + 1 | |
@property | |
def abs_max(self) -> float: | |
"""Absolute maximum representable value. | |
Accounts for the fact that the value with all 1s exponent+mantissa denotes NaN. | |
This was previously handled by taking one value away from max_e, but now the all | |
1s exponent is generally valid, this special-case must be added. | |
""" | |
return float(super().abs_max - 2 ** (self.max_e - self.m_width)) | |
@dataclass | |
@total_ordering | |
class FloatInstance: | |
"""An instance of a floating point number, defined with reference to an instance | |
or subclass of `IEEE754BinaryFormat`. | |
""" | |
format: IEEE754BinaryFormat | |
s: int | |
e: int | |
m: int | |
def __post_init__(self) -> None: | |
assert self.s in [0, 1], self.s | |
self.e_limit = int(2**self.format.e_width) - 1 | |
self.m_limit = int(2**self.format.m_width) - 1 | |
assert ( | |
0 <= self.e <= self.e_limit | |
), f"Exponent {self.e} outside range: [0, {self.e_limit}]" | |
assert ( | |
0 <= self.m <= self.m_limit | |
), f"Mantissa {self.m} outside range: [0, {self.m_limit}]" | |
@property | |
def value(self) -> float: | |
"""The numerical value of the bitstring, as defined by the supplied format.""" | |
if self._is_inf(): | |
return float("inf") * int((-1) ** self.s) | |
if self._is_nan(): | |
return float("nan") | |
if self._is_subnormal(): | |
return self._subnormal_val() | |
return self._normal_val() | |
def _normal_val(self) -> float: | |
e = self.e - self.format.bias | |
m = 1 + (self.m / (self.m_limit + 1)) | |
return float(((-1) ** self.s) * (2**e) * m) | |
def _subnormal_val(self) -> float: | |
e = 1 - self.format.bias | |
m = self.m / (self.m_limit + 1) | |
return float(((-1) ** self.s) * (2**e) * m) | |
def _is_subnormal(self) -> bool: | |
return self.e == 0 | |
def _is_nan(self) -> bool: | |
if self.format.nan_encoding == "E=1s M≠0s": | |
return self.e == self.e_limit and self.m != 0 | |
if self.format.nan_encoding == "E=1s M=1s": | |
return self.e == self.e_limit and self.m == self.m_limit | |
assert ( | |
self.format.nan_encoding == "S=1 E=0s M=0s" | |
), f"NaN encoding `'{self.format.nan_encoding}' not recognised" | |
return self.s == 1 and self.e == 0 and self.m == 0 | |
def _is_inf(self) -> bool: | |
if self.format.inf_encoding == "E=1s M=0s": | |
return self.e == self.e_limit and self.m == 0 | |
assert ( | |
self.format.inf_encoding == "N/A" | |
), f"Inf encoding `'{self.format.inf_encoding}' not recognised" | |
return False | |
def __repr__(self) -> str: | |
return str(self.value) | |
def __eq__(self, other: Any) -> bool: | |
return self.value.__eq__(other) | |
def __lt__(self, other: Any) -> bool: | |
return self.value.__lt__(other) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment