Created
January 2, 2023 12:03
-
-
Save akx/d8af21b79189cbb1a6dfbfad9abab0a5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import statistics | |
import time | |
def assign_into_groups( | |
n_people: int, | |
group_max_sizes: dict[int, set[int]], | |
) -> dict[int, set[int]]: | |
groups = [set() for _ in group_max_sizes] | |
available_group_indices = list(range(len(group_max_sizes))) | |
for person_id in range(n_people): | |
index = random.choice(available_group_indices) | |
groups[index].add(person_id) | |
if len(groups[index]) == group_max_sizes[index]: | |
available_group_indices.remove(index) | |
return {gid: frozenset(group) for gid, group in enumerate(groups)} | |
def score_solution( | |
solution: dict[int, set[int]], | |
group_max_sizes: dict[int, set[int]], | |
n_people: int, | |
) -> float: | |
if len(solution) != len(group_max_sizes): # Must have all groups | |
return 0 | |
person_to_group = { | |
person_id: gid for gid, group in solution.items() for person_id in group | |
} | |
if len(person_to_group) != n_people: # Must have all people | |
return 0 | |
for person_id, group_id in person_to_group.items(): | |
# Example: persons with adjacent IDs can't be in the same group | |
if person_to_group.get(person_id + 1) == group_id: | |
return 0 | |
if ( | |
person_to_group[12] == person_to_group[13] | |
): # Persons 12 and 13 can't stand each other | |
return 0 | |
# etc. other constraints here... | |
# Figure out a final score metric; the bigger the better. | |
# This example will try to minimize the variance of group sizes. | |
group_sizes = [len(group) for group in solution.values()] | |
avg_group_size = statistics.mean(group_sizes) | |
group_size_variance = statistics.variance(group_sizes, avg_group_size) | |
return (1 / (1 + group_size_variance)) * avg_group_size | |
def main() -> None: | |
# Just random examples; this should come from your problem domain. | |
person_names = {pid: f"Person {pid + 1}" for pid in range(90)} | |
group_max_sizes = {gid: (1 + gid) % 6 for gid in range(25)} | |
n_people = len(person_names) | |
best_solution = None | |
best_score = None | |
best_solution_attempt = 0 | |
attempt = 0 | |
t0 = time.time() | |
for attempt in range(100_000): | |
try: | |
solution = assign_into_groups( | |
n_people=n_people, group_max_sizes=group_max_sizes | |
) | |
score = score_solution( | |
solution, group_max_sizes=group_max_sizes, n_people=n_people | |
) | |
if score and (best_score is None or score > best_score): | |
best_solution = solution | |
best_score = score | |
best_solution_attempt = attempt | |
print(f"Attempt {attempt}: {best_score}") | |
except KeyboardInterrupt: | |
break | |
dur = time.time() - t0 | |
print( | |
f"\nStopped at {attempt=} (at {dur=:.2f} seconds; {attempt / dur:.1f} attempts per second), \n" | |
f"best solution found at {best_solution_attempt=}, {best_score=}:" | |
) | |
for group_id, person_ids in sorted(best_solution.items()): | |
print(f"Group {group_id:2d}: {sorted(person_names[pid] for pid in person_ids)}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment