Skip to content

Commit

Permalink
chore: gui b4 presentation
Browse files Browse the repository at this point in the history
  • Loading branch information
dvdblk committed Nov 26, 2023
1 parent 8b47a26 commit 8c46f5f
Show file tree
Hide file tree
Showing 10 changed files with 2,920 additions and 470 deletions.
129 changes: 93 additions & 36 deletions app/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Sequence, Type
import json
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, Union

from langchain.callbacks import get_openai_callback
from langchain.callbacks.manager import CallbackManagerForToolRun
Expand All @@ -7,7 +8,7 @@
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool, format_tool_to_openai_function

from app.preprocessing.adobe.model import Section
from app.preprocessing.adobe.model import Document, Section
from app.prompts import (
create_summaries_prompt_template,
refine_answer_prompt_template,
Expand All @@ -33,10 +34,6 @@ class FetchSectionsTool(BaseTool):
description = "fetches an entire section or sections from a document that might contain an answer to the question"
args_schema: Type[FetchSectionsSchema] = FetchSectionsSchema

def __init__(self, section_summaries: SectionSummaryDict, *args, **kwargs):
self.section_summaries = section_summaries
super().__init__(*args, **kwargs)

def _run(
self,
reasoning: str,
Expand All @@ -45,6 +42,8 @@ def _run(
**kwargs,
) -> Sequence[str]:
"""Use the tool."""
# FIXME: can't add section_summaries to self, don't use this
raise NotImplementedError()
sections = []
section_mapper = lambda s_id: self.section_summaries[s_id]

Expand Down Expand Up @@ -112,11 +111,73 @@ def wrapper(self, *args, **kwargs):
self.n_prompt_tokens += cb.prompt_tokens
self.n_completion_tokens += cb.completion_tokens
self.total_cost += cb.total_cost
print(cb)
return result

return wrapper


def document_to_structured_metadata(section, section_summaries):
"""Convert document to structured metadata"""
# Check if the document is the root node
if section.section_type == "document":
return {
"document": {
"title": section.title,
"sections": [
document_to_structured_metadata(section, section_summaries)
for section in section.subsections
],
}
}
else:
# find section from section summaries
section_summary = section_summaries[section.id]
result = {
"title": section.title_clean,
"id": section.id,
"pages": sorted(section.pages),
"summary": section_summary,
}
if subsections := [
document_to_structured_metadata(subsection, section_summaries)
for subsection in section.subsections
]:
result["sections"] = subsections

return result


def parse_function_output(response, document) -> Union[List[Section], str]:
# Get the function call
fn_call = response.additional_kwargs.get("function_call")

# Check if the response content is empty and that there is a function call
if response.content == "" and fn_call is not None:
# Get the attributes of the function call
tool_name = fn_call["name"]
tool_args = json.loads(fn_call["arguments"])
# Get the correct tool from the tools list
# tool = next(filter(lambda x: x.name == tool_name, tools))
# fn_output = tool._run(**tool_args)

if tool_name == "fetch_sections" and "section_ids" in tool_args:
sections = []
# get full section text from document
for section_id in sorted(tool_args["section_ids"]):
if section := document.get_section_by_id(section_id):
result = {
"title": section.title_clean,
"id": section.id,
"text": section.text,
}
sections.append(result)
return sections
else:
# Otherwise return the content
return response.content


class OpenAIPromptExecutor:
"""Executes all pre-defined prompts (chains) while keeping track of OpenAI costs."""

Expand All @@ -137,14 +198,16 @@ def temp(self, question: str) -> str:
return response.content

@track_costs
def create_summaries_chain(self, sections: List[Section]) -> SectionSummaryDict:
def create_summaries_chain(
self, sections: List[Section]
) -> Generator[float, List[Section], None]:
"""Create summaries for all sections in the document
Args:
sections (List[Section]): the sections to summarize
Returns:
SectionSummaryDict: a dictionary containing the summaries for each section id
Generator: yields a dictionary containing the summaries for each section id
"""

# key = section id, value = summary
Expand All @@ -156,6 +219,7 @@ def create_summaries_chain(self, sections: List[Section]) -> SectionSummaryDict:
)
# Generate summaries for each section
i = 0
max_i = len(sections)
for section in sections:
section_text = section.paragraph_text

Expand All @@ -165,18 +229,16 @@ def create_summaries_chain(self, sections: List[Section]) -> SectionSummaryDict:
{"section_title": section.title_clean, "section_text": section_text}
)
summary_dict[section.id] = response.summary
i += 1
else:
summary_dict[section.id] = None

if i == 20:
break

return summary_dict
i += 1
yield (i / max_i, summary_dict)

@track_costs
def generic_question_chain(
self,
document: Document,
section_summaries: SectionSummaryDict,
question: str,
):
Expand All @@ -190,34 +252,29 @@ def generic_question_chain(
fetch_sections_response = self.llm.invoke(
structured_metadata_prompt_template.format(
question=question,
section_summaries=section_summaries,
openai_functions=openai_functions,
document_structural_metadata=document_to_structured_metadata(
document, section_summaries
),
),
functions=openai_functions,
)

import json

# Refine all sections into one answer if there are more than 1 section returned by the chain above
def parse_function_output(response) -> str:
# Get the function call
fn_call = response.additional_kwargs.get("function_call")

# Check if the response content is empty and that there is a function call
if response.content == "" and fn_call is not None:
# Get the attributes of the function call
tool_name = fn_call["name"]
tool_args = json.loads(fn_call["arguments"])
# Get the correct tool from the tools list
tool = next(filter(lambda x: x.name == tool_name, tools))
fn_output = tool._run(**tool_args)
return fn_output
else:
# Otherwise return the content
return response.content

fetched_sections = parse_function_output(fetch_sections_response)

fetched_sections = parse_function_output(fetch_sections_response, document)
if isinstance(fetched_sections, list):
if len(fetched_sections) == 0:
# No sections were fetched by LLM, return generic response
return RefineIO(
intermediate_answer="The answer could not be determined from the given context and question 😓",
section_ids=[],
)
elif isinstance(fetched_sections, str):
# LLM returned an unexpected response.content (should have been null + function call but wasn't)
return RefineIO(
intermediate_answer=fetched_sections,
section_ids=[],
)
print(fetched_sections)
refine_io = RefineIO(intermediate_answer="", section_ids=[])
refine_answer_runnable = create_structured_output_runnable(
RefineIO, self.llm, refine_answer_prompt_template
Expand Down
Loading

0 comments on commit 8c46f5f

Please sign in to comment.