diff --git a/kaggle_environments/envs/llm_20_questions/llm_20_questions.py b/kaggle_environments/envs/llm_20_questions/llm_20_questions.py index 64e0abbf..ee2ca629 100644 --- a/kaggle_environments/envs/llm_20_questions/llm_20_questions.py +++ b/kaggle_environments/envs/llm_20_questions/llm_20_questions.py @@ -24,6 +24,7 @@ DONE = "DONE" INACTIVE = "INACTIVE" ACTIVE = "ACTIVE" +TIMEOUT = "TIMEOUT" GUESS = "guess" ASK = "ask" @@ -98,15 +99,19 @@ def guesser_action(active, inactive, step): if active.action and keyword_guessed(active.action): guessed = True score = 20 - int(step / 3) - active.reward = score - inactive.reward = score - active.status = DONE - inactive.status = DONE - active.observation.keyword = keyword - active.observation.category = category + end_game(active, inactive, score, DONE, DONE) + return guessed + +def end_game(active, inactive, reward, status, inactive_status): + active.observation.keyword = keyword + active.observation.category = category inactive.observation.keyword = keyword inactive.observation.category = category - return guessed + active.reward = reward + inactive.reward = reward + active.status = status + inactive.status = inactive_status + def answerer_action(active, inactive): active.observation.keyword = keyword @@ -114,27 +119,20 @@ def answerer_action(active, inactive): response = active.action if not response: response = "none" - active.status = ERROR + end_game(active, inactive, -1, ERROR, DONE) elif "yes" in response.lower(): response = "yes" elif "no" in response.lower(): response = "no" else: response = "maybe" - active.status = ERROR + end_game(active, inactive, -1, ERROR, DONE) active.observation.answers.append(response) inactive.observation.answers.append(response) def increment_turn(active, inactive, step, guessed): if step == 59 and not guessed: - active.observation.keyword = keyword - active.observation.category = category - inactive.observation.keyword = keyword - inactive.observation.category = category - active.reward = -1 - inactive.reward = -1 - active.status = DONE - inactive.status = DONE + end_game(active, inactive, -1, DONE, DONE) elif active.observation.turnType == "guess": active.observation.turnType = "ask" elif active.observation.turnType == "ask": @@ -166,13 +164,20 @@ def interpreter(state, env): step = state[0].observation.step + end_early = (active1 and active1.status) in (TIMEOUT, ERROR) or (active2 and active2.status in (TIMEOUT, ERROR)) + if active1 is not None: guessed = False if active1.observation.role == GUESSER: guessed = guesser_action(active1, inactive1, step) else: answerer_action(active1, inactive1) - increment_turn(active1, inactive1, step, guessed) + if active1.status in (TIMEOUT, ERROR): + end_game(active1, inactive1, -1, active1.status, DONE) + elif end_early: + end_game(active1, inactive1, 0, DONE, DONE) + else: + increment_turn(active1, inactive1, step, guessed) if active2 is not None: guessed = False @@ -180,7 +185,12 @@ def interpreter(state, env): guessed = guesser_action(active2, inactive2, step) else: answerer_action(active2, inactive2) - increment_turn(active2, inactive2, step, guessed) + if active2.status in (TIMEOUT, ERROR): + end_game(active2, inactive2, -1, active2.status, DONE) + elif end_early: + end_game(active2, inactive2, 0, DONE, DONE) + else: + increment_turn(active2, inactive2, step, guessed) return state