Skip to content

Commit

Permalink
fix planner refelection bug (microsoft#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShilinHe authored Feb 6, 2024
2 parents e3f6902 + 10c919f commit 0ebb59b
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def compose_conversation_for_prompt(
conversation.append(
format_chat_message(
role="user",
message="User: " + post.message,
message="User: " + post.get_attachment(type=AttachmentType.revise_message)[0],
),
) # append the self correction instruction message to chat history

Expand Down Expand Up @@ -249,26 +249,27 @@ def reply(
else:
selected_experiences = None

new_post = self.event_emitter.create_post_proxy("Planner")
post_proxy = self.event_emitter.create_post_proxy("Planner")

new_post.update_status("composing prompt")
post_proxy.update_status("composing prompt")
chat_history = self.compose_prompt(rounds, selected_experiences)

def check_post_validity(post: Post):
assert post.send_to is not None, "send_to field is None"
assert post.send_to != "Planner", "send_to field should not be Planner"
assert post.message is not None, "message field is None"
assert post.send_to is not None, "LLM failed to generate send_to field"
assert post.send_to != "Planner", "LLM failed to generate correct send_to field: Planner"
assert post.message is not None, "LLM failed to generate message field"
assert len(post.attachment_list) == 3, "LLM failed to generate complete attachments"
assert (
post.attachment_list[0].type == AttachmentType.init_plan
), f"attachment type {post.attachment_list[0].type} is not init_plan"
), f"LLM failed to generate correct attachment type {post.attachment_list[0].type}: init_plan"
assert (
post.attachment_list[1].type == AttachmentType.plan
), f"attachment type {post.attachment_list[1].type} is not plan"
), f"LLM failed to generate correct attachment type {post.attachment_list[1].type}: plan"
assert (
post.attachment_list[2].type == AttachmentType.current_plan_step
), "attachment type is not current_plan_step"
), "LLM failed to generate correct attachment type: current_plan_step"

new_post.update_status("calling LLM endpoint")
post_proxy.update_status("calling LLM endpoint")
if self.config.skip_planning and rounds[-1].post_list[-1].send_from == "User":
self.config.dummy_plan["response"][0]["content"] += rounds[-1].post_list[-1].message
llm_stream = [
Expand All @@ -289,7 +290,7 @@ def stream_filter(s: Iterable[ChatMessageType]):
try:
for c in s:
if is_first_chunk:
new_post.update_status("receiving LLM response")
post_proxy.update_status("receiving LLM response")
is_first_chunk = False
llm_output.append(c["content"])
yield c
Expand All @@ -301,32 +302,35 @@ def stream_filter(s: Iterable[ChatMessageType]):
pass

self.planner_post_translator.raw_text_to_post(
post_proxy=new_post,
post_proxy=post_proxy,
llm_output=stream_filter(llm_stream),
validation_func=check_post_validity,
)

except (JSONDecodeError, AssertionError) as e:
self.logger.error(f"Failed to parse LLM output due to {str(e)}")
new_post.error(f"failed to parse LLM output due to {str(e)}")
new_post.update_attachment(
post_proxy.error(f"failed to parse LLM output due to {str(e)}")
post_proxy.update_attachment(
"".join(llm_output),
AttachmentType.invalid_response,
)
new_post.update_message(
post_proxy.update_attachment(
f"Failed to parse Planner output due to {str(e)}."
f"The output format should follow the below format:"
f"{self.prompt_data['planner_response_schema']}"
"Please try to regenerate the output.",
AttachmentType.revise_message,
)
new_post.update_send_to("Planner")
self.ask_self_cnt += 1
if self.ask_self_cnt > self.max_self_ask_num: # if ask self too many times, return error message
self.ask_self_cnt = 0
new_post.end("Planner failed to generate response")
post_proxy.end(f"Planner failed to generate response because {str(e)}")
raise Exception(f"Planner failed to generate response because {str(e)}")
else:
post_proxy.update_send_to("Planner")
self.ask_self_cnt += 1
if prompt_log_path is not None:
self.logger.dump_log_file(chat_history, prompt_log_path)
return new_post.end()
return post_proxy.end()

def get_examples(self) -> List[Conversation]:
example_conv_list = load_examples(self.config.example_base_path)
Expand Down

0 comments on commit 0ebb59b

Please sign in to comment.