Skip to content

Instantly share code, notes, and snippets.

@sebastianknopf
Created May 5, 2025 11:57
Show Gist options
  • Save sebastianknopf/446da12b88e3b5adb08d16b2ec1ce3ad to your computer and use it in GitHub Desktop.
Save sebastianknopf/446da12b88e3b5adb08d16b2ec1ce3ad to your computer and use it in GitHub Desktop.
Python implementation for a set-cover-problem: Which are the minimum stations to meet every trip of a GTFS dataset at least once?
import csv
import os
import re
import sys
from collections import defaultdict
from datetime import datetime
def reduce_ifopt(ifopt_id: str) -> str:
pattern = r"^[a-z]{2}:\d{5}:[\w\d]+(:[\w\d]+){0,2}$"
if re.fullmatch(pattern, ifopt_id) is not None:
return ':'.join(ifopt_id.split(":")[:3])
else:
return ifopt_id
def load_valid_service_ids(gtfs_path: str) -> set:
current_date: int = int(datetime.today().strftime('%Y%m%d'))
current_weekday: str = datetime.today().strftime('%A').lower()
valid_service_ids: set = set()
# load calendar.txt
calendar_path: str = f"{gtfs_path}/calendar.txt"
if os.path.exists(calendar_path):
with open(calendar_path, encoding='utf-8-sig') as f:
reader = csv.DictReader(f, delimiter=',')
for row in reader:
if row[current_weekday] == '1' and int(row['start_date']) <= current_date and int(row['end_date']) >= current_date:
valid_service_ids.add(row['service_id'])
# load calendar_dates.txt
calendar_dates_path: str = f"{gtfs_path}/calendar_dates.txt"
if os.path.exists(calendar_dates_path):
with open(calendar_dates_path, encoding='utf-8-sig') as f:
reader = csv.DictReader(f, delimiter=',')
for row in reader:
service_id: str = row['service_id']
exception_type: str = row['exception_type']
if int(row['date']) == current_date:
if exception_type == '1':
valid_service_ids.add(service_id)
elif exception_type == '2':
valid_service_ids.discard(service_id)
return valid_service_ids
def load_trip_ids_for_service_ids(gtfs_path: str, valid_service_ids: set, filter_route_ids: list | None = None) -> set:
valid_trips: set = set()
# load trips.txt
trips_path: str = f"{gtfs_path}/trips.txt"
with open(trips_path, encoding='utf-8-sig') as f:
reader = csv.DictReader(f, delimiter=',')
for row in reader:
if row['service_id'] in valid_service_ids:
if filter_route_ids is not None:
if any(r in row['route_id'] for r in filter_route_ids):
valid_trips.add(row['trip_id'])
else:
valid_trips.add(row['trip_id'])
return valid_trips
def load_trips_per_stop(gtfs_path: str, valid_trip_ids: set) -> dict:
trips_per_stop: dict = defaultdict(set)
# load stop_times.txt
stop_times_path: str = f"{gtfs_path}/stop_times.txt"
with open(stop_times_path, encoding='utf-8-sig') as f:
reader = csv.DictReader(f, delimiter=',')
for row in reader:
if row['trip_id'] in valid_trip_ids:
stop_ifopt_reduced = reduce_ifopt(row['stop_id'])
trips_per_stop[stop_ifopt_reduced].add(row['trip_id'])
return trips_per_stop
def find_minimal_stop_set(trips_per_stop):
all_trips = set().union(*trips_per_stop.values())
covered_trips: set = set()
selected_stops: set = set()
while covered_trips != all_trips:
best_stop = max(
trips_per_stop.items(),
key=lambda item: len(set(item[1]) - covered_trips)
)[0]
selected_stops.add(best_stop)
covered_trips.update(trips_per_stop[best_stop])
return selected_stops
# Sample Usage
# Analyze the whole GTFS network:
# set-cover-gtfs.py ./data
# Analyze only certain route IDs of the GTFS networtk:
# set-cover-gtfs.py ./data 04720,04721,04722
#
if __name__ == '__main__':
print('Loading valid service IDs for this operation day...')
service_ids = load_valid_service_ids(sys.argv[1])
print(f"Found {len(service_ids)} service IDs")
print()
print('Loading valid trip IDs for this operation day...')
filter_route_ids = None
if len(sys.argv) > 2:
filter_route_ids = sys.argv[2].split(',')
filter_route_ids = [r.strip() for r in filter_route_ids]
print(f"Using route ID filter {filter_route_ids}")
trip_ids = load_trip_ids_for_service_ids(sys.argv[1], service_ids, filter_route_ids)
print(f"Found {len(trip_ids)} trip IDs")
print()
print('Loading every departures grouped by station IFOPT ...')
trips_per_stop = load_trips_per_stop(sys.argv[1], trip_ids)
print()
print('Calculating minimal set of required stations to match each trip at least once ...')
minimal_stop_set = find_minimal_stop_set(trips_per_stop)
print(f"Found {len(minimal_stop_set)} minimal stations")
print()
for i in minimal_stop_set:
print(i)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment