Skip to content

Instantly share code, notes, and snippets.

@liangfu
Created January 17, 2025 23:52
Show Gist options
  • Save liangfu/1f644bcaa07433eadb6ec89e8566caba to your computer and use it in GitHub Desktop.
Save liangfu/1f644bcaa07433eadb6ec89e8566caba to your computer and use it in GitHub Desktop.
Evaluate stochastic rounding
import os
import time
import torch
import torch_xla.core.xla_model as xm
N = 16
def main():
# os.environ["XLA_USE_BF16"] = "1"
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"
os.environ["NEURON_CC_FLAGS"] = " --internal-hlo2tensorizer-options=--experimental-unsafe-fp8e4m3fn-as-fp8e4m3 --execute-repetition=1 "
device = xm.xla_device()
data = torch.arange(N).reshape(1,N).expand(32,N).to(device=device) * 2
print(f"{data=}")
# output = (data / float(N)).to(dtype=torch.float8_e4m3fn)
output_fp32 = (data / float(N)).to(dtype=torch.float32) * 3.11111
print(f"{output_fp32=}")
output_bf16 = output_fp32.to(dtype=torch.bfloat16)
print(f"{output_bf16=}")
output_fp8e4m3 = output_fp32.to(dtype=torch.float8_e4m3fn)
print(f"{output_fp8e4m3=}")
if __name__=="__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment