Created
September 7, 2012 18:11
Predict plays
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from nflgame.statmap import idmap | |
from sqlalchemy import Column, String, Integer, Boolean | |
from sqlalchemy import create_engine | |
from sqlalchemy import func | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.orm import sessionmaker | |
import nflgame | |
import simplejson | |
import sys | |
engine = create_engine('sqlite:///plays.db', echo=False) | |
Session = sessionmaker(bind=engine) | |
Base = declarative_base() | |
class Play(Base): | |
__tablename__ = 'plays' | |
id = Column(Integer, primary_key=True) | |
team = Column(String) | |
down = Column(Integer) | |
distance = Column(Integer) | |
yardstoendzone = Column(Integer) | |
playcall = Column(Integer) # 1=pass, 2=rush, 0=everything else | |
desc = Column(String) | |
succesful = Column(Boolean) | |
def __repr__(self): | |
return u"<Play: %s %d & %d with %d to go, is-pass=%s - %s (succes=%s)>" % ( | |
self.team, self.down, self.distance, self.yardstoendzone, | |
self.ispass, self.desc, self.succesful | |
) | |
def buildstats(year): | |
Base.metadata.create_all(engine) | |
session = Session() | |
session.query(Play).delete() | |
session.commit() | |
for game in nflgame.games(year=year): | |
json = simplejson.loads(game.rawData) | |
for number, drive in json[game.eid]['drives'].iteritems(): | |
if not isinstance(drive, dict): | |
continue | |
for number, play in drive['plays'].iteritems(): | |
playcall = 0 | |
if len(play['players']) == 0 or play['down'] == 0: | |
continue | |
cats = [] | |
succesful = False | |
for playerid, stats in play['players'].iteritems(): | |
for stat in stats: | |
cats.append(stat['statId']) | |
cat = None | |
if stat['statId'] in idmap: | |
cat = idmap[stat['statId']]['cat'] | |
fields = idmap[stat['statId']]['fields'] | |
if cat == 'passing': | |
playcall=1 | |
elif cat == 'rushing': | |
playcall=2 | |
for field in fields: | |
if field in ('first_down', 'third_down_conv', 'fourth_down_conv', 'rushing_tds', 'passing_tds', 'receiving_tds'): | |
succesful = True | |
if playcall != 0: | |
break | |
play['playcall'] = playcall | |
if play['yrdln'] == '50': | |
yardstoendzone = play['yrdln'] | |
half = None | |
else: | |
half, yardstoendzone = play['yrdln'].split(" ") | |
yardstoendzone = int(yardstoendzone) | |
if play['posteam'] == half: | |
yardstoendzone += 50 | |
play['yardstoendzone'] = yardstoendzone | |
p = Play( | |
team=play['posteam'], | |
down=play['down'], | |
distance=play['ydstogo'], | |
yardstoendzone=yardstoendzone, | |
playcall=playcall, | |
desc=play['desc'], | |
succesful=succesful | |
) | |
session.add(p) | |
session.commit() | |
def predictplay(team, down, distance, fromposition, toposition): | |
down = int(down) | |
fromposition = int(fromposition) | |
toposition = int(toposition) | |
session = Session() | |
res = session.query(Play.playcall, func.count(Play.id)).filter(Play.team==team, Play.down==down, | |
Play.yardstoendzone >= fromposition, | |
Play.yardstoendzone <= toposition).group_by(Play.playcall) | |
if distance[-1] == '+': | |
res = res.filter(Play.distance >= int(distance[0:-1])) | |
else: | |
res = res.filter(Play.distance == int(distance)) | |
res = res.group_by(Play.playcall) | |
totalplays = sum([t[1] for t in res]) | |
resdict = dict(res) | |
passplays = 0 | |
runplays = 0 | |
if 1 in resdict: | |
passplays = float(resdict[1]) | |
if 2 in resdict: | |
runplays = float(resdict[2]) | |
return ( | |
team, down, distance, fromposition, toposition, | |
totalplays, | |
passplays/totalplays, | |
runplays/totalplays, | |
[], []) | |
def printprediction(p): | |
if p[5] == 0: | |
sys.stderr.write("No matching plays found\n") | |
sys.exit(1) | |
print """ Team: {0} | |
Down: {1} | |
Distance: {2} | |
Ball on: {3} < x < {4} | |
{0} faced this situation {5} times | |
They passed the ball: {6:.0%} | |
They ran the ball: {7:.0%}""".format(*p) | |
if __name__ == "__main__": | |
if sys.argv[1] == "build": | |
buildstats(2011) | |
else: | |
printprediction(predictplay(*sys.argv[1:])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment