Skip to content

Instantly share code, notes, and snippets.

@andreaskoepf
Created April 20, 2026 13:26
Show Gist options
  • Select an option

  • Save andreaskoepf/814c78c5e7c99553bc7dcf70a047970c to your computer and use it in GitHub Desktop.

Select an option

Save andreaskoepf/814c78c5e7c99553bc7dcf70a047970c to your computer and use it in GitHub Desktop.

FastWAM++ Inference Optimization on NVIDIA Jetson Thor

A 2.14× speed-up, no architecture changes, auto-researched in ~11 hours wall-clock (~4 h hot phase)

Platform: Jetson Thor (Blackwell sm_110, 20 SMs, 32 MB L2, CUDA 13, torch 2.10+cu130, bf16 weights) Workload: FastWAM — Wan2.2-5B backbone (30 blocks) + action expert (30 blocks × 10-step Euler diffusion) + Wan2.2 VAE encode. Action chunk out [1, 32, 14], batch 1, single-robot inference. Budget: ~11 hours of auto-research wall-clock (2026-04-19, 11:23 → 22:27 UTC; the optimization-heavy phase from first shipped patch to current_best took ~4 hours), 130 benchmark runs, 24 shipped optimizations, 25 discarded hypotheses. Full log in experiments.jsonl / IDEAS.md. Correctness gate: max_abs_diff < 0.05 for exact methods, MSE < 5e-3 for lossy. Enforced automatically in benchmark.py; every KEEP has hard numbers attached.


TL;DR

stage mean_ms cumulative Δ comment
eager bf16, no graphs 691.5 starting point
+ 3 CUDA graphs + dtype_fix (pre-existing) 359.5 −332.0 ms S-001..S-008, shipped in the fastwam repo
+ FP8 E4M3 on all backbone + AE FFN linears 347.3 −12.2 ms S-009 / S-010
+ pre-alloc KV + cuDNN SDPA + bf16 modulation stack 311.6 −35.7 ms S-011..S-013 (three "fails alone""wins stacked")
+ FP8 fixed amax=64, no clamp, bf16 pre-quant 254.8 −56.8 ms S-017..S-019 — the biggest single lever
+ Liger RMSNorm (Triton, unblocked on sm_110) 225.8 −29.0 ms S-020
+ torch.compile on AE pre/post_mot → QKV → full block → backbone 168.0 −57.8 ms S-021..S-024 — the other big lever
current best 168.0 ms (p99 168.9) −191.5 ms vs. repo baseline patches/current_best.py

End-to-end on-robot latency: 580 ms/trigger → ~280 ms p50, sustained at 60 Hz dispatch on the DK-1 with prefix-conditioned RTC, 500/500 cycles, correct task execution (pick-and-place duplo).

journey

all runs


1. Where the time was actually going

Baseline profile (outputs/profile_top_kernels.md, 5 iterations of _timed_inference with all CUDA graphs on):

rank kernel share what it is
1 CatArrayBatchedCopy ~19 % torch.cat([cached_kv, new]) — 600 calls/inference in the AE self-attn KV update
2 nvjet_sm110 GEMM ~18 % bf16 cuBLAS linear projections (Q/K/V/O + FFN) across backbone + AE
3 pytorch_flash::flash_fwd_kernel ~16 % SDPA backend for action-expert self-attn
4 fp32 elementwise / upcasts ~10 % x.float() * (1+scale) + shift inside pre/post_mot modulation
5 fmha_cutlassF_bf16_aligned_64x128_rf_sm80 ~5 % mem-efficient attention — cross-attn
6 RMSNorm tail (pow, mean, layer_norm) ~4 % our custom 5-kernel RMSNorm

Top 3 buckets = ~53 % of GPU time. Everything below went through the auto-research loop against these three names.


2. The optimization stack (what ended up in current_best)

2.1 FP8 — the single biggest lever (−56.8 ms alone, ~30 % of the savings)

  • Blackwell has native FP8 E4M3 tensor cores. torch._scaled_mm(x_fp8, w_fp8, out_dtype=bf16) works on sm_110 at our shapes.
  • First attempt failed. Dynamic per-call amax quantization on AE linears at M=44 cost more than the _scaled_mm saved (F-006, F-007 — +0.65 % to +13 ms slower).
  • Fix: apply FP8 to the large-M backbone first (S-009, S-010). Backbone has M=960, where _scaled_mm amortizes the amax reduction.
  • Real win (S-017): replace dynamic amax with a fixed constant amax=64 (scale = 64/448 ≈ 0.143) on every linear including AE Q/K/V/O. Gives 28× headroom over typical post-LayerNorm bf16 ranges; MSE stays ≤ 1.3 e-4, far under the 5 e-3 gate. Eliminates hundreds of reduction kernels per inference. −15.7 ms.
  • Polish (S-018): at amax=64 the .clamp(±448) op is a no-op; deleting it saves ~600 kernel launches. −8.2 ms paired.
  • Polish (S-019): the (x.float() / scale).to(fp8) pre-quant upcast was dead cycles — bf16's 7-bit mantissa is more than enough precision for a value about to land in FP8 E4M3 (3-bit mantissa). Pre-compute inv_scale_bf16, replace 3 kernels with 2. −16.8 ms paired.

2.2 torch.compile — almost tied for biggest lever (−57.8 ms, ~30 %)

  • Initially marked "broken on Thor" (F-002): Triton emitted sm_110a assembly which ptxas-blackwell bundled in Triton's wheel rejected.
  • Unblocked 2026-04-19: pointing Triton at the system CUDA 13 ptxas with TRITON_PTXAS_BLACKWELL_PATH=/usr/local/cuda-13.0/bin/ptxas works. This single-line env fix unlocked four separate optimizations that had been shelved for months.
  • Applied in four layers (S-021 → S-024), each compiled separately so inductor can see across module boundaries:
    1. AE pre_mot / post_mot — fuses 6 pointwise ops (norm + modulate + residual + cross-attn dispatch + gate) into 1–2 Triton kernels. −18.3 ms.
    2. AE compute_mot_qkv — fuses FP8 _scaled_mm + Liger RMSNorm + rearrange. −3.8 ms.
    3. Full AE block forward — lets inductor merge across module boundaries including SDPA. −16.3 ms.
    4. Backbone block forward — 30 blocks × one compiled forward each. −19.7 ms.
  • Must stay on default mode: mode="reduce-overhead" installs its own cudagraph manager and conflicts with our outer manual torch.cuda.CUDAGraph (F-025).

2.3 Liger RMSNorm — unblocked by the same Triton fix (−29.0 ms)

  • Our custom RMSNorm was 5 separate kernels (pow → mean → rsqrt → mul → cast). Liger fuses into 1 Triton kernel.
  • Microbench at our shapes: 10× faster at [1, 960, 3072] (backbone), ~3× faster at [1, 44, 3072] (AE).
  • End-to-end: −29.0 ms paired → 225.8 ms.

2.4 KV-cache + attention backend — the "wins only when stacked" category

Three hypotheses that were discarded in isolation but became wins once FP8 changed the workload composition:

id isolated stacked on FP8 mechanism
S-011 pre-alloc KV +7 ms slower −7 ms with FP8 GEMMs shrunk, the torch.cat relative cost rose; .copy_() into a persistent buffer now wins
S-012 cuDNN SDPA pin +7 to +21 ms slower −17 ms attention's share of wall time rises after FP8; cuDNN's faster kernel now beats dispatch overhead
S-013 bf16-native modulation +12 ms slower −5 ms fewer fp32 upcast kernels pays off only once the surrounding GEMMs are the same cost order

Lesson learned: run paired benches against the current candidate layer, not against the raw baseline. Isolated DISCARDs can flip to KEEPs when the workload composition changes. This was coded into the experimental protocol (see IDEAS.md "thermal-drift finding" section).

2.5 Smaller wins

  • HF kernels-community rotary (S-014, −5.4 ms): lean variant that skips .contiguous() and uses strided output slices.
  • FP8 on AE cross_Q / cross_O (S-015, −4.3 ms): extends FP8 coverage to another 600 calls/inference.
  • Prebaked SDPA scale (S-016, −26.3 ms — the brainstorm underestimated this by 10×): pre-multiply Q by 1/sqrt(head_dim) before SDPA and pass scale=1.0. Saves SDPA's internal elementwise on every call.

3. Things that didn't work (the 25 discarded hypotheses)

The full list is in IDEAS.md under "ALREADY TRIED — FAILED". Highlights:

id what why it failed
F-005 Fused QKV on action expert (all three projections in one GEMM) +24 ms slower — at M=44, cuBLAS picks a worse tile for [44 × 9216] than three sequential [44 × 3072] calls on Thor's 20 SMs
F-010 HF Hub kernels-community/flash-attn wheel loads, but cudaErrorNoKernelImageForDevice — wheel ships SASS for sm_80/90/100/120, no sm_110, no PTX
F-015 flash-attn 2.8.4 source-built for sm_110 compiles, imports, runs — 1.2× to 2.9× slower than cuDNN-SDPA. fa2 kernels are sm_80-era, compiled for sm_110 but not tile-tuned for it
F-023 FA4 CuTeDSL on Thor per-shape microbench: +117 % slower than cuDNN on the big backbone shape; FP8-attention is upstream-gated out for sm_11x (1 LOC away). Report at outputs/fa4_thor_assessment.md
F-012 torch._scaled_mm mixed bf16-act × fp8-weight RuntimeError: Invalid scaling configuration. Weight-only FP8 is not supported on Thor yet
F-021 fold 1/sqrt(d) into norm_q weights to eliminate a 660-site post-RoPE multiply +0.9 ms noise — kernels are free at CUDA-graph replay
F-022 bf16 backbone modulation (retry under current_best) 0 ms — the fp32 elementwise casts in the backbone's forward are also kernel-free at replay

Pattern: "kernel counts" and "microbench speed-ups" routinely failed to survive end-to-end benchmark, because inside a captured CUDA graph the per-launch costs vanish and the only thing that matters is the aggregate compute. The only reliable verdict is the full _timed_inference n=50 number.


4. The thermal-drift finding (why paired benchmarking matters)

Early in the run, Jetson Thor's baseline drifted from 359.5 ms (cold, clean) to ~371–376 ms after ~2 hours of bench load. Five DISCARD verdicts recorded against the stale 359 ms were re-run paired — four of them flipped to KEEP:

hypothesis unpaired vs stale 359.5 ms paired delta new verdict
H-002 FP8 AE FFN only +0.65 % "slower" −8.5 ms re-opened → S-009
H-001bb FP8 backbone FFN +19 ms "slower" −9.1 ms re-opened → S-009
H-014 pre-alloc KV +7 ms "slower" −7 ms stacked → S-011
H-016 cuDNN SDPA +7-21 ms "slower" −17 ms stacked → S-012

Protocol going forward: fresh baseline run immediately before each candidate (within seconds). The pair_* entries in experiments.jsonl are all A/B pairs produced under that rule.


5. Correctness

Two gates are run on every patch:

  1. Random-input gate (outputs/baseline_chunk.pt): max_abs_diff and mse against the saved action-chunk tensor. Backward-compat.
  2. Real-input gate (outputs/reference_outputs_16k.pt, 5 real samples × 2 modes): verifies the patched path against an independent eager-mode reference built from dataset frames. Also pins prefix_pin_max_abs_diff for the RTC prefix-conditioned path.

current_best against the real-robot-checkpoint (step_16000):

  • max_abs_diff ≤ 0.0215 (gate 0.05 — 2.3× under)
  • mse ≤ 2.28 e-5 (gate 5 e-3 — 220× under)
  • prefix_pin_max_abs_diff = 0.00000 on all 5 prefix rows (bit-exact)
  • nan=0, inf=0 → VERDICT: PASS

Cross-finetune: FP8 fixed-amax=64 transferred cleanly from the step_20000 benchmark checkpoint to step_16000 (robot) and to the H=50 duplo-finetune-h50/step_8000 run — as hypothesized, since the constant is calibrated against typical post-LayerNorm bf16 ranges, not a specific checkpoint.


6. On-robot outcome

Running eval/real_robot_inference_rtc.py --optimizations current_best on the DK-1 arm with the H=50 finetune:

dispatch rate cycles completed p50 end-to-end splice failures
30 Hz 500/500 280 ms 0
40 Hz 500/500 327 ms 2/500
50 Hz (12 steps) 500/500 297 ms 4/500
60 Hz 500/500 291 ms 10/500 (none consecutive)

prefix_diff_max = 0.0 e+00 on all 2000 completions. Pre-optimization on-robot baseline was ~580 ms/trigger at 20 Hz. Task execution (pick-and-place duplo) is unchanged.


7. What's left on the table

See outputs/future_directions.md. Short version:

  • FA4 / SageAttention for sm_110: waiting on upstream Blackwell-small tuning. Both compiled cleanly; neither is faster than cuDNN-SDPA on Thor's 20 SMs at our shapes.
  • VAE quantization: VAE is 161 ms on-robot, conv3d-dominated, still untouched.
  • Cross-stream overlap (H-013): VAE + backbone + diffusion run serially; overlap could hide ~150 ms on back-to-back triggers. Deeper change in eval/rtc/.
  • torchao INT8 W8A8: available, not yet benchmarked.

8. Reproducing

cd /home/koepf/rclaw/fastwam-inference-opt
source /home/koepf/robotics/trlc-dk1/.venv/bin/activate

# baseline
python benchmark.py \
    --checkpoint /home/koepf/robotics/fastwam/checkpoints/dk1-fastwam-duplo-finetune/step_20000 \
    --experiment-id repo_baseline --n-runs 50

# current_best (all 24 shipped optimizations)
python benchmark.py \
    --checkpoint /home/koepf/robotics/fastwam/checkpoints/dk1-fastwam-duplo-finetune/step_20000 \
    --experiment-id current_best --n-runs 50 \
    --apply-patch patches/current_best.py

Requires TRITON_PTXAS_BLACKWELL_PATH=/usr/local/cuda-13.0/bin/ptxas for the Liger + torch.compile patches.

Plot regeneration: python scripts/make_report_plot.py.


This report was auto-generated from experiments.jsonl (130 runs) and IDEAS.md (24 SHIPPED, 25 TRIED-FAILED). Every number above is backed by an entry in one of those two files.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment