analyze.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import matplotlib.pyplot as plt
  2. import pandas as pd
  3. import numpy as np
  4. def percentile_from_pdf(pdf, bin_centers, percentile=0.5):
  5. cdf = 0
  6. for pdf_val, bin_low in zip(pdf, bin_centers[:-1]):
  7. cdf += pdf_val
  8. if cdf > percentile:
  9. return bin_low # TODO: Interpolate
  10. print(pdf, bin_centers)
  11. raise ValueError(f"couldn't find percentile: {percentile}, cdf: {cdf}" )
  12. def pdf_stats(pdf, bins):
  13. from collections import namedtuple
  14. Stats = namedtuple('Stats', ['hist', 'mean', 'median', 'quart_high', 'quart_low'])
  15. bin_centers = (bins[:-1] + bins[1:])/2
  16. mean = np.average(bin_centers, weights=pdf)
  17. median = percentile_from_pdf(pdf, bin_centers)
  18. quart_low = percentile_from_pdf(pdf, bin_centers, 0.25)
  19. quart_high = percentile_from_pdf(pdf, bin_centers, 0.75)
  20. return Stats((pdf, bins), mean, median, quart_high, quart_low)
  21. def get_stats_congress(year, age_max=110, parties=None, states=None):
  22. query = f'''\
  23. SELECT yob, position, party FROM Member
  24. WHERE congress={year} AND position IN ("Representative", "Senator")
  25. '''
  26. if parties:
  27. query += ' AND party IN (' + ", ".join(f'"{party}"' for party in parties) + ')'
  28. if states:
  29. query += ' AND state IN (' + ", ".join(f'"{state}"' for state in states) + ')'
  30. data = pd.read_sql_query(query, 'sqlite:///us_congress_members.sqlite3')
  31. data['age'] = year - data.yob
  32. pdf, bins = np.histogram(data.age, bins=age_max, range=(0, age_max), density=True)
  33. return pdf_stats(pdf, bins)
  34. def get_stats_genpop(year_data, age_max=110):
  35. pdf, bins = np.histogram(year_data.AGE, bins=age_max, range=(0, age_max), weights=year_data.PERWT, density=True)
  36. return pdf_stats(pdf, bins)
  37. def plot_pdf(genpop_stats, congress_stats):
  38. import matplotlib.pyplot as plt
  39. genpop_pdf, genpop_bins = genpop_stats.hist
  40. congress_pdf, congress_bins = congress_stats.hist
  41. plt.plot(genpop_bins[:-1], genpop_pdf, 'r.', label='U.S. Population')
  42. plt.plot(congress_bins[:-1], congress_pdf, 'b.', label='Congress')
  43. plt.legend()
  44. plt.show()
  45. def plot_yearly_stats(congress_stats, genpop_stats):
  46. congress_years = []
  47. congress_medians = []
  48. congress_quart_highs = []
  49. congress_quart_lows = []
  50. for year, year_stats in congress_stats.items():
  51. congress_years.append(year)
  52. congress_medians.append(year_stats.median)
  53. congress_quart_highs.append(year_stats.quart_high)
  54. congress_quart_lows.append(year_stats.quart_low)
  55. genpop_years = []
  56. genpop_medians = []
  57. genpop_quart_highs = []
  58. genpop_quart_lows = []
  59. for year, year_stats in genpop_stats.items():
  60. genpop_years.append(year)
  61. genpop_medians.append(year_stats.median)
  62. genpop_quart_highs.append(year_stats.quart_high)
  63. genpop_quart_lows.append(year_stats.quart_low)
  64. plt.fill_between(genpop_years, genpop_medians, genpop_quart_highs, color='b', alpha=0.3)
  65. plt.fill_between(genpop_years, genpop_medians, genpop_quart_lows, color='b', alpha=0.3)
  66. plt.plot(genpop_years, genpop_medians, 'b', label='U.S. Population')
  67. plt.fill_between(congress_years, congress_medians, congress_quart_highs, color='r', alpha=0.3)
  68. plt.fill_between(congress_years, congress_medians, congress_quart_lows, color='r', alpha=0.3)
  69. plt.plot(congress_years, congress_medians, 'r', label='Congress')
  70. plt.legend()
  71. plt.grid()
  72. plt.ylabel('Age')
  73. plt.xlabel('Year')
  74. plt.show()
  75. if __name__ == '__main__':
  76. stats_genpop = {}
  77. stats_congress = {}
  78. for year in range(1850, 2017):
  79. stats_congress[year] = get_stats_congress(year)
  80. people = pd.read_csv('usa_00001.csv', usecols=['YEAR', 'AGE', 'PERWT'], index_col='YEAR')
  81. for year in people.index.unique():
  82. stats_genpop[year] = get_stats_genpop(people.loc[year])
  83. plot_yearly_stats(stats_congress, stats_genpop)