Skip to content

Commit

Permalink
Merge pull request #106 from jdebacker/tcja_ext
Browse files Browse the repository at this point in the history
Allow for alternative policy baselines
  • Loading branch information
jdebacker authored Apr 24, 2024
2 parents 4e1605e + 8229008 commit 300912a
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 29 deletions.
181 changes: 181 additions & 0 deletions examples/run_og_usa_current_policy_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import multiprocessing
from distributed import Client
import os
import requests
import json
import time
from taxcalc import Calculator
from ogusa.calibrate import Calibration
from ogcore.parameters import Specifications
from ogcore import output_tables as ot
from ogcore import output_plots as op
from ogcore.execute import runner
from ogcore.utils import safe_read_pickle


def main():
# Define parameters to use for multiprocessing
num_workers = min(multiprocessing.cpu_count(), 7)
client = Client(n_workers=num_workers, threads_per_worker=1)
print("Number of workers = ", num_workers)

# Directories to save data
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
base_dir = os.path.join(CUR_DIR, "OG-USA-CP-Example", "OUTPUT_BASELINE")
reform_dir = os.path.join(CUR_DIR, "OG-USA-CP-Example", "OUTPUT_REFORM")

"""
------------------------------------------------------------------------
Run baseline policy
------------------------------------------------------------------------
"""
# Set up baseline parameterization
p = Specifications(
baseline=True,
num_workers=num_workers,
baseline_dir=base_dir,
output_base=base_dir,
)
# Update parameters for baseline from default json file
p.update_specifications(
json.load(
open(
os.path.join(
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
)
)
)
)
p.tax_func_type = "GS"
# get current policy JSON file
base_url = (
"github://PSLmodels:Tax-Calculator@master/taxcalc/"
+ "reforms/ext.json"
)
ref = Calculator.read_json_param_objects(base_url, None)
iit_baseline = ref["policy"]
c = Calibration(
p,
estimate_tax_functions=True,
iit_baseline=iit_baseline,
client=client,
)
# close and delete client bc cache is too large
client.close()
del client
client = Client(n_workers=num_workers, threads_per_worker=1)
d = c.get_dict()
# # additional parameters to change
updated_params = {
"etr_params": d["etr_params"],
"mtrx_params": d["mtrx_params"],
"mtry_params": d["mtry_params"],
"mean_income_data": d["mean_income_data"],
"frac_tax_payroll": d["frac_tax_payroll"],
}
p.update_specifications(updated_params)
# Run model
start_time = time.time()
runner(p, time_path=True, client=client)
print("run time = ", time.time() - start_time)

"""
------------------------------------------------------------------------
Run reform policy
------------------------------------------------------------------------
"""
# Grab a reform JSON file already in Tax-Calculator
# In this example the 'reform' is a change to 2017 law
reform_url = (
"github://PSLmodels:Tax-Calculator@master/taxcalc/"
+ "reforms/2017_law.json"
)
ref = Calculator.read_json_param_objects(reform_url, None)
iit_reform = ref["policy"]

# create new Specifications object for reform simulation
p2 = Specifications(
baseline=False,
num_workers=num_workers,
baseline_dir=base_dir,
output_base=reform_dir,
)
# Update parameters for baseline from default json file
p2.update_specifications(
json.load(
open(
os.path.join(
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
)
)
)
)
p2.tax_func_type = "GS"
# Use calibration class to estimate reform tax functions from
# Tax-Calculator, specifying reform for Tax-Calculator in iit_reform
c2 = Calibration(
p2,
iit_baseline=iit_baseline,
iit_reform=iit_reform,
estimate_tax_functions=True,
client=client,
)
# close and delete client bc cache is too large
client.close()
del client
client = Client(n_workers=num_workers, threads_per_worker=1)
# update tax function parameters in Specifications Object
d = c2.get_dict()
# # additional parameters to change
updated_params = {
"cit_rate": [[0.35]],
"etr_params": d["etr_params"],
"mtrx_params": d["mtrx_params"],
"mtry_params": d["mtry_params"],
"mean_income_data": d["mean_income_data"],
"frac_tax_payroll": d["frac_tax_payroll"],
}
p2.update_specifications(updated_params)
# Run model
start_time = time.time()
runner(p2, time_path=True, client=client)
print("run time = ", time.time() - start_time)
client.close()

"""
------------------------------------------------------------------------
Save some results of simulations
------------------------------------------------------------------------
"""
base_tpi = safe_read_pickle(os.path.join(base_dir, "TPI", "TPI_vars.pkl"))
base_params = safe_read_pickle(os.path.join(base_dir, "model_params.pkl"))
reform_tpi = safe_read_pickle(
os.path.join(reform_dir, "TPI", "TPI_vars.pkl")
)
reform_params = safe_read_pickle(
os.path.join(reform_dir, "model_params.pkl")
)
ans = ot.macro_table(
base_tpi,
base_params,
reform_tpi=reform_tpi,
reform_params=reform_params,
var_list=["Y", "C", "K", "L", "r", "w"],
output_type="pct_diff",
num_years=10,
start_year=base_params.start_year,
)

# create plots of output
op.plot_all(
base_dir, reform_dir, os.path.join(CUR_DIR, "OG-USA_example_plots")
)

print("Percentage changes in aggregates:", ans)
# save percentage change output to csv file
ans.to_csv("ogusa_example_output.csv")


if __name__ == "__main__":
# execute only if run as a script
main()
36 changes: 32 additions & 4 deletions ogusa/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,34 @@ def __init__(
estimate_chi_n=False,
estimate_pop=False,
tax_func_path=None,
iit_baseline=None,
iit_reform={},
guid="",
data="cps",
client=None,
num_workers=1,
):
"""
Constructor for the Calibration class. This class is used to find
parameter values for the OG-USA model.
Args:
p (OGUSA Parameters object): parameters object
estimate_tax_functions (bool): whether to estimate tax functions
estimate_beta (bool): whether to estimate beta
estimate_chi_n (bool): whether to estimate chi_n
estimate_pop (bool): whether to estimate population
tax_func_path (str): path to tax function parameters
iit_baseline (dict): baseline policy to use
iit_reform (dict): reform tax parameters
guid (str): id for tax function parameters
data (str): data source for microsimulation model
client (Dask client object): client
num_workers (int): number of workers for Dask client
Returns:
Calibration class object instance
"""
self.estimate_tax_functions = estimate_tax_functions
self.estimate_beta = estimate_beta
self.estimate_chi_n = estimate_chi_n
Expand All @@ -36,6 +58,7 @@ def __init__(
run_micro = True
self.tax_function_params = self.get_tax_function_parameters(
p,
iit_baseline,
iit_reform,
guid,
data,
Expand Down Expand Up @@ -104,6 +127,7 @@ def __init__(
def get_tax_function_parameters(
self,
p,
iit_baseline=None,
iit_reform={},
guid="",
data="",
Expand All @@ -117,7 +141,13 @@ def get_tax_function_parameters(
parameters from microsimulation model output.
Args:
p (OG-Core Parameters object): parameters object
iit_baseline (dict): baseline policy to use
iit_reform (dict): reform tax parameters
guid (string): id for tax function parameters
data (string): data source for microsimulation model
client (Dask client object): client
num_workers (int): number of workers for Dask client
run_micro (bool): whether to estimate parameters from
microsimulation model
tax_func_path (string): path where find or save tax
Expand Down Expand Up @@ -152,7 +182,8 @@ def get_tax_function_parameters(
micro_data, taxcalc_version = get_micro_data.get_data(
baseline=p.baseline,
start_year=p.start_year,
reform=iit_reform,
iit_baseline=iit_baseline,
iit_reform=iit_reform,
data=data,
path=p.output_base,
client=client,
Expand All @@ -166,12 +197,9 @@ def get_tax_function_parameters(
p.starting_age,
p.ending_age,
start_year=p.start_year,
baseline=p.baseline,
analytical_mtrs=p.analytical_mtrs,
tax_func_type=p.tax_func_type,
age_specific=p.age_specific,
reform=iit_reform,
data=data,
client=client,
num_workers=num_workers,
tax_func_path=tax_func_path,
Expand Down
45 changes: 31 additions & 14 deletions ogusa/get_micro_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
def get_calculator(
baseline,
calculator_start_year,
reform=None,
iit_baseline=None,
iit_reform=None,
data=None,
gfactors=None,
weights=None,
Expand Down Expand Up @@ -72,17 +73,21 @@ def get_calculator(
records1 = Records() # pragma: no cover

if baseline:
if not reform:
if iit_baseline is None:
print("Running current law policy baseline")
else:
print("Baseline policy is: ", reform)
print("Baseline policy is: ", iit_baseline)
policy1.implement_reform(iit_baseline)
else:
if not reform:
if not iit_reform:
print("Running with current law as reform")
else:
print("Reform policy is: ", reform)
print("TYPE", type(reform))
policy1.implement_reform(reform)
print("Reform policy is: ", iit_reform)
if (
iit_baseline is not None
): # if alt baseline, stack reform on that
policy1.implement_reform(iit_baseline)
policy1.implement_reform(iit_reform)

# the default set up increments year to 2013
calc1 = Calculator(records=records1, policy=policy1)
Expand All @@ -97,7 +102,8 @@ def get_calculator(
def get_data(
baseline=False,
start_year=DEFAULT_START_YEAR,
reform={},
iit_baseline=None,
iit_reform={},
data=None,
path=CUR_PATH,
client=None,
Expand All @@ -112,7 +118,8 @@ def get_data(
Args:
baseline (boolean): True if baseline tax policy
calculator_start_year (int): first year of budget window
reform (dictionary): IIT policy reform parameters, None if
iit_baseline (dictionary): IIT policy parameters for baseline
iit_reform (dictionary): IIT policy reform parameters, None if
baseline
data (DataFrame or str): DataFrame or path to datafile for
Records object
Expand All @@ -132,7 +139,9 @@ def get_data(
lazy_values = []
for year in range(start_year, TC_LAST_YEAR + 1):
lazy_values.append(
delayed(taxcalc_advance)(baseline, start_year, reform, data, year)
delayed(taxcalc_advance)(
baseline, start_year, iit_baseline, iit_reform, data, year
)
)
if client: # pragma: no cover
futures = client.compute(lazy_values, num_workers=num_workers)
Expand Down Expand Up @@ -167,14 +176,21 @@ def get_data(
return micro_data_dict, taxcalc_version


def taxcalc_advance(baseline, start_year, reform, data, year):
def taxcalc_advance(
baseline, start_year, iit_baseline, iit_reform, data, year
):
"""
This function advances the year used in Tax-Calculator, compute
taxes and rates, and save the results to a dictionary.
Args:
calc1 (Tax-Calculator Calculator object): TC calculator
year (int): year to begin advancing from
baseline (boolean): True if baseline tax policy
start_year (int): first year of budget window
iit_baseline (dict): IIT policy parameters for baseline
iit_reform (dict): IIT policy reform parameters for reform
data (DataFrame or str): DataFrame or path to datafile for
Records object
year (int): year to advance to in Tax-Calculator
Returns:
tax_dict (dict): a dictionary of microdata with marginal tax
Expand All @@ -183,7 +199,8 @@ def taxcalc_advance(baseline, start_year, reform, data, year):
calc1 = get_calculator(
baseline=baseline,
calculator_start_year=start_year,
reform=reform,
iit_baseline=iit_baseline,
iit_reform=iit_reform,
data=data,
)
calc1.advance_to_year(year)
Expand Down
Loading

0 comments on commit 300912a

Please sign in to comment.