2021-03-01 17:57:17 +11:00
|
|
|
import pandas as pd
|
|
|
|
import argparse
|
|
|
|
import matplotlib.pyplot as plt
|
2021-04-10 19:40:43 +10:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
# parse input arguments
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('scatter_a', help = 'output path of figure 1')
|
|
|
|
parser.add_argument('scatter_b', help = 'output path of figure 2')
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
all_covid_data = pd.read_csv('data/owid-covid-data.csv', encoding = 'ISO-8859-1')
|
|
|
|
|
|
|
|
# filter out data - only need 2020-12-31
|
|
|
|
all_covid_data = all_covid_data[(all_covid_data['date'] == '2020-12-31')]
|
|
|
|
all_covid_data.date = pd.to_datetime(all_covid_data.date)
|
|
|
|
|
|
|
|
# extract total cases and deaths for 2020 per country
|
|
|
|
total_cases = all_covid_data.loc[:, ['location', 'total_cases']]
|
|
|
|
total_cases.set_index(['location'], inplace = True)
|
|
|
|
total_deaths = all_covid_data.loc[:, ['location', 'total_deaths']]
|
|
|
|
|
|
|
|
# merge total_cases and total_deaths
|
|
|
|
aggregated_data = total_cases.merge(total_deaths, how = 'inner', on = 'location')
|
|
|
|
|
|
|
|
# compute case fatality rate for each country
|
|
|
|
aggregated_data['case_fatality_rate'] = (aggregated_data['total_deaths'] / aggregated_data['total_cases'])
|
|
|
|
|
|
|
|
# format aggregated_data
|
|
|
|
aggregated_data = aggregated_data.reindex(columns = ['location', 'case_fatality_rate', 'total_deaths', 'total_cases'])
|
|
|
|
|
|
|
|
# extract case fatality rate from aggregated data
|
|
|
|
case_fatality_rate = aggregated_data.loc[:, ['case_fatality_rate']]
|
|
|
|
|
|
|
|
# plot case-fatality-rate against number of cases
|
|
|
|
# this excludes some locations as they are outliers
|
|
|
|
plt.scatter(total_cases, case_fatality_rate)
|
|
|
|
plt.xlim(0, 3000000)
|
|
|
|
plt.ylim(0, 0.075)
|
|
|
|
plt.ylabel('Case fatality rate')
|
|
|
|
plt.xlabel('Total confirmed cases in 2020')
|
|
|
|
plt.grid(True)
|
|
|
|
plt.savefig(args.scatter_a, dpi = 250.0)
|
|
|
|
|
|
|
|
# plot case-fatality-rate against the log of the number of cases
|
|
|
|
total_cases = total_cases.apply(np.log10, axis = 1)
|
|
|
|
plt.scatter(total_cases, case_fatality_rate)
|
|
|
|
plt.xlim(0, 9)
|
|
|
|
plt.xlabel('Log of total confirmed cases in 2020')
|
|
|
|
plt.savefig(args.scatter_b, dpi = 250.0)
|