Last active
January 12, 2021 12:36
-
-
Save ghamerly/723c6bad926d6c1523c094a3b6a3eb6a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
*.csv | |
*.png | |
*.svg |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
''' | |
This script uses data from https://github.com/CSSEGISandData/COVID-19/ and: | |
- grabs it (using the requests package) | |
- filters it (using the command line arguments and only a recent number of days) | |
- fits an exponential growth model to the data, | |
- plots the data and the growth curve, | |
- makes predictions of the numbers using the growth curve. | |
It can plot each type of data found under the subdirectory | |
csse_covid_19_data/csse_covid_19_time_series (i.e. per-country, or | |
per-USA-county). | |
''' | |
import argparse | |
import csv | |
import hashlib | |
import os | |
import sys | |
import time | |
import matplotlib.pyplot | |
import requests | |
import numpy.linalg | |
class Constants: # pylint: disable=too-few-public-methods | |
'''Gathering place for data needed elsewhere.''' | |
_GLOBAL_HEADERS = { | |
'major': 'Country/Region', | |
'major_default': 'US', | |
'minor': 'Province/State', | |
} | |
_USA_HEADERS = { | |
'major': 'Province_State', | |
'major_default': None, | |
'minor': 'Admin2', | |
} | |
# split long string over multiple lines... | |
DATA_URL_BASE = ( | |
'https://raw.githubusercontent.com/CSSEGISandData/COVID-19/' | |
'master/csse_covid_19_data/csse_covid_19_time_series/') | |
DATA_OPTIONS = { | |
'deaths': {'file': 'time_series_covid19_deaths_global.csv', | |
'headers': _GLOBAL_HEADERS, | |
'description': 'deaths'}, | |
'infections': {'file': 'time_series_covid19_confirmed_global.csv', | |
'headers': _GLOBAL_HEADERS, | |
'description': 'infections'}, | |
'recovered': {'file': 'time_series_covid19_recovered_global.csv', | |
'headers': _GLOBAL_HEADERS, | |
'description': 'recovered'}, | |
'usa_infections': {'file': 'time_series_covid19_confirmed_US.csv', | |
'headers': _USA_HEADERS, | |
'description': 'infections'}, | |
'usa_deaths': {'file': 'time_series_covid19_deaths_US.csv', | |
'headers': _USA_HEADERS, | |
'description': 'deaths'}, | |
} | |
def parse_args(): | |
'''Parse the command line arguments (and create extra fields "url" and | |
"headers").''' | |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--major', default='', help='Major region (country or USA state)') | |
parser.add_argument('--minor', default='', help='Minor region (state/province or USA county)') | |
parser.add_argument('--days', type=int, default=20, help='How many days to look at') | |
parser.add_argument('--test', type=int, default=0, \ | |
help='How many recent examples to use for testing (and not train)') | |
parser.add_argument('--which', choices=list(Constants.DATA_OPTIONS), default='infections', \ | |
help='Which series to plot') | |
parser.add_argument('--linear_y', action='store_true', default=False, \ | |
help='Use linear scale for y-axis') | |
parser.add_argument('--save', help='Filename of saved plot (if not specified, do not save)') | |
args = parser.parse_args() | |
args.url = Constants.DATA_URL_BASE + Constants.DATA_OPTIONS[args.which]['file'] | |
args.headers = Constants.DATA_OPTIONS[args.which]['headers'] | |
args.description = Constants.DATA_OPTIONS[args.which]['description'] | |
args.major = args.major or args.headers['major_default'] | |
return args | |
def get_data(args): | |
'''Get the data of interest''' | |
raw_csv = None | |
cached_data_file = 'cache_' + hashlib.md5(args.url.encode('utf-8')).hexdigest() + '.csv' | |
try: | |
s = os.stat(cached_data_file) | |
print(s) | |
if time.time() - s.st_mtime < 10 * 60: # cache data for 10 minutes | |
print('reading cached data from', cached_data_file) | |
raw_csv = open(cached_data_file).read() | |
else: | |
print('cached data is too old') | |
except Exception as e: | |
print('exception while getting cached data', e) | |
if not raw_csv: | |
print(f'retrieving data from {args.url}') | |
raw_csv = requests.get(args.url).text | |
with open(cached_data_file, 'w') as cache: | |
cache.write(raw_csv) | |
reader = csv.DictReader(list(raw_csv.split('\n'))) | |
for row in reader: | |
if row[args.headers['major']].lower() == args.major.lower() and \ | |
row[args.headers['minor']].lower() == args.minor.lower(): | |
return row | |
return None | |
def fit_exponential_model(args, log_train): | |
'''Fit a linear model to the log-data, i.e. log(y) ~ theta[0] + theta[1] * x''' | |
# data matrix | |
x_matrix = numpy.array([[1, i] for i in range(args.days - args.test)]) | |
x_t_x = numpy.matmul(numpy.transpose(x_matrix), x_matrix) | |
xtx_inv_xt = numpy.matmul(numpy.linalg.pinv(x_t_x), numpy.transpose(x_matrix)) | |
theta = numpy.matmul(xtx_inv_xt, numpy.transpose(log_train)) | |
return theta | |
def main(): | |
'''Parse the command line arguments, request the data, and plot it.''' | |
args = parse_args() | |
# get the data of interest | |
counts = get_data(args) | |
if counts is None: | |
print("Could not find the data you were looking for... bailing.") | |
return | |
# construct the data | |
recent = list(map(int, list(counts.values())[-args.days:])) | |
train = recent[:len(recent)-args.test] | |
test = recent[len(recent)-args.test:] | |
safe_log = lambda x: numpy.log(x) if x else -1 | |
log_train = list(map(safe_log, train)) # this is our "y" | |
log_test = list(map(safe_log, test)) | |
print('train', len(log_train), log_train) | |
print('test', len(log_test), log_test) | |
theta = fit_exponential_model(args, log_train) | |
print('estimated parameters', theta) | |
# compute and print the % daily growth and predictions | |
growth_str = make_predictions(args, theta, recent) | |
y_data = (train + test) if args.linear_y else (log_train + log_test) | |
plot_data(args, counts, theta, y_data, growth_str) | |
def make_predictions(args, theta, recent): | |
'''Compute the daily growth rate (according to the model) and make | |
predictions for the next 30 days based on that; return a string description | |
of that growth rate.''' | |
growth = numpy.exp(theta[1]) | |
growth_str = '{:0.1f}'.format((growth - 1) * 100) | |
print('the most recent count of {} is {}'.format(args.description, recent[-1])) | |
print('the {} are growing {}% each day'.format(args.description, growth_str)) | |
for day in range(30): | |
prediction = int((growth ** day) * recent[-1]) | |
print('in {} days, that means {} {}'.format(day, prediction, args.description)) | |
return growth_str | |
def plot_data(args, counts, theta, y_data, growth_str): | |
'''Plot the data and model using matplotlib''' | |
y_hat = [theta[0] + theta[1] * x for x in range(args.days)] | |
if args.linear_y: | |
y_hat = numpy.exp(y_hat) | |
matplotlib.pyplot.plot(range(args.days), y_data, 'x-') | |
matplotlib.pyplot.plot(range(args.days), y_hat) | |
# make the labels + title + legend | |
x_ticks = list(range(args.days)) | |
matplotlib.pyplot.xticks(x_ticks[::2], x_ticks[::-2]) | |
matplotlib.pyplot.xlabel('days ago') | |
if not args.linear_y: | |
y_min, y_max = matplotlib.pyplot.ylim() | |
y_ticks = [] | |
for i in range(10): | |
if y_min <= (i + 1) * numpy.log(10) and (i - 1) * numpy.log(10) <= y_max: | |
y_ticks.append((i * numpy.log(10), 10 ** i)) | |
matplotlib.pyplot.yticks([y[0] for y in y_ticks], [y[1] for y in y_ticks]) | |
matplotlib.pyplot.ylabel('{} {}'.format(args.description, '' if args.linear_y else ' (log scale)')) | |
most_recent_date = list(counts)[-1] | |
minor_name = '' | |
if counts[args.headers['minor']]: | |
minor_name = ' (' + counts[args.headers['minor']] + ')' | |
name = '{}{}'.format(counts[args.headers['major']], minor_name) | |
matplotlib.pyplot.title('COVID-19 {} in {}: {}% daily growth (as of {})'.format( \ | |
args.description, name, growth_str, most_recent_date)) | |
legend = ['reported cases', 'model fit'] | |
matplotlib.pyplot.axvspan(0, args.days - args.test - 1, facecolor='y', alpha=0.3) | |
legend.append('training region') | |
if args.test: | |
matplotlib.pyplot.axvspan(args.days - args.test - 1, args.days - 1, facecolor='g', alpha=0.3) | |
legend.append('prediction region') | |
matplotlib.pyplot.legend(legend) | |
# make sure all the labels fit on the plot | |
matplotlib.pyplot.tight_layout() | |
if args.save: | |
print('saving to', args.save) | |
matplotlib.pyplot.savefig(args.save, dpi=300) | |
matplotlib.pyplot.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment