Skip to content

Commit

Permalink
Fix spurious AgentConfig member reference (#277)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrofin authored Jul 12, 2023
1 parent 5bf8776 commit b82deb3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion compiler_opt/rl/distributed/ppo_collect_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def collect(corpus_path: str, replay_buffer_server_address: str,
agent_cfg = agent_config.DistributedPPOAgentConfig(
time_step_spec=time_step_spec, action_spec=action_spec)
agent = agent_config.create_agent(
agent_cfg.agent,
agent_cfg,
preprocessing_layer_creator=problem_config
.get_preprocessing_layer_creator())

Expand Down
2 changes: 1 addition & 1 deletion compiler_opt/rl/distributed/ppo_eval_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def evaluate(root_dir: str, corpus_path: str,
agent_cfg = agent_config.DistributedPPOAgentConfig(
time_step_spec=time_step_spec, action_spec=action_spec)
agent = agent_config.create_agent(
agent_cfg.agent,
agent_cfg,
preprocessing_layer_creator=problem_config
.get_preprocessing_layer_creator())

Expand Down
2 changes: 1 addition & 1 deletion compiler_opt/rl/train_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def train_eval(worker_manager_class=LocalWorkerPoolManager,
agent_cfg = agent_config_type(
time_step_spec=time_step_spec, action_spec=action_spec)
agent: tf_agent.TFAgent = agent_config.create_agent(
agent_cfg.agent, preprocessing_layer_creator=preprocessing_layer_creator)
agent_cfg, preprocessing_layer_creator=preprocessing_layer_creator)
# create the random network distillation object
random_network_distillation = None
if use_random_network_distillation:
Expand Down

0 comments on commit b82deb3

Please sign in to comment.