-
-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #106 from jdebacker/tcja_ext
Allow for alternative policy baselines
- Loading branch information
Showing
4 changed files
with
255 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import multiprocessing | ||
from distributed import Client | ||
import os | ||
import requests | ||
import json | ||
import time | ||
from taxcalc import Calculator | ||
from ogusa.calibrate import Calibration | ||
from ogcore.parameters import Specifications | ||
from ogcore import output_tables as ot | ||
from ogcore import output_plots as op | ||
from ogcore.execute import runner | ||
from ogcore.utils import safe_read_pickle | ||
|
||
|
||
def main(): | ||
# Define parameters to use for multiprocessing | ||
num_workers = min(multiprocessing.cpu_count(), 7) | ||
client = Client(n_workers=num_workers, threads_per_worker=1) | ||
print("Number of workers = ", num_workers) | ||
|
||
# Directories to save data | ||
CUR_DIR = os.path.dirname(os.path.realpath(__file__)) | ||
base_dir = os.path.join(CUR_DIR, "OG-USA-CP-Example", "OUTPUT_BASELINE") | ||
reform_dir = os.path.join(CUR_DIR, "OG-USA-CP-Example", "OUTPUT_REFORM") | ||
|
||
""" | ||
------------------------------------------------------------------------ | ||
Run baseline policy | ||
------------------------------------------------------------------------ | ||
""" | ||
# Set up baseline parameterization | ||
p = Specifications( | ||
baseline=True, | ||
num_workers=num_workers, | ||
baseline_dir=base_dir, | ||
output_base=base_dir, | ||
) | ||
# Update parameters for baseline from default json file | ||
p.update_specifications( | ||
json.load( | ||
open( | ||
os.path.join( | ||
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json" | ||
) | ||
) | ||
) | ||
) | ||
p.tax_func_type = "GS" | ||
# get current policy JSON file | ||
base_url = ( | ||
"github://PSLmodels:Tax-Calculator@master/taxcalc/" | ||
+ "reforms/ext.json" | ||
) | ||
ref = Calculator.read_json_param_objects(base_url, None) | ||
iit_baseline = ref["policy"] | ||
c = Calibration( | ||
p, | ||
estimate_tax_functions=True, | ||
iit_baseline=iit_baseline, | ||
client=client, | ||
) | ||
# close and delete client bc cache is too large | ||
client.close() | ||
del client | ||
client = Client(n_workers=num_workers, threads_per_worker=1) | ||
d = c.get_dict() | ||
# # additional parameters to change | ||
updated_params = { | ||
"etr_params": d["etr_params"], | ||
"mtrx_params": d["mtrx_params"], | ||
"mtry_params": d["mtry_params"], | ||
"mean_income_data": d["mean_income_data"], | ||
"frac_tax_payroll": d["frac_tax_payroll"], | ||
} | ||
p.update_specifications(updated_params) | ||
# Run model | ||
start_time = time.time() | ||
runner(p, time_path=True, client=client) | ||
print("run time = ", time.time() - start_time) | ||
|
||
""" | ||
------------------------------------------------------------------------ | ||
Run reform policy | ||
------------------------------------------------------------------------ | ||
""" | ||
# Grab a reform JSON file already in Tax-Calculator | ||
# In this example the 'reform' is a change to 2017 law | ||
reform_url = ( | ||
"github://PSLmodels:Tax-Calculator@master/taxcalc/" | ||
+ "reforms/2017_law.json" | ||
) | ||
ref = Calculator.read_json_param_objects(reform_url, None) | ||
iit_reform = ref["policy"] | ||
|
||
# create new Specifications object for reform simulation | ||
p2 = Specifications( | ||
baseline=False, | ||
num_workers=num_workers, | ||
baseline_dir=base_dir, | ||
output_base=reform_dir, | ||
) | ||
# Update parameters for baseline from default json file | ||
p2.update_specifications( | ||
json.load( | ||
open( | ||
os.path.join( | ||
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json" | ||
) | ||
) | ||
) | ||
) | ||
p2.tax_func_type = "GS" | ||
# Use calibration class to estimate reform tax functions from | ||
# Tax-Calculator, specifying reform for Tax-Calculator in iit_reform | ||
c2 = Calibration( | ||
p2, | ||
iit_baseline=iit_baseline, | ||
iit_reform=iit_reform, | ||
estimate_tax_functions=True, | ||
client=client, | ||
) | ||
# close and delete client bc cache is too large | ||
client.close() | ||
del client | ||
client = Client(n_workers=num_workers, threads_per_worker=1) | ||
# update tax function parameters in Specifications Object | ||
d = c2.get_dict() | ||
# # additional parameters to change | ||
updated_params = { | ||
"cit_rate": [[0.35]], | ||
"etr_params": d["etr_params"], | ||
"mtrx_params": d["mtrx_params"], | ||
"mtry_params": d["mtry_params"], | ||
"mean_income_data": d["mean_income_data"], | ||
"frac_tax_payroll": d["frac_tax_payroll"], | ||
} | ||
p2.update_specifications(updated_params) | ||
# Run model | ||
start_time = time.time() | ||
runner(p2, time_path=True, client=client) | ||
print("run time = ", time.time() - start_time) | ||
client.close() | ||
|
||
""" | ||
------------------------------------------------------------------------ | ||
Save some results of simulations | ||
------------------------------------------------------------------------ | ||
""" | ||
base_tpi = safe_read_pickle(os.path.join(base_dir, "TPI", "TPI_vars.pkl")) | ||
base_params = safe_read_pickle(os.path.join(base_dir, "model_params.pkl")) | ||
reform_tpi = safe_read_pickle( | ||
os.path.join(reform_dir, "TPI", "TPI_vars.pkl") | ||
) | ||
reform_params = safe_read_pickle( | ||
os.path.join(reform_dir, "model_params.pkl") | ||
) | ||
ans = ot.macro_table( | ||
base_tpi, | ||
base_params, | ||
reform_tpi=reform_tpi, | ||
reform_params=reform_params, | ||
var_list=["Y", "C", "K", "L", "r", "w"], | ||
output_type="pct_diff", | ||
num_years=10, | ||
start_year=base_params.start_year, | ||
) | ||
|
||
# create plots of output | ||
op.plot_all( | ||
base_dir, reform_dir, os.path.join(CUR_DIR, "OG-USA_example_plots") | ||
) | ||
|
||
print("Percentage changes in aggregates:", ans) | ||
# save percentage change output to csv file | ||
ans.to_csv("ogusa_example_output.csv") | ||
|
||
|
||
if __name__ == "__main__": | ||
# execute only if run as a script | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.