Skip to content

Instantly share code, notes, and snippets.

@ezyang
ezyang / prec.md
Last active April 28, 2026 02:54

Ref: https://x.com/ezyang/status/2048485559576789083

I think one way to think about fine-grained precision APIs is that we are exposing a little about the underlying memory hierarchy to the user. For the single node ops, the most important thing is memory or not. For collectives, what the comms are actually done in is another dimension.

PyTorch generally has these rules:

  • If hardware is involved (e.g., tensor cores), defer to the hardware
  • Always do accumulation in fp32 (this is formalized as acc_dtype, which is not exposed to users but is hard coded per dtype)
    • We don't always do accumulation in fp32 for matmuls. This is generally controlled by global knobs: torch.backends.cuda.matmul.fp32_precision, torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction, torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction, torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp
  • Output dtype defaults to the same as the input. out_dtype to control how to convert the accumulator
#!/usr/bin/env bash
# Claude Code status line - shows context window remaining percentage
input=$(cat)
remaining=$(echo "$input" | jq -r '.context_window.remaining_percentage // empty')
if [ -n "$remaining" ]; then
printf 'Context: %s%% remaining' "$remaining"
fi
ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIKeKVvH67TN+aMN0jjau8SCHQo5XcniG73fxKc32aF6I ezyang@ezyang-mac
https://github.com/pytorch/pytorch/issues/163449
https://github.com/pytorch/pytorch/issues/163457
https://github.com/pytorch/pytorch/issues/163420
https://github.com/pytorch/pytorch/issues/163300
https://github.com/pytorch/pytorch/issues/162723
import torch
import unittest
from torch import Tensor
from torch.distributed.tensor import (
DTensor,
DeviceMesh,
distribute_tensor,
init_device_mesh,
Partial,
Replicate,
x = DTensor.from_local(arange_nd(15), mesh["m", "n", "k"], [R, R, R])
# Eliminate M
x = DTensor.from_local(x.redistribute(placements=[R, R, S(0)]).to_local(), mesh["m", "n"]) # shard K
x = DTensor.from_local(x.redistribute(placements=[R, S(0)]).to_local(), mesh["m"]) # shard N
x = x.redistribute(placements=[S(0)]).to_local() # shard M
x = DTensor.from_local(x, mesh["n"], [S(0)]).redistribute(placements=[R]) # unshard N
x = DTensor.from_local(x.to_local(), mesh["n", "k"], [R, S(0)]).redistribute(placements=[R, R]) # unshard K
# Eliminate N
x = DTensor.from_local(x.redistribute(placements=[R, S(0)]).to_local(), mesh["n"]) # shard K
x = x.redistribute(placements=[S(0)]).to_local() # shard N
@ezyang
ezyang / gist:15791ae363900f42c704c09ca34346e3
Created October 29, 2025 19:02
Matrix-of-matrices tensor render
def render(tensor, cell_width=None):
"""
Print a tensor following the matrix-of-matrices algorithm.
Args:
tensor: A tensor-like object with .shape attribute and indexing
cell_width: Width for each cell (calculated globally if None)
Returns:
import torch
from torch import nn
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Replicate
world_size = 4
(verl) [ezyang@devgpu086.cco2 ~/local/verl/verl/examples/ppo_trainer (main)]$ pp bash run_deepseek7b_llm.sh
+ python3 -m verl.trainer.main_ppo algorithm.adv_estimator=gae data.train_files=/home/ezyang/local/data/gsm8k/train.parquet data.val_files=/home/ezyang/local/data/gsm8k/test.parquet data.train_batch_size=1024 data.max_prompt_length=512 data.max_response_length=512 data.filter_overlong_prompts=True data.truncation=error actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat actor_rollout_ref.actor.optim.lr=1e-6 actor_rollout_ref.model.use_remove_padding=True actor_rollout_ref.actor.ppo_mini_batch_size=256 actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 actor_rollout_ref.actor.fsdp_config.param_offload=False actor_rollout_ref.actor.fsdp_config.optimizer_offload=False actor_rollout_ref.actor.use_kl_loss=False actor_rollout_ref.model.enable_gradient_checkpointing=True actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 actor_rollout_ref.rollout.tensor_model_parallel_size=4 actor_ro

I really like Scuba (Meta's internal real-time database system). The distributed, real-time database part of Scuba is quite difficult (and expensive) to replicate, but I also really like Scuba's UI for doing queries, and I have found myself wishing that I have access to it even for "small" databases, e.g., I have a sqlite dataset I want to explore.

Here's a screenshot of this UI from https://research.facebook.com/publications/scuba-diving-into-data-at-facebook/:

image

Pivotal ideas:

  • Time series by default. In the dedicated "time series" view, there are many features specifically oriented towards working towards tables that represent events that occurred over time: the start, end, compare, aggregate and granularity fields all specially privilege the timestamp field. In fact, you can't log events to Scuba's backing data store without a timestamp, they always come with one. (Scuba a