import pandas as pd import argparse import matplotlib.pyplot as plt import numpy as np # parse input arguments parser = argparse.ArgumentParser() parser.add_argument('scatter_a', help = 'output path of figure 1') parser.add_argument('scatter_b', help = 'output path of figure 2') args = parser.parse_args() all_covid_data = pd.read_csv('data/owid-covid-data.csv', \ encoding = 'ISO-8859-1') # filter out data - only need 2020-12-31 all_covid_data = all_covid_data[(all_covid_data['date'] == '2020-12-31')] all_covid_data.date = pd.to_datetime(all_covid_data.date) # extract total cases and deaths for 2020 per country total_cases = all_covid_data.loc[:, ['location', 'total_cases']] total_cases.set_index(['location'], inplace = True) total_deaths = all_covid_data.loc[:, ['location', 'total_deaths']] # merge total_cases and total_deaths aggregated_data = total_cases.merge(total_deaths, how = 'inner', \ on = 'location') # compute case fatality rate for each country aggregated_data['case_fatality_rate'] = \ (aggregated_data['total_deaths'] / aggregated_data['total_cases']) # extract case fatality rate from aggregated data case_fatality_rate = aggregated_data.loc[:, ['case_fatality_rate']] # plot case-fatality-rate against number of cases # this excludes some locations as they are outliers plt.scatter(total_cases, case_fatality_rate) plt.xlim(0, 3000000) plt.ylim(0, 0.075) plt.ylabel('Case fatality rate') plt.xlabel('Total confirmed cases in 2020') plt.grid(True) plt.savefig(args.scatter_a, dpi = 250.0) # plot case-fatality-rate against the log of the number of cases total_cases = total_cases.apply(np.log10, axis = 1) plt.scatter(total_cases, case_fatality_rate) plt.xlim(0, 9) plt.xlabel('Log of total confirmed cases in 2020') plt.savefig(args.scatter_b, dpi = 250.0)