Skip to content

Instantly share code, notes, and snippets.

@rrbutani
Last active January 15, 2026 04:04
Show Gist options
  • Select an option

  • Save rrbutani/b6582736cb0341b503eb4a232906d020 to your computer and use it in GitHub Desktop.

Select an option

Save rrbutani/b6582736cb0341b503eb4a232906d020 to your computer and use it in GitHub Desktop.
❯ pytest --chrome-trace parallel-baseline.json -n auto --dist=worksteal
664 passed, 24 warnings in 9.08s

chrome trace of running pytest w/xdist's workstealing strategy

❯ pytest --chrome-trace parallel-reordered.json -n auto --dist=worksteal --expensive-tests-first
664 passed, 24 warnings in 7.33s

chrome trace of running with --expensive-tests-first

Caution

Watch out for --dist=worksteal's minimum of two jobs in every queue; this can lead to long tails if you get unlucky...

use flake
/.direnv
__pycache__
/*.json
import time
import pytest
N = 13
@pytest.mark.expensive(seconds = 1)
@pytest.mark.parametrize("sleep", (n / 10 for n in range(10, 10 + N)))
def test_slow(sleep: float):
time.sleep(sleep)
@pytest.mark.parametrize("sleep", (n / 10000 for n in range(10, 10 + (N * 40))))
def test_fast(sleep: float):
time.sleep(sleep)
class TestsInAScope:
@pytest.mark.expensive(seconds = 3)
def test_quite_slow(self):
time.sleep(2)
@pytest.mark.expensive(seconds = 0.1)
@pytest.mark.parametrize("sleep", (n / 1000 for n in range(10, 10 + (N * 10))))
def test_med(self, sleep: float):
time.sleep(sleep)
from collections.abc import Generator, Iterable
import itertools
import json
import os
from typing import Any
import pytest
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--chrome-trace",
action="store",
default=None,
help="Path to write a Chrome trace containing pytest test durations",
)
parser.addoption(
"--expensive-tests-first",
action="store_true",
dest="expensive_first",
default=False,
help=(
"schedule tests annotated with `@pytest.mark.expensive` first; "
"useful for reducing aggregate test runtime when using `xdist`"
),
)
def _build_trace_events(stats: dict[str, list[Any]]) -> list[dict[str, Any]]:
events: list[dict[str, Any]] = []
worker_ids = itertools.count()
workers = dict[str, int]()
def worker_id(key: None | str) -> int:
if key is None:
key = "<main thread>"
key = str(key)
if key not in workers:
workers[key] = next(worker_ids)
return workers[key]
for bucket, reports in stats.items():
for report in reports:
start = getattr(report, "start", None)
stop = getattr(report, "stop", None)
if start is None or stop is None:
continue
ts_us = int(start * 1_000_000)
dur_us = int(max(stop - start, 0) * 1_000_000)
kind = getattr(report, "when", "<unknown>")
events.append(
{
"name": getattr(report, "nodeid", bucket),
"cat": kind,
"ph": "X",
"ts": ts_us,
"dur": dur_us,
"pid": 0,
"tid": worker_id(getattr(report, "node", None)),
"args": dict(
keywords=getattr(report, "keywords", {}),
location=getattr(report, "location", {}),
),
}
)
# name events for worker threads:
for worker_name, w_id in workers.items():
events.append(
{
"name": "thread_name",
"ph": "M",
"pid": 0,
"tid": w_id,
"args": dict(name=worker_name),
}
)
return events
def pytest_terminal_summary(
terminalreporter: pytest.TerminalReporter,
exitstatus: pytest.ExitCode,
config: pytest.Config,
) -> None:
chrome_trace_path = config.getoption("--chrome-trace")
if not chrome_trace_path:
return
stats = terminalreporter.stats
trace_events = _build_trace_events(stats)
trace: dict[str, Any] = {"traceEvents": trace_events, "displayTimeUnit": "ms"}
os.makedirs(os.path.dirname(chrome_trace_path) or ".", exist_ok=True)
with open(chrome_trace_path, "w", encoding="utf-8") as fp:
json.dump(trace, fp, sort_keys=True)
# https://github.com/pytest-dev/pytest-xdist/blob/9329e6d2144fc92b91ead2680ba2241fd1171cc9/src/xdist/plugin.py#L299-L301
# see `xdist` for the scheme used to partition tests across workers (when using
# the workstealing strategy):
@pytest.hookimpl(wrapper=True)
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]):
try:
return (yield)
finally:
# after all other `modifyitems` hooks have run:
# if we were asked to execute tests marked as `expensive` first, reorder
# `items`:
if config.getoption("expensive_first"):
expensive, normal = [], []
for item in items:
# TODO: ideally we'd sort by expected runtime (schedule longest running
# first) but this is good enough for now
if (mark := next(item.iter_markers("expensive"), None)) is not None:
cost = (
(60 * 60) * float(mark.kwargs.get("hours", 0))
+ 60 * float(mark.kwargs.get("minutes", 0))
+ 1 * float(mark.kwargs.get("seconds", 0))
+ (1 / 1000) * float(mark.kwargs.get("milliseconds", 0))
)
expensive.append((item, cost))
else:
normal.append(item)
# sort expensive items by time to run (highest first):
expensive.sort(key=lambda tup: tup[1], reverse=True)
# hoist the expensive tests to the front:
items[:] = [item for item, _ in expensive] + normal
# if running under `xdist`, we want to stripe the tests the workers such
# that — across the workers — the expensive tests are run first
#
# to do so we need to rely on internal details about how `pytest-xdist`
# partitions tests across workers (when using the workstealing strategy)
#
# see:
# - https://github.com/pytest-dev/pytest-xdist/blob/9329e6d2144fc92b91ead2680ba2241fd1171cc9/src/xdist/remote.py#L418
if worker_info := getattr(config, "workerinput", None):
num_workers = worker_info["workercount"]
# the scheme used divides up the tests roughly equally into buckets
# using the worker index number as a way to handle uneven splits:
# - https://github.com/pytest-dev/pytest-xdist/blob/9329e6d2144fc92b91ead2680ba2241fd1171cc9/src/xdist/scheduler/worksteal.py#L212-L216
# - https://github.com/pytest-dev/pytest-xdist/blob/9329e6d2144fc92b91ead2680ba2241fd1171cc9/src/xdist/scheduler/worksteal.py#L312-L317
items_remaining, worker_bucket_sizes = len(items), []
for i in range(num_workers):
workers_remaining = num_workers - i
bucket_size = items_remaining // workers_remaining
items_remaining -= bucket_size
worker_bucket_sizes.append(bucket_size)
assert items_remaining == 0
# we should evenly spread our tests (ordered so that more expensive ones
# are first) across these buckets:
def interperse[T](iter: list[Iterable[T]]) -> Generator[T]:
seen_any = True
while seen_any:
seen_any = False
for it in iter:
try:
item = next(it)
seen_any = True
yield item
except StopIteration:
continue
worker_buckets = [[] for _ in worker_bucket_sizes]
worker_indices = interperse(
[
itertools.repeat(w_idx, size)
for w_idx, size in enumerate(worker_bucket_sizes)
]
)
for worker_idx, item in zip(worker_indices, items, strict=True):
worker_buckets[worker_idx].append(item)
# concatenate buckets to produce the overall items list:
items[:] = [item for bucket in worker_buckets for item in bucket]
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1768364046,
"narHash": "sha256-PDFfpswLiuG/DcadTBb7dEfO3jX1fcGlCD4ZKSkC0M8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ea30586ee015f37f38783006a9bc9e4aa64d7d61",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixpkgs-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}
{
inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
inputs.flake-utils.url = "github:numtide/flake-utils";
outputs = inputs: inputs.flake-utils.lib.eachDefaultSystem (system: let
pkgs = inputs.nixpkgs.legacyPackages.${system};
in {
devShells.default = pkgs.mkShell {
packages = with pkgs; [
(python313.withPackages (ps: with ps; [ pytest pytest-xdist ]))
ruff
];
};
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment