Skip to content

Commit

Permalink
Merge pull request #8 from huawei-noah/xt_release_0.3.0
Browse files Browse the repository at this point in the history
fix muzero
  • Loading branch information
hustqj authored Jan 4, 2021
2 parents 20a0477 + 1c848bc commit bc6c966
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions xt/agent/muzero/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,32 @@ def sync_model(self):
break

return ret_model_name

def run_one_episode(self, use_explore, need_collect):
"""
Do interaction with max steps in each episode.
:param use_explore:
:param need_collect: if collect the total transition of each episode.
:return:
"""
# clear the old trajectory data
self.clear_trajectory()
state = self.env.get_init_state(self.id)

self._stats.reset()

for _ in range(self.max_step):
self.clear_transition()
state = self.do_one_interaction(state, use_explore)

if need_collect:
self.add_to_trajectory(self.transition_data)

if self.transition_data["done"]:
if not self.keep_seq_len:
break
self.env.reset()
state = self.env.get_init_state()

return self.get_trajectory()

0 comments on commit bc6c966

Please sign in to comment.