Last active
June 4, 2025 18:53
-
-
Save 9999years/0030984947b1a1b8a684292966961412 to your computer and use it in GitHub Desktop.
Track completion of a set number of tasks in Python with time estimates.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
When writing Python scripts, you often want to track how long a block of code | |
takes. The `timer` context manager provides a simple way to do this, with | |
automatic logging. | |
The `progress` context manager extends this functionality to track completion | |
of a set integer number of tasks which take roughly equal time to complete. | |
Additionally, tasks can be skipped without skewing the average time per task. | |
You tell it when you start and when you complete a task, and it will give you | |
useful progress information including an ETA. | |
""" | |
import logging | |
from contextlib import contextmanager | |
from dataclasses import dataclass, field | |
from datetime import datetime, timedelta | |
from typing import Iterator | |
logger = logging.getLogger(__name__) | |
@dataclass | |
class Timer: | |
"""A timer from a given start time.""" | |
start: datetime = field(default_factory=datetime.now) | |
"""The starting time of the timer.""" | |
def elapsed(self) -> timedelta: | |
"""Calculate the elapsed time since the timer started.""" | |
return datetime.now() - self.start | |
@contextmanager | |
def timer(description: str | None = None) -> Iterator[Timer]: | |
"""Track elapsed time for a block of code. | |
If a `description` is provided, the elapsed time will be logged when the | |
context manager is closed. | |
""" | |
if description is not None: | |
logger.debug(f"Starting {description}") | |
timer_ = Timer() | |
try: | |
yield timer_ | |
finally: | |
if description is not None: | |
elapsed = format_timedelta(timer_.elapsed()) | |
logger.debug(f"Finished {description} in {elapsed}") | |
@dataclass | |
class Progress: | |
"""Track progress of a set number of tasks. | |
This class estimates time remaining based on the total elapsed time and | |
provides pretty progress output. | |
""" | |
total: int | |
"""Total number of items to process.""" | |
count: int = field(default=0) | |
"""Number of items processed and not skipped. | |
This is increased when you call `increment()`. | |
""" | |
skipped_count: int = field(default=0) | |
"""Count of items skipped. | |
This is increased when you call `skip()`. | |
This is stored separately from `count` because skipped items tend to be | |
processed much more quickly, throwing off averages. | |
""" | |
skipped_time: timedelta = field(default=timedelta(0)) | |
""" | |
Total time spent processing skipped items. | |
This is increased when you call `skip()`. | |
""" | |
last_checkin: datetime = field(default_factory=datetime.now) | |
"""Last checkin time. | |
This is updated when you call `checkin()`, `increment()`, or `skip()`. | |
This is used to determine the total `skipped_time`. | |
""" | |
timer: Timer = field(default_factory=Timer) | |
"""A timer storing elapsed time since the start of the progress.""" | |
def checkin(self) -> timedelta: | |
"""Update the checkin time, returning the elapsed time since the last checkin.""" | |
now = datetime.now() | |
elapsed = now - self.last_checkin | |
self.last_checkin = now | |
return elapsed | |
def increment(self) -> str: | |
"""Increment the counter and describe the progress.""" | |
self.count += 1 | |
self.checkin() | |
return str(self) | |
def skip(self) -> str: | |
"""Skip an item and describe the progress.""" | |
self.skipped_count += 1 | |
self.skipped_time += self.checkin() | |
return str(self) | |
def elapsed_per_count(self) -> timedelta: | |
"""Determine the average elapsed time per non-skipped item.""" | |
if self.count == 0: | |
return timedelta(0) | |
return (self.timer.elapsed() - self.skipped_time) / self.count | |
def elapsed_per_skip(self) -> timedelta: | |
"""Determine the average elapsed time per skipped item.""" | |
if self.skipped_count == 0: | |
return timedelta(0) | |
return self.skipped_time / self.skipped_count | |
def elapsed_per_processed(self) -> timedelta: | |
"""Determine the average elapsed time per processed item, weighted | |
across skipped and non-skipped items. | |
""" | |
skipped_ratio = self.skipped_ratio() | |
return ( | |
self.elapsed_per_count() * (1 - skipped_ratio) | |
+ self.elapsed_per_skip() * skipped_ratio | |
) | |
def skipped_ratio(self) -> float: | |
"""Determine the portion of items that are skipped.""" | |
processed = self.count + self.skipped_count | |
if processed == 0: | |
return 0 | |
return self.skipped_count / processed | |
def estimate_remaining(self) -> timedelta: | |
"""Estimate the remaining time based on the average elapsed time per processed item.""" | |
return self.elapsed_per_processed() * ( | |
self.total - (self.count + self.skipped_count) | |
) | |
def eta(self) -> datetime: | |
"""Estimate the completion time based on the current time and the estimated remaining time.""" | |
return datetime.now() + self.estimate_remaining() | |
def processed_count(self) -> int: | |
"""Get the total number of processed items, including skipped items.""" | |
return self.count + self.skipped_count | |
def __str__(self) -> str: | |
# Note: This is inefficient, a bunch of methods are called multiple | |
# times. | |
per_count = format_timedelta(self.elapsed_per_processed()) | |
remaining = format_timedelta(self.estimate_remaining()) | |
eta = self.eta().strftime("%Y-%m-%d %H:%M %p") # Sorry to non-Americans! | |
processed = self.processed_count() | |
return f"{processed}/{self.total} ({processed / self.total:.2%}), ≈{per_count} each\n{remaining} left -> eta {eta}" | |
@contextmanager | |
def progress(total: int, description: str | None = None) -> Iterator[Progress]: | |
"""Track progress of a `total` number tasks. | |
If a `description` is provided, the elapsed time will be logged when the | |
context manager is closed. | |
""" | |
with timer(description) as timer_: | |
progress_ = Progress( | |
total=total, | |
last_checkin=timer_.start, | |
timer=timer_, | |
) | |
yield progress_ | |
def format_timedelta(delta: timedelta) -> str: | |
"""Pretty-format a `timedelta` object. | |
Times over 1 second are rounded to the nearest hundredths of a second (e.g. | |
`0:05:46.30`), and times under 1 second are formatted in milliseconds (e.g. | |
`16.67ms`). | |
""" | |
if delta < timedelta(0): | |
return "-" + format_timedelta(-delta) | |
if delta < timedelta(seconds=1): | |
return format(delta.microseconds / 1_000, ".2f") + "ms" | |
delta_rounded = timedelta( | |
days=delta.days, | |
seconds=delta.seconds, | |
# Round to hundredths (10ms = 0.01s). | |
microseconds=round(delta.microseconds, -4), | |
) | |
result = str(delta_rounded) | |
# lol | |
return result.removesuffix("0000") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment