Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Citation Processor #3872

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 97 additions & 107 deletions backend/onyx/chat/stream_processing/citation_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ def process_token(
self.llm_out += token

# Handle code blocks without language tags
if "`" in self.curr_segment:
if self.curr_segment.endswith("`"):
pass
elif "```" in self.curr_segment:
piece_that_comes_after = self.curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
if "```" in self.curr_segment and not self.curr_segment.endswith("`"):
piece_that_comes_after = self.curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")

citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
Expand All @@ -80,115 +77,108 @@ def process_token(
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []

result = ""
self.result = ""
if citations_found and not in_code_block(self.llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(
next(group for group in citation.groups() if group is not None)
)
yield from self.process_found_citations(citations_found)

if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]

if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)

citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
)

# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
if not possible_citation_found:
self.result += self.curr_segment
self.curr_segment = ""

# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue

# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue

link = context_llm_doc.link

self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)

if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
if self.result:
yield OnyxAnswerPiece(answer_piece=self.result)

def process_found_citations(
self, citations_found: list[re.Match[str]]
) -> Generator[CitationInfo, None, None]:
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(
next(group for group in citation.groups() if group is not None)
)

if numerical_value <= 0 or numerical_value > self.max_citation_num:
continue

context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[context_llm_doc.document_id]

if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)

citation_order_idx = self.citation_order.index(final_citation_num) + 1

# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)

# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue

# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation number is now the one that was displayed to user
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)

start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
length_to_add += len(self.curr_segment) - prev_length
else:
logger.warning("Manual LLM citation wasn't able to close brackets")
continue

last_citation_end = end + length_to_add
link = context_llm_doc.link

if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]
self.curr_segment = self.curr_segment[last_citation_end:]
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)

if not possible_citation_found:
result += self.curr_segment
self.curr_segment = ""
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)

if result:
yield OnyxAnswerPiece(answer_piece=result)
start, end = citation.span()
prev_length = len(self.curr_segment)
link_str = link or ""
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link_str})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length

last_citation_end = end + length_to_add

if last_citation_end > 0:
self.result += self.curr_segment[:last_citation_end]
self.curr_segment = self.curr_segment[last_citation_end:]
Loading