Ref: https://x.com/ezyang/status/2048485559576789083
I think one way to think about fine-grained precision APIs is that we are exposing a little about the underlying memory hierarchy to the user. For the single node ops, the most important thing is memory or not. For collectives, what the comms are actually done in is another dimension.
PyTorch generally has these rules:
- If hardware is involved (e.g., tensor cores), defer to the hardware
- Always do accumulation in fp32 (this is formalized as
acc_dtype, which is not exposed to users but is hard coded per dtype)- We don't always do accumulation in fp32 for matmuls. This is generally controlled by global knobs:
torch.backends.cuda.matmul.fp32_precision,torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction,torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction,torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp
- We don't always do accumulation in fp32 for matmuls. This is generally controlled by global knobs:
- Output dtype defaults to the same as the input.
out_dtypeto control how to convert the accumulator before writing it to memory: can be used to avoid loss of precision from accumulator (sum) or to reduce out memory bandwidth (mm) - We don't have special low precision fused collectives, so FSDP2/DTensor control the collective precision by casting before collective. However this still needs to be baked into the API as you need asymmetric casts in backwards.
- There is very limited support for relaxing the dtype / grad dtype matching invariant via
grad_dtypeon parameters
- There is very limited support for relaxing the dtype / grad dtype matching invariant via
Some observations:
- The FSDP all-gather situation is very reminiscent of the autograd.Function compositionality problem. With the simple dtype = grad dtype invariant, if we temporarily want to rescind this we have to pack it into an autograd.Function. This is fine and keeps things simple, but we could also make it possible to explicitly control both the dtype and grad dtype for all intermediate tensors, thereby allowing more fine grained control here.
Main points:
forward_dtypeandbackward_dtypeis annoying as every autograd function has to be modified to support it (so redistribute can do it uniformly)- Need to understand how similar/different this is to the dtype/out dtype seen on reduction/matmul.
Pseudocode algorithms for reduction and all-reduce:
# reduction
input: in_dtype
acc: acc_dtype = 0
for i in range(N):
val = load(input, i).to(acc_dtype)
acc += val
store(out, acc.to(out_dtype))
# reduction backward
grad_out: out_dtype
grad_in = grad_out.expand(in_shape).to(in_dtype)
# reduce-scatter
input: in_dtype
buf = input.to(forward_dtype)
for i in range(WORLD_SIZE):
send_i, rec_i = schedule(i)
send(buf[send_i])
# there's some sort of tile level fusion that I'm not expressing correctly here
val = recv()
buf[rec_i] += val # NB: a custom reduce algorithm could choose to change dtype here; e.g., as seen in DSv4
# all-gather (V -> R)
input: in_dtype
out = torch.empty((WORLD_SIZE, *input.shape), out=forward_dtype)
out[schedule(0)].copy_(input)
for i in range(0, WORLD_SIZE):
send(out[schedule(i)])
recv(out=out[schedule(i+1)])