Created
March 20, 2020 08:38
-
-
Save monodera/f9a250bdf73827cb3a7ef0a133824030 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# %% | |
from datetime import datetime, timedelta, date | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from functools import reduce | |
import colorcet as cc | |
from bokeh.plotting import figure, output_file, show | |
from bokeh.models import ColumnDataSource, Range1d | |
from bokeh.palettes import viridis, cividis, all_palettes | |
from bokeh.layouts import column, row, gridplot | |
# (click on legend entries to highlight the corresponding lines)""", | |
def plot_cases( | |
corona_sums_states, | |
corona_sums_countries, | |
colors, | |
states_to_plot, | |
countries_to_plot, | |
case="Confirmed", | |
ymin=1, | |
ymax=1e5, | |
thresh_confirmed=1000, | |
): | |
pp = figure( | |
plot_width=800, | |
plot_height=600, | |
x_axis_type="datetime", | |
y_axis_type="log", | |
x_axis_label="Date", | |
y_axis_label="Number of cases", | |
) | |
pp.title.text_font_size = "16pt" | |
pp.xaxis.axis_label_text_font_size = "16pt" | |
pp.yaxis.axis_label_text_font_size = "16pt" | |
for i, state in enumerate(states_to_plot, start=0): | |
df_state = corona_sums_states[corona_sums_states["Province/State"] == state] | |
c = colors[i] | |
pp.line( | |
x="Date", | |
y="Count", | |
source=df_state[df_state["type"] == case], | |
line_width=3, | |
color=c, | |
alpha=0.8, | |
muted_color=c, | |
muted_alpha=0.4, | |
muted_line_width=2, | |
muted=True, | |
legend_group="Province/State", | |
) | |
for i, country in enumerate(countries_to_plot, start=len(states_to_plot)): | |
df_country = corona_sums_countries[ | |
corona_sums_countries["Country/Region"] == country | |
] | |
c = colors[i] | |
pp.line( | |
x="Date", | |
y="Count", | |
source=df_country[df_country["type"] == case], | |
line_width=3, | |
color=c, | |
alpha=0.8, | |
muted_color=c, | |
muted_alpha=0.4, | |
muted_line_width=2, | |
muted=True, | |
legend_group="Country/Region", | |
# legend_label=country, | |
) | |
for country in corona_sums_countries["Country/Region"].unique(): | |
df_country = corona_sums_countries[ | |
corona_sums_countries["Country/Region"] == country | |
] | |
if (country not in countries_to_plot) and ( | |
df_country[df_country["type"] == "Confirmed"]["Count"].max() | |
> thresh_confirmed | |
): | |
print(country, df_country[df_country["type"] == case]["Count"].max()) | |
pp.line( | |
x="Date", | |
y="Count", | |
source=df_country[df_country["type"] == case], | |
line_width=1.0, | |
color="gray", | |
alpha=0.8, | |
muted_color="gray", | |
muted_alpha=0.5, | |
muted_line_width=0.5, | |
muted=True, | |
legend_group="Country/Region", | |
) | |
pp.legend.location = "top_left" | |
pp.legend.click_policy = "mute" | |
pp.legend.title = "Click to highlight" | |
pp.x_range = Range1d(datetime(2020, 1, 1), datetime.now()) | |
pp.y_range = Range1d(ymin, ymax) | |
return pp | |
# reference : https://medium.com/analytics-vidhya/mapping-the-spread-of-coronavirus-covid-19-d7830c4282e | |
# %% | |
df_confirmed = pd.read_csv( | |
"../COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv" | |
) | |
df_deaths = pd.read_csv( | |
"../COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Deaths.csv" | |
) | |
df_recovered = pd.read_csv( | |
"../COVID-19/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Recovered.csv" | |
) | |
countries = sorted(df_confirmed["Country/Region"].unique()) | |
# %% | |
# 1.2 Tidying the data | |
# Using melt() command in pandas (similar to gather() in R's tidyr) | |
id_list = df_confirmed.columns.to_list()[:4] | |
vars_list = df_confirmed.columns.to_list()[4:] | |
confirmed_tidy = pd.melt( | |
df_confirmed, | |
id_vars=id_list, | |
value_vars=vars_list, | |
var_name="Date", | |
value_name="Confirmed", | |
) | |
deaths_tidy = pd.melt( | |
df_deaths, | |
id_vars=id_list, | |
value_vars=vars_list, | |
var_name="Date", | |
value_name="Deaths", | |
) | |
recovered_tidy = pd.melt( | |
df_recovered, | |
id_vars=id_list, | |
value_vars=vars_list, | |
var_name="Date", | |
value_name="Recovered", | |
) | |
print(recovered_tidy.head(10)) | |
active = ( | |
confirmed_tidy["Confirmed"] - deaths_tidy["Deaths"] - recovered_tidy["Recovered"] | |
) | |
active_tidy = recovered_tidy.copy() | |
active_tidy.rename(columns={"Recovered": "Active"}, inplace=True) | |
active_tidy["Active"] = active | |
# %% | |
# 1.3 Merging the three dataframes into one | |
data_frames = [confirmed_tidy, deaths_tidy, recovered_tidy, active_tidy] | |
df_corona = reduce( | |
lambda left, right: pd.merge(left, right, on=id_list + ["Date"], how="outer"), | |
data_frames, | |
) | |
# 1.4 Each row should only represent one observation | |
id_vars = df_corona.columns[:5] | |
data_type = ["Confirmed", "Deaths", "Recovered", "Active"] | |
df_corona = pd.melt( | |
df_corona, | |
id_vars=id_vars, | |
value_vars=data_type, | |
var_name="type", | |
value_name="Count", | |
) | |
df_corona["Date"] = pd.to_datetime(df_corona["Date"], format="%m/%d/%y", errors="raise") | |
corona_sums_states = df_corona.groupby( | |
["type", "Date", "Province/State"], as_index=False | |
).agg({"Count": "sum"}) | |
corona_sums_countries = df_corona.groupby( | |
["type", "Date", "Country/Region"], as_index=False | |
).agg({"Count": "sum"}) | |
# %% | |
# source = ColumnDataSource(data=df_corona) | |
states_to_plot = ["California", "Hawaii"] | |
countries_to_plot = [ | |
"China", | |
"Italy", | |
"Germany", | |
"Japan", | |
"Korea, South", | |
"Spain", | |
"US", | |
] | |
n_plot = len(states_to_plot) + len(countries_to_plot) | |
colors_interval = int(256 / (n_plot + 1)) | |
colors = cc.glasbey_dark[: n_plot + 1] | |
p1 = plot_cases( | |
corona_sums_states, | |
corona_sums_countries, | |
colors, | |
states_to_plot, | |
countries_to_plot, | |
case="Confirmed", | |
ymin=1, | |
ymax=1e5, | |
thresh_confirmed=1000, | |
) | |
p1.title.text = 'Number of "confirmed" COVID-19 cases' | |
p2 = plot_cases( | |
corona_sums_states, | |
corona_sums_countries, | |
colors, | |
states_to_plot, | |
countries_to_plot, | |
case="Deaths", | |
ymin=1, | |
ymax=1e5, | |
thresh_confirmed=1000, | |
) | |
p2.title.text = 'Number of "death" COVID-19 cases' | |
p3 = plot_cases( | |
corona_sums_states, | |
corona_sums_countries, | |
colors, | |
states_to_plot, | |
countries_to_plot, | |
case="Recovered", | |
ymin=1, | |
ymax=1e5, | |
thresh_confirmed=1000, | |
) | |
p3.title.text = 'Number of "recovered" COVID-19 cases' | |
p4 = plot_cases( | |
corona_sums_states, | |
corona_sums_countries, | |
colors, | |
states_to_plot, | |
countries_to_plot, | |
case="Active", | |
ymin=1, | |
ymax=1e5, | |
thresh_confirmed=1000, | |
) | |
p4.title.text = 'Number of "currently active" COVID-19 cases' | |
output_file("../gist/index.html", title="COVID-19 Cases") | |
p = column(p1, p2, p3, p4) | |
show(p) | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment