Ref: https://x.com/ezyang/status/2048485559576789083
I think one way to think about fine-grained precision APIs is that we are exposing a little about the underlying memory hierarchy to the user. For the single node ops, the most important thing is memory or not. For collectives, what the comms are actually done in is another dimension.
PyTorch generally has these rules:
- If hardware is involved (e.g., tensor cores), defer to the hardware
- Always do accumulation in fp32 (this is formalized as
acc_dtype, which is not exposed to users but is hard coded per dtype)- We don't always do accumulation in fp32 for matmuls. This is generally controlled by global knobs:
torch.backends.cuda.matmul.fp32_precision,torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction,torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction,torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp
- We don't always do accumulation in fp32 for matmuls. This is generally controlled by global knobs:
- Output dtype defaults to the same as the input.
out_dtypeto control how to convert the accumulator
