Skip to content

Instantly share code, notes, and snippets.

@9999years
Last active June 4, 2025 18:53
Show Gist options
  • Save 9999years/0030984947b1a1b8a684292966961412 to your computer and use it in GitHub Desktop.
Save 9999years/0030984947b1a1b8a684292966961412 to your computer and use it in GitHub Desktop.
Track completion of a set number of tasks in Python with time estimates.
"""
When writing Python scripts, you often want to track how long a block of code
takes. The `timer` context manager provides a simple way to do this, with
automatic logging.
The `progress` context manager extends this functionality to track completion
of a set integer number of tasks which take roughly equal time to complete.
Additionally, tasks can be skipped without skewing the average time per task.
You tell it when you start and when you complete a task, and it will give you
useful progress information including an ETA.
"""
import logging
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Iterator
logger = logging.getLogger(__name__)
@dataclass
class Timer:
"""A timer from a given start time."""
start: datetime = field(default_factory=datetime.now)
"""The starting time of the timer."""
def elapsed(self) -> timedelta:
"""Calculate the elapsed time since the timer started."""
return datetime.now() - self.start
@contextmanager
def timer(description: str | None = None) -> Iterator[Timer]:
"""Track elapsed time for a block of code.
If a `description` is provided, the elapsed time will be logged when the
context manager is closed.
"""
if description is not None:
logger.debug(f"Starting {description}")
timer_ = Timer()
try:
yield timer_
finally:
if description is not None:
elapsed = format_timedelta(timer_.elapsed())
logger.debug(f"Finished {description} in {elapsed}")
@dataclass
class Progress:
"""Track progress of a set number of tasks.
This class estimates time remaining based on the total elapsed time and
provides pretty progress output.
"""
total: int
"""Total number of items to process."""
count: int = field(default=0)
"""Number of items processed and not skipped.
This is increased when you call `increment()`.
"""
skipped_count: int = field(default=0)
"""Count of items skipped.
This is increased when you call `skip()`.
This is stored separately from `count` because skipped items tend to be
processed much more quickly, throwing off averages.
"""
skipped_time: timedelta = field(default=timedelta(0))
"""
Total time spent processing skipped items.
This is increased when you call `skip()`.
"""
last_checkin: datetime = field(default_factory=datetime.now)
"""Last checkin time.
This is updated when you call `checkin()`, `increment()`, or `skip()`.
This is used to determine the total `skipped_time`.
"""
timer: Timer = field(default_factory=Timer)
"""A timer storing elapsed time since the start of the progress."""
def checkin(self) -> timedelta:
"""Update the checkin time, returning the elapsed time since the last checkin."""
now = datetime.now()
elapsed = now - self.last_checkin
self.last_checkin = now
return elapsed
def increment(self) -> str:
"""Increment the counter and describe the progress."""
self.count += 1
self.checkin()
return str(self)
def skip(self) -> str:
"""Skip an item and describe the progress."""
self.skipped_count += 1
self.skipped_time += self.checkin()
return str(self)
def elapsed_per_count(self) -> timedelta:
"""Determine the average elapsed time per non-skipped item."""
if self.count == 0:
return timedelta(0)
return (self.timer.elapsed() - self.skipped_time) / self.count
def elapsed_per_skip(self) -> timedelta:
"""Determine the average elapsed time per skipped item."""
if self.skipped_count == 0:
return timedelta(0)
return self.skipped_time / self.skipped_count
def elapsed_per_processed(self) -> timedelta:
"""Determine the average elapsed time per processed item, weighted
across skipped and non-skipped items.
"""
skipped_ratio = self.skipped_ratio()
return (
self.elapsed_per_count() * (1 - skipped_ratio)
+ self.elapsed_per_skip() * skipped_ratio
)
def skipped_ratio(self) -> float:
"""Determine the portion of items that are skipped."""
processed = self.count + self.skipped_count
if processed == 0:
return 0
return self.skipped_count / processed
def estimate_remaining(self) -> timedelta:
"""Estimate the remaining time based on the average elapsed time per processed item."""
return self.elapsed_per_processed() * (
self.total - (self.count + self.skipped_count)
)
def eta(self) -> datetime:
"""Estimate the completion time based on the current time and the estimated remaining time."""
return datetime.now() + self.estimate_remaining()
def processed_count(self) -> int:
"""Get the total number of processed items, including skipped items."""
return self.count + self.skipped_count
def __str__(self) -> str:
# Note: This is inefficient, a bunch of methods are called multiple
# times.
per_count = format_timedelta(self.elapsed_per_processed())
remaining = format_timedelta(self.estimate_remaining())
eta = self.eta().strftime("%Y-%m-%d %H:%M %p") # Sorry to non-Americans!
processed = self.processed_count()
return f"{processed}/{self.total} ({processed / self.total:.2%}), ≈{per_count} each\n{remaining} left -> eta {eta}"
@contextmanager
def progress(total: int, description: str | None = None) -> Iterator[Progress]:
"""Track progress of a `total` number tasks.
If a `description` is provided, the elapsed time will be logged when the
context manager is closed.
"""
with timer(description) as timer_:
progress_ = Progress(
total=total,
last_checkin=timer_.start,
timer=timer_,
)
yield progress_
def format_timedelta(delta: timedelta) -> str:
"""Pretty-format a `timedelta` object.
Times over 1 second are rounded to the nearest hundredths of a second (e.g.
`0:05:46.30`), and times under 1 second are formatted in milliseconds (e.g.
`16.67ms`).
"""
if delta < timedelta(0):
return "-" + format_timedelta(-delta)
if delta < timedelta(seconds=1):
return format(delta.microseconds / 1_000, ".2f") + "ms"
delta_rounded = timedelta(
days=delta.days,
seconds=delta.seconds,
# Round to hundredths (10ms = 0.01s).
microseconds=round(delta.microseconds, -4),
)
result = str(delta_rounded)
# lol
return result.removesuffix("0000")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment