import matplotlib.pyplot as plt import pandas as pd import numpy as np def percentile_from_pdf(pdf, bin_centers, percentile=0.5): cdf = 0 for pdf_val, bin_low in zip(pdf, bin_centers[:-1]): cdf += pdf_val if cdf > percentile: return bin_low # TODO: Interpolate print(pdf, bin_centers) raise ValueError(f"couldn't find percentile: {percentile}, cdf: {cdf}" ) def pdf_stats(pdf, bins): from collections import namedtuple Stats = namedtuple('Stats', ['hist', 'mean', 'median', 'quart_high', 'quart_low']) bin_centers = (bins[:-1] + bins[1:])/2 mean = np.average(bin_centers, weights=pdf) median = percentile_from_pdf(pdf, bin_centers) quart_low = percentile_from_pdf(pdf, bin_centers, 0.25) quart_high = percentile_from_pdf(pdf, bin_centers, 0.75) return Stats((pdf, bins), mean, median, quart_high, quart_low) def get_stats_congress(year, age_max=110, parties=None, states=None): query = f'''\ SELECT yob, position, party FROM Member WHERE congress={year} AND position IN ("Representative", "Senator") ''' if parties: query += ' AND party IN (' + ", ".join(f'"{party}"' for party in parties) + ')' if states: query += ' AND state IN (' + ", ".join(f'"{state}"' for state in states) + ')' data = pd.read_sql_query(query, 'sqlite:///us_congress_members.sqlite3') data['age'] = year - data.yob pdf, bins = np.histogram(data.age, bins=age_max, range=(0, age_max), density=True) return pdf_stats(pdf, bins) def get_stats_genpop(year_data, age_max=110): pdf, bins = np.histogram(year_data.AGE, bins=age_max, range=(0, age_max), weights=year_data.PERWT, density=True) return pdf_stats(pdf, bins) def plot_pdf(genpop_stats, congress_stats): import matplotlib.pyplot as plt genpop_pdf, genpop_bins = genpop_stats.hist congress_pdf, congress_bins = congress_stats.hist plt.plot(genpop_bins[:-1], genpop_pdf, 'r.', label='U.S. Population') plt.plot(congress_bins[:-1], congress_pdf, 'b.', label='Congress') plt.legend() plt.show() def plot_yearly_stats(congress_stats, genpop_stats): congress_years = [] congress_medians = [] congress_quart_highs = [] congress_quart_lows = [] for year, year_stats in congress_stats.items(): congress_years.append(year) congress_medians.append(year_stats.median) congress_quart_highs.append(year_stats.quart_high) congress_quart_lows.append(year_stats.quart_low) genpop_years = [] genpop_medians = [] genpop_quart_highs = [] genpop_quart_lows = [] for year, year_stats in genpop_stats.items(): genpop_years.append(year) genpop_medians.append(year_stats.median) genpop_quart_highs.append(year_stats.quart_high) genpop_quart_lows.append(year_stats.quart_low) plt.fill_between(genpop_years, genpop_medians, genpop_quart_highs, color='b', alpha=0.3) plt.fill_between(genpop_years, genpop_medians, genpop_quart_lows, color='b', alpha=0.3) plt.plot(genpop_years, genpop_medians, 'b', label='U.S. Population') plt.fill_between(congress_years, congress_medians, congress_quart_highs, color='r', alpha=0.3) plt.fill_between(congress_years, congress_medians, congress_quart_lows, color='r', alpha=0.3) plt.plot(congress_years, congress_medians, 'r', label='Congress') plt.legend() plt.grid() plt.ylabel('Age') plt.xlabel('Year') plt.show() if __name__ == '__main__': stats_genpop = {} stats_congress = {} for year in range(1850, 2017): stats_congress[year] = get_stats_congress(year) people = pd.read_csv('usa_00001.csv', usecols=['YEAR', 'AGE', 'PERWT'], index_col='YEAR') for year in people.index.unique(): stats_genpop[year] = get_stats_genpop(people.loc[year]) plot_yearly_stats(stats_congress, stats_genpop)