Created
July 24, 2020 14:54
-
-
Save mrphlip/9e0e0913ef4610e2a207ee48758bc318 to your computer and use it in GitHub Desktop.
A numerical solver for the Riddler Pinball puzzle - https://fivethirtyeight.com/features/are-you-a-pinball-wizard/
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
from vector import Vector | |
from math import sqrt | |
import sys | |
class NoBounce(Exception): | |
def __init__(self, side): | |
self.side = side | |
def bounce_line(p, v): | |
if v[1] <= 0: | |
raise NoBounce(v[0] < 0) | |
# find the point where this will intersect the line | |
dy = 2 - p[1] | |
dx = v[0] * dy / v[1] | |
# reflect the velocity vertically | |
return Vector(p[0] + dx, 2), Vector(v[0], -v[1]) | |
def bounce_circle(p, v): | |
# project p onto perpendicular of v | |
# this gets us the closest approach of our vector to the circle | |
approach = Vector(v[1], -v[0]).project(p) | |
approach_dist_sq = approach.squaredist() | |
if approach_dist_sq > 1: | |
dy = 0 - p[1] | |
dx = v[0] * dy / v[1] | |
raise NoBounce(p[0] + dx < 0) | |
# use pythagoras to get the distance before this approach where we would hit the circle | |
hit_dist = sqrt(1 - approach_dist_sq) | |
# project this distance backward | |
hit_loc = (p - approach).norm(hit_dist) + approach | |
# calculate the reflected velocity | |
new_v = v - 2 * hit_loc.project(v) | |
return hit_loc, new_v | |
def run_calc(p, v, trace=True): | |
n = 0 | |
try: | |
while True: | |
if trace: | |
print(p, v) | |
p, v = bounce_line(p, v) | |
n += 1 | |
if trace: | |
print(p, v) | |
p, v = bounce_circle(p, v) | |
n += 1 | |
except NoBounce as e: | |
print("No Bounce after %d bounces, escaping to the %s" % (n, "left" if e.side else "right")) | |
return e.side | |
def do_test(x, trace=True): | |
p = Vector(-2, 0) | |
v = Vector(2-x, 2) | |
return run_calc(p, v, trace=trace) | |
def binarysearch(): | |
xmin, xmax = 0, 2 | |
while True: | |
x = (xmin + xmax) / 2 | |
print(x, end=': ') | |
if do_test(x, False): | |
if x >= xmax: | |
break | |
xmax = x | |
else: | |
if x <= xmin: | |
break | |
xmin = x | |
if __name__ == '__main__': | |
if len(sys.argv) > 1: | |
do_test(float(sys.argv[1])) | |
else: | |
binarysearch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment