Last active
September 26, 2019 15:01
-
-
Save aswild/c006956299552298b70a7c964a8bfda1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import inspect | |
from io import StringIO | |
from itertools import product | |
import random | |
import re | |
import sys | |
class DiceSet: | |
@staticmethod | |
def parse_dice_str(dice_str): | |
if not dice_str: | |
raise ValueError('dice is empty') | |
dice_strs = (x.strip() for x in dice_str.split('+')) | |
dice = {} | |
constant = 0 | |
for s in dice_strs: | |
# looking for a positive constant is easy, just see if we can parse as an int | |
try: | |
constant += int(s) | |
continue | |
except ValueError: | |
pass | |
# negaitve constants are trickier since we only split on '+', look for minus | |
# something at the end of this, and if found add it to the constant and | |
# then remove it from the string | |
m = re.match(r'^.*(?P<minus>-)\s*(?P<constant>\d+)$', s) | |
if m: | |
constant -= int(m.group('constant')) | |
s = s[:m.start('minus')].strip() | |
m = re.match(r'^(?P<count>\d*)(?:d(?P<die>\d+))?$', s) | |
if not m: | |
raise ValueError('Invalid roll "%s"'%s) | |
count = m.group('count') | |
die = m.group('die') | |
if not die: | |
# no die mentioned, add to the constant | |
if not count: | |
raise ValueError('Invalid roll "%s"'%s) | |
constant += int(count) | |
continue | |
die = int(die) | |
if count: | |
count = int(count) | |
else: | |
count = 1 | |
dice[die] = dice.get(die, 0) + count | |
return dice, constant | |
@staticmethod | |
def get_roll_total(roll): | |
return sum(map(lambda x: x[1], roll)) | |
def __init__(self, dice, constant=0): | |
self._constant = constant | |
if isinstance(dice, str): | |
self._dice, _constant = self.parse_dice_str(dice) | |
self._constant += _constant | |
elif isinstance(dice, dict): | |
self._dice = {} | |
for k, v in dice.items(): | |
self._dice[int(k)] = int(v) | |
# check if we can use /dev/urandom | |
self._urandom_fp = None | |
try: | |
self._urandom_fp = open('/dev/urandom', 'rb') | |
except Exception: | |
pass | |
def __del__(self): | |
if self._urandom_fp is not None: | |
try: | |
self._urandom_fp.close() | |
except Exception: | |
pass | |
finally: | |
self._urandom_fp = None | |
def __str__(self): | |
buf = StringIO() | |
first = True | |
for die in reversed(sorted(self._dice)): | |
if first: | |
first = False | |
else: | |
buf.write('+') | |
buf.write('%dd%d'%(self._dice[die], die)) | |
if self._constant: | |
buf.write('%+d'%self._constant) | |
return buf.getvalue() | |
def __repr__(self): | |
return '<DiceSet: %s>'%self.__str__() | |
@property | |
def dice(self): | |
return self._dice | |
@property | |
def constant(self): | |
return self._constant | |
def _roll_one(self, n): | |
""" Roll one die of value n. Use urandom if available, fallback to random.randint """ | |
try: | |
# a d4294967295 should be a reasonable upper limit, right? | |
b = self._urandom_fp.read(4) | |
i = int.from_bytes(b, byteorder=sys.byteorder, signed=False) | |
return (i % n) + 1 | |
except Exception: | |
return random.randint(1, n) | |
def roll(self): | |
""" Roll the dice and return a list of roll results, where each result | |
is a 2-tuple of the form (die, result). If a constant value is added, | |
it's returned as the 2-tuple (0, constant). For example, a 2d6+1d4+1 | |
roll may return [(6, 3), (6, 6), (4, 2), (0, 1)]. """ | |
result = [] | |
for die in reversed(sorted(self._dice)): | |
for _ in range(self._dice[die]): | |
result.append((die, self._roll_one(die))) | |
if self._constant: | |
result.append((0, self._constant)) | |
return result | |
def roll_total(self): | |
""" Roll the dice and return the total value. Convenience wrapper for | |
roll = dice_set.roll() | |
get_roll_total(roll) | |
""" | |
return self.get_roll_total(self.roll()) | |
def histogram(self): | |
""" Return a histogram representing the probability of rolling each | |
possible value. Data is returned as a dict of value:count pairs, with | |
value representing the roll total (including constant) and count | |
representing the number of rolls which generate that value out of the | |
total possible number of rolls. For example, the histogram of 2d6+1 | |
returns {3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 8: 6, 9: 5, 10: 4, 11: 3, 12: 2, 13: 1}. | |
To get the probability of a particular value, divide its count by the total | |
number of values (i.e. sum(hist.values())) | |
On Python 3.7+, where dict insertion order is guaranteed to be preserved, the | |
keys of the histogram dict will be in ascending order. On older versions, the | |
order may be arbitrary. | |
""" | |
ranges = [] | |
for die, count in self._dice.items(): | |
for _ in range(count): | |
ranges.append(range(1, die+1)) | |
hist = {} | |
for value in sorted(map(sum, product(*ranges))): | |
value = value + self._constant | |
hist[value] = hist.get(value, 0) + 1 | |
return hist | |
def roll_prob(self, low, high, normalize=False): | |
""" Return the number of rolls in this DiceSet's histogram with a value | |
between [low, high], inclusive. If low is 0, then consider all rolls | |
<=high. If high is 0, consider all rolls >=low. If normalize is True, | |
return the floating point probability in range [0.0, 1.0] rather that | |
the number of rolls which would produce that value. """ | |
if not (low or high): | |
raise ValueError('At least low or high must be specified') | |
if low and high and low > high: | |
raise ValueError('low cannot be greater than high') | |
if low and high: | |
check = lambda x: x >= low and x <= high | |
elif low: | |
check = lambda x: x >= low | |
else: | |
check = lambda x: x <= high | |
hist = self.histogram() | |
total = 0 | |
for roll, count in hist.items(): | |
if check(roll): | |
total += count | |
return (total / sum(hist.values())) if normalize else total | |
def roll_prob_str(self, spec, normalize=False): | |
""" Parse spec in one of the forms 'a-b', 'a+', or 'a-' and pass it | |
to self.roll_prob. """ | |
try: | |
# if spec is just one number, then it's easy, and short-circuit the regex stuff | |
x = int(spec) | |
except ValueError: | |
pass | |
else: | |
return self.roll_prob(x, x, normalize) | |
m = re.match(r'(\d+)([+-])(\d+)?$', spec) | |
if m is None: | |
raise ValueError(f'Invalid spec: "{spec}"') | |
a = int(m.group(1)) | |
plus = m.group(2) == '+' | |
b = int(m.group(3)) if m.group(3) else None | |
if plus and b is not None: | |
raise ValueError(f'Invalid spec: "{spec}"') | |
if plus: | |
return self.roll_prob(a, 0, normalize) | |
if b is None: | |
return self.roll_prob(0, a, normalize) | |
return self.roll_prob(a, b, normalize) | |
@staticmethod | |
def _ndigits(n): | |
""" Helper function for print_histogram. | |
Return the number of digits needed to display an integer. For positive | |
values, this is mathematically equivalent to floor(log10(n))+1, but | |
this definition is unreliable due to floating-point inexactness, and | |
special handling is needed for zero and negative numbers anyway. """ | |
assert isinstance(n, int) | |
digits = 0 | |
negative = False | |
if n < 0: | |
negative = True | |
n = abs(n) | |
while ((10 ** digits) - 1) < n: | |
digits += 1 | |
return digits + 1 if negative else digits | |
def print_histogram(self): | |
hist = self.histogram() | |
# check width of max and min roll values, in case of negatives | |
value_width = max(self._ndigits(max(hist.keys())), self._ndigits(min(hist.keys()))) | |
value_width = max(value_width, len('Roll')) | |
# value counts are always positive | |
count_width = self._ndigits(max(hist.values())) | |
count_width = max(count_width, len('Count')) | |
print('%*s %*s Probability'%(value_width, 'Roll', count_width, 'Count')) | |
total_count = sum(hist.values()) | |
for value in sorted(hist): | |
count = hist[value] | |
prob = (count / total_count) * 100.0 | |
print('%*d %*d %4.1f%% %s'%(value_width, value, count_width, count, prob, '#'*count)) | |
class DiceRollCLI: | |
class CmdBase: | |
# subclasses should set this to the command name and help text description | |
NAME = '' | |
HELP = '' | |
@classmethod | |
def populate_parser(cls, parser): | |
""" Take the given argparse.ArgumentParser object and add arguments it """ | |
raise NotImplementedError('unimplemented abstract method in %s'%cls) | |
@classmethod | |
def run(cls, args): | |
""" Run the command with the given args Namespace from argparse """ | |
raise NotImplementedError('unimplemented abstract method in %s'%cls) | |
class CmdRollBase(CmdBase): | |
# common code for roll and qroll | |
NAME = '' | |
HELP = '' | |
@classmethod | |
def populate_parser(cls, parser): | |
parser.add_argument('-c', '--count', type=int, default=1, | |
help='Number of times to roll this dice set') | |
if not cls._always_quiet: | |
parser.add_argument('-q', '--quiet', action='store_true', | |
help='Quiet mode, print only the total, not the result of each die.') | |
parser.add_argument('dice', help='Set of dice to roll, e.g. "d20", "2d6", or "2d4+1"') | |
@classmethod | |
def run(cls, args): | |
if args.count < 1: | |
sys.exit('Error: roll count must be positive') | |
dice = DiceSet(args.dice) | |
for i in range(args.count): | |
roll = dice.roll() | |
total = dice.get_roll_total(roll) | |
if cls._always_quiet or args.quiet: | |
print(total) | |
else: | |
print('%s: %s'%(total, roll)) | |
class CmdRoll(CmdRollBase): | |
NAME = 'roll' | |
HELP = 'Roll some dice (the default)' | |
_always_quiet = False | |
class CmdRollQuiet(CmdRollBase): | |
NAME = 'qroll' | |
HELP = 'Roll some dice, only print the result (shortcut for "roll -q")' | |
_always_quiet = True | |
class CmdHistogram(CmdBase): | |
NAME = 'histogram' | |
HELP = 'Display a roll probability histogram' | |
@classmethod | |
def populate_parser(cls, parser): | |
parser.add_argument('dice', help='Set of dice to roll, e.g. "d20", "2d6", or "2d4+1"') | |
@classmethod | |
def run(cls, args): | |
dice = DiceSet(args.dice) | |
dice.print_histogram() | |
class CmdProb(CmdBase): | |
NAME = 'prob' | |
HELP = 'Get the probability of a give roll.' | |
@classmethod | |
def populate_parser(cls, parser): | |
parser.add_argument('dice', help='Set of dice to roll, e.g. "d20", "2d6", or "2d4+1"') | |
parser.add_argument('value', help='Target roll value, can be of the form "2" (exactly 2), '+ | |
'"6-" (6 or less), "7-9" (inclusive range), or "10+" (at least 10)') | |
@classmethod | |
def run(cls, args): | |
dice = DiceSet(args.dice) | |
prob = dice.roll_prob_str(args.value, normalize=True) | |
print('%.1f%%'%(prob * 100.0)) | |
@classmethod | |
def get_cmds(cls): | |
# cache the command list | |
try: | |
return cls._cmds | |
except AttributeError: | |
pass | |
cmds = {} | |
for _, cmdclass in inspect.getmembers(cls, inspect.isclass): | |
# filter on subclasses of CmdBase and a non-empty name (to distinguish from the base class) | |
if issubclass(cmdclass, cls.CmdBase) and cmdclass.NAME: | |
cmds[cmdclass.NAME] = cmdclass | |
cls._cmds = cmds | |
return cls._cmds | |
@staticmethod | |
def cmd_search(cmd_names, cmd): | |
""" Search for cmd in cmd_names, allowing for partial matches. | |
If an exact match is found, or only one partial match, return it. | |
If multiple partial matches found, return a list of them. | |
If no matches found, return None. """ | |
if cmd in cmd_names: | |
return cmd | |
partial_matches = [] | |
for c in cmd_names: | |
if c.startswith(cmd): | |
partial_matches.append(c) | |
if len(partial_matches) == 1: | |
return partial_matches[0] | |
if len(partial_matches) > 1: | |
return partial_matches | |
return None | |
@classmethod | |
def main(cls, argv=None): | |
if argv is None: | |
argv = sys.argv[1:] | |
# Hack around argparse so that roll can be the default command. Look | |
# at the first argument and see if it can be parsed as a DiceSet | |
if argv: | |
try: | |
DiceSet.parse_dice_str(argv[0]) | |
return cls.main_roll(argv) | |
except ValueError: | |
pass | |
cmds = cls.get_cmds() | |
parser = argparse.ArgumentParser(description='Roll some dice!', | |
epilog='Valid commands are: %s'%', '.join(cmds.keys())) | |
parser.add_argument('command', help='What to do.') | |
parser.add_argument('command_args', nargs=argparse.REMAINDER, | |
help='Arguments for the command. Use "%s COMMAND -h" for command help'%parser.prog) | |
args = parser.parse_args() | |
cmd_name = cls.cmd_search(cmds.keys(), args.command) | |
if cmd_name is None: | |
return 'Error: unknown command: ' + args.command | |
if isinstance(cmd_name, list): | |
return 'Error: multiple possible commands, be more specific: ' + ', '.join(cmd_name) | |
cmdclass = cmds[cmd_name] | |
cmdparser = argparse.ArgumentParser(prog='%s %s'%(parser.prog, cmdclass.NAME), description=cmdclass.HELP) | |
cmdclass.populate_parser(cmdparser) | |
cmdargs = cmdparser.parse_args(args.command_args) | |
try: | |
cmds[cmd_name].run(cmdargs) | |
except Exception as e: | |
return f'Error: {e}' | |
@classmethod | |
def main_roll(cls, argv): | |
cmd = cls.get_cmds()['roll'] | |
parser = argparse.ArgumentParser(prog=sys.argv[0]+' roll', description='Roll some dice!') | |
cmd.populate_parser(parser) | |
args = parser.parse_args(argv) | |
cmd.run(args) | |
if __name__ == '__main__': | |
sys.exit(DiceRollCLI.main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment