diff --git a/parta2.py b/parta2.py index 6c0db5e..896d572 100644 --- a/parta2.py +++ b/parta2.py @@ -1,3 +1,50 @@ import pandas as pd import argparse import matplotlib.pyplot as plt +import numpy as np + +# parse input arguments +parser = argparse.ArgumentParser() +parser.add_argument('scatter_a', help = 'output path of figure 1') +parser.add_argument('scatter_b', help = 'output path of figure 2') +args = parser.parse_args() + +all_covid_data = pd.read_csv('data/owid-covid-data.csv', encoding = 'ISO-8859-1') + +# filter out data - only need 2020-12-31 +all_covid_data = all_covid_data[(all_covid_data['date'] == '2020-12-31')] +all_covid_data.date = pd.to_datetime(all_covid_data.date) + +# extract total cases and deaths for 2020 per country +total_cases = all_covid_data.loc[:, ['location', 'total_cases']] +total_cases.set_index(['location'], inplace = True) +total_deaths = all_covid_data.loc[:, ['location', 'total_deaths']] + +# merge total_cases and total_deaths +aggregated_data = total_cases.merge(total_deaths, how = 'inner', on = 'location') + +# compute case fatality rate for each country +aggregated_data['case_fatality_rate'] = (aggregated_data['total_deaths'] / aggregated_data['total_cases']) + +# format aggregated_data +aggregated_data = aggregated_data.reindex(columns = ['location', 'case_fatality_rate', 'total_deaths', 'total_cases']) + +# extract case fatality rate from aggregated data +case_fatality_rate = aggregated_data.loc[:, ['case_fatality_rate']] + +# plot case-fatality-rate against number of cases +# this excludes some locations as they are outliers +plt.scatter(total_cases, case_fatality_rate) +plt.xlim(0, 3000000) +plt.ylim(0, 0.075) +plt.ylabel('Case fatality rate') +plt.xlabel('Total confirmed cases in 2020') +plt.grid(True) +plt.savefig(args.scatter_a, dpi = 250.0) + +# plot case-fatality-rate against the log of the number of cases +total_cases = total_cases.apply(np.log10, axis = 1) +plt.scatter(total_cases, case_fatality_rate) +plt.xlim(0, 9) +plt.xlabel('Log of total confirmed cases in 2020') +plt.savefig(args.scatter_b, dpi = 250.0) diff --git a/scatter-a.png b/scatter-a.png new file mode 100644 index 0000000..aae6d31 Binary files /dev/null and b/scatter-a.png differ diff --git a/scatter-b.png b/scatter-b.png new file mode 100644 index 0000000..b8c63eb Binary files /dev/null and b/scatter-b.png differ