Skip to content

Commit

Permalink
update run script
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Aug 22, 2024
1 parent 5e93add commit 4076bd3
Showing 1 changed file with 19 additions and 45 deletions.
64 changes: 19 additions & 45 deletions examples/run_og_usa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import os
import json
import time
import importlib.resources
import copy
from taxcalc import Calculator
import matplotlib.pyplot as plt
from ogusa.calibrate import Calibration
Expand All @@ -28,8 +30,9 @@ def main():

# Directories to save data
CUR_DIR = os.path.dirname(os.path.realpath(__file__))
base_dir = os.path.join(CUR_DIR, "OG-USA-Example", "OUTPUT_BASELINE")
reform_dir = os.path.join(CUR_DIR, "OG-USA-Example", "OUTPUT_REFORM")
save_dir = os.path.join(CUR_DIR, "OG-USA-Example")
base_dir = os.path.join(save_dir, "OUTPUT_BASELINE")
reform_dir = os.path.join(save_dir, "OUTPUT_REFORM")

"""
------------------------------------------------------------------------
Expand All @@ -44,22 +47,13 @@ def main():
output_base=base_dir,
)
# Update parameters for baseline from default json file
p.update_specifications(
json.load(
open(
os.path.join(
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
)
)
)
)
p.tax_func_type = "GS"
p.age_specific = False
with importlib.resources.open_text(
"ogusa", "ogusa_default_parameters.json"
) as file:
defaults = json.load(file)
p.update_specifications(defaults)
p.tax_func_type = "HSV"
c = Calibration(p, estimate_tax_functions=True, client=client)
# close and delete client bc cache is too large
client.close()
del client
client = Client(n_workers=num_workers, threads_per_worker=1)
d = c.get_dict()
# # additional parameters to change
updated_params = {
Expand All @@ -84,43 +78,23 @@ def main():
# In this example the 'reform' is a change to 2017 law (the
# baseline policy is tax law in 2018)
reform_url = (
"github://PSLmodels:examples@main/psl_examples/"
+ "taxcalc/2017_law.json"
"github://PSLmodels:Tax-Calculator@master/taxcalc/"
+ "reforms/2017_law.json"
)

ref = Calculator.read_json_param_objects(reform_url, None)
iit_reform = ref["policy"]

# create new Specifications object for reform simulation
p2 = Specifications(
baseline=False,
num_workers=num_workers,
baseline_dir=base_dir,
output_base=reform_dir,
)
# Update parameters for baseline from default json file
p2.update_specifications(
json.load(
open(
os.path.join(
CUR_DIR, "..", "ogusa", "ogusa_default_parameters.json"
)
)
)
)
p2.tax_func_type = "GS"
p2.age_specific = False
p2 = copy.deepcopy(p)
# Use calibration class to estimate reform tax functions from
# Tax-Calculator, specifying reform for Tax-Calculator in iit_reform
c2 = Calibration(
p2, iit_reform=iit_reform, estimate_tax_functions=True, client=client
)
# close and delete client bc cache is too large
client.close()
del client
client = Client(n_workers=num_workers, threads_per_worker=1)
# update tax function parameters in Specifications Object
d = c2.get_dict()
# # additional parameters to change
# additional parameters to change
updated_params = {
"cit_rate": [[0.35]],
"etr_params": d["etr_params"],
Expand Down Expand Up @@ -164,7 +138,7 @@ def main():
op.plot_all(
base_dir,
reform_dir,
os.path.join(CUR_DIR, "OG-USA_example_plots_tables"),
os.path.join(save_dir, "OG-USA_example_plots_tables"),
)
# Create CSV file with output
ot.tp_output_dump_table(
Expand All @@ -174,7 +148,7 @@ def main():
reform_tpi,
table_format="csv",
path=os.path.join(
CUR_DIR,
save_dir,
"OG-USA_example_plots_tables",
"macro_time_series_output.csv",
),
Expand All @@ -184,7 +158,7 @@ def main():
# save percentage change output to csv file
ans.to_csv(
os.path.join(
CUR_DIR, "OG-USA_example_plots_tables", "ogusa_example_output.csv"
save_dir, "OG-USA_example_plots_tables", "ogusa_example_output.csv"
)
)

Expand Down

0 comments on commit 4076bd3

Please sign in to comment.