Created
May 5, 2025 11:57
-
-
Save sebastianknopf/446da12b88e3b5adb08d16b2ec1ce3ad to your computer and use it in GitHub Desktop.
Python implementation for a set-cover-problem: Which are the minimum stations to meet every trip of a GTFS dataset at least once?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import os | |
import re | |
import sys | |
from collections import defaultdict | |
from datetime import datetime | |
def reduce_ifopt(ifopt_id: str) -> str: | |
pattern = r"^[a-z]{2}:\d{5}:[\w\d]+(:[\w\d]+){0,2}$" | |
if re.fullmatch(pattern, ifopt_id) is not None: | |
return ':'.join(ifopt_id.split(":")[:3]) | |
else: | |
return ifopt_id | |
def load_valid_service_ids(gtfs_path: str) -> set: | |
current_date: int = int(datetime.today().strftime('%Y%m%d')) | |
current_weekday: str = datetime.today().strftime('%A').lower() | |
valid_service_ids: set = set() | |
# load calendar.txt | |
calendar_path: str = f"{gtfs_path}/calendar.txt" | |
if os.path.exists(calendar_path): | |
with open(calendar_path, encoding='utf-8-sig') as f: | |
reader = csv.DictReader(f, delimiter=',') | |
for row in reader: | |
if row[current_weekday] == '1' and int(row['start_date']) <= current_date and int(row['end_date']) >= current_date: | |
valid_service_ids.add(row['service_id']) | |
# load calendar_dates.txt | |
calendar_dates_path: str = f"{gtfs_path}/calendar_dates.txt" | |
if os.path.exists(calendar_dates_path): | |
with open(calendar_dates_path, encoding='utf-8-sig') as f: | |
reader = csv.DictReader(f, delimiter=',') | |
for row in reader: | |
service_id: str = row['service_id'] | |
exception_type: str = row['exception_type'] | |
if int(row['date']) == current_date: | |
if exception_type == '1': | |
valid_service_ids.add(service_id) | |
elif exception_type == '2': | |
valid_service_ids.discard(service_id) | |
return valid_service_ids | |
def load_trip_ids_for_service_ids(gtfs_path: str, valid_service_ids: set, filter_route_ids: list | None = None) -> set: | |
valid_trips: set = set() | |
# load trips.txt | |
trips_path: str = f"{gtfs_path}/trips.txt" | |
with open(trips_path, encoding='utf-8-sig') as f: | |
reader = csv.DictReader(f, delimiter=',') | |
for row in reader: | |
if row['service_id'] in valid_service_ids: | |
if filter_route_ids is not None: | |
if any(r in row['route_id'] for r in filter_route_ids): | |
valid_trips.add(row['trip_id']) | |
else: | |
valid_trips.add(row['trip_id']) | |
return valid_trips | |
def load_trips_per_stop(gtfs_path: str, valid_trip_ids: set) -> dict: | |
trips_per_stop: dict = defaultdict(set) | |
# load stop_times.txt | |
stop_times_path: str = f"{gtfs_path}/stop_times.txt" | |
with open(stop_times_path, encoding='utf-8-sig') as f: | |
reader = csv.DictReader(f, delimiter=',') | |
for row in reader: | |
if row['trip_id'] in valid_trip_ids: | |
stop_ifopt_reduced = reduce_ifopt(row['stop_id']) | |
trips_per_stop[stop_ifopt_reduced].add(row['trip_id']) | |
return trips_per_stop | |
def find_minimal_stop_set(trips_per_stop): | |
all_trips = set().union(*trips_per_stop.values()) | |
covered_trips: set = set() | |
selected_stops: set = set() | |
while covered_trips != all_trips: | |
best_stop = max( | |
trips_per_stop.items(), | |
key=lambda item: len(set(item[1]) - covered_trips) | |
)[0] | |
selected_stops.add(best_stop) | |
covered_trips.update(trips_per_stop[best_stop]) | |
return selected_stops | |
# Sample Usage | |
# Analyze the whole GTFS network: | |
# set-cover-gtfs.py ./data | |
# Analyze only certain route IDs of the GTFS networtk: | |
# set-cover-gtfs.py ./data 04720,04721,04722 | |
# | |
if __name__ == '__main__': | |
print('Loading valid service IDs for this operation day...') | |
service_ids = load_valid_service_ids(sys.argv[1]) | |
print(f"Found {len(service_ids)} service IDs") | |
print() | |
print('Loading valid trip IDs for this operation day...') | |
filter_route_ids = None | |
if len(sys.argv) > 2: | |
filter_route_ids = sys.argv[2].split(',') | |
filter_route_ids = [r.strip() for r in filter_route_ids] | |
print(f"Using route ID filter {filter_route_ids}") | |
trip_ids = load_trip_ids_for_service_ids(sys.argv[1], service_ids, filter_route_ids) | |
print(f"Found {len(trip_ids)} trip IDs") | |
print() | |
print('Loading every departures grouped by station IFOPT ...') | |
trips_per_stop = load_trips_per_stop(sys.argv[1], trip_ids) | |
print() | |
print('Calculating minimal set of required stations to match each trip at least once ...') | |
minimal_stop_set = find_minimal_stop_set(trips_per_stop) | |
print(f"Found {len(minimal_stop_set)} minimal stations") | |
print() | |
for i in minimal_stop_set: | |
print(i) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment