From 2b9e27f464d00e38a4730766effef07e6370e772 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Sat, 1 Feb 2025 10:49:30 -0800 Subject: [PATCH] refactor --- .../stream_processing/citation_processing.py | 204 +++++++++--------- 1 file changed, 97 insertions(+), 107 deletions(-) diff --git a/backend/onyx/chat/stream_processing/citation_processing.py b/backend/onyx/chat/stream_processing/citation_processing.py index 071b28c3457..5386f8ce67e 100644 --- a/backend/onyx/chat/stream_processing/citation_processing.py +++ b/backend/onyx/chat/stream_processing/citation_processing.py @@ -62,13 +62,10 @@ def process_token( self.llm_out += token # Handle code blocks without language tags - if "`" in self.curr_segment: - if self.curr_segment.endswith("`"): - pass - elif "```" in self.curr_segment: - piece_that_comes_after = self.curr_segment.split("```")[1][0] - if piece_that_comes_after == "\n" and in_code_block(self.llm_out): - self.curr_segment = self.curr_segment.replace("```", "```plaintext") + if "```" in self.curr_segment and not self.curr_segment.endswith("`"): + piece_that_comes_after = self.curr_segment.split("```")[1][0] + if piece_that_comes_after == "\n" and in_code_block(self.llm_out): + self.curr_segment = self.curr_segment.replace("```", "```plaintext") citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc. citations_found = list(re.finditer(citation_pattern, self.curr_segment)) @@ -80,115 +77,108 @@ def process_token( if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5: self.current_citations = [] - result = "" + self.result = "" if citations_found and not in_code_block(self.llm_out): - last_citation_end = 0 - length_to_add = 0 - while len(citations_found) > 0: - citation = citations_found.pop(0) - numerical_value = int( - next(group for group in citation.groups() if group is not None) - ) + yield from self.process_found_citations(citations_found) - if 1 <= numerical_value <= self.max_citation_num: - context_llm_doc = self.context_docs[numerical_value - 1] - final_citation_num = self.final_order_mapping[ - context_llm_doc.document_id - ] - - if final_citation_num not in self.citation_order: - self.citation_order.append(final_citation_num) - - citation_order_idx = ( - self.citation_order.index(final_citation_num) + 1 - ) - - # get the value that was displayed to user, should always - # be in the display_doc_order_dict. But check anyways - if context_llm_doc.document_id in self.display_order_mapping: - displayed_citation_num = self.display_order_mapping[ - context_llm_doc.document_id - ] - else: - displayed_citation_num = final_citation_num - logger.warning( - f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." - ) + if not possible_citation_found: + self.result += self.curr_segment + self.curr_segment = "" - # Skip consecutive citations of the same work - if final_citation_num in self.current_citations: - start, end = citation.span() - real_start = length_to_add + start - diff = end - start - self.curr_segment = ( - self.curr_segment[: length_to_add + start] - + self.curr_segment[real_start + diff :] - ) - length_to_add -= diff - continue - - # Handle edge case where LLM outputs citation itself - if self.curr_segment.startswith("[["): - match = re.match(r"\[\[(\d+)\]\]", self.curr_segment) - if match: - try: - doc_id = int(match.group(1)) - context_llm_doc = self.context_docs[doc_id - 1] - yield CitationInfo( - # citation_num is now the number post initial ranking, i.e. as displayed to user - citation_num=displayed_citation_num, - document_id=context_llm_doc.document_id, - ) - except Exception as e: - logger.warning( - f"Manual LLM citation didn't properly cite documents {e}" - ) - else: - logger.warning( - "Manual LLM citation wasn't able to close brackets" - ) - continue - - link = context_llm_doc.link - - self.past_cite_count = len(self.llm_out) - self.current_citations.append(final_citation_num) - - if citation_order_idx not in self.cited_inds: - self.cited_inds.add(citation_order_idx) + if self.result: + yield OnyxAnswerPiece(answer_piece=self.result) + + def process_found_citations( + self, citations_found: list[re.Match[str]] + ) -> Generator[CitationInfo, None, None]: + last_citation_end = 0 + length_to_add = 0 + while len(citations_found) > 0: + citation = citations_found.pop(0) + numerical_value = int( + next(group for group in citation.groups() if group is not None) + ) + + if numerical_value <= 0 or numerical_value > self.max_citation_num: + continue + + context_llm_doc = self.context_docs[numerical_value - 1] + final_citation_num = self.final_order_mapping[context_llm_doc.document_id] + + if final_citation_num not in self.citation_order: + self.citation_order.append(final_citation_num) + + citation_order_idx = self.citation_order.index(final_citation_num) + 1 + + # get the value that was displayed to user, should always + # be in the display_doc_order_dict. But check anyways + if context_llm_doc.document_id in self.display_order_mapping: + displayed_citation_num = self.display_order_mapping[ + context_llm_doc.document_id + ] + else: + displayed_citation_num = final_citation_num + logger.warning( + f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." + ) + + # Skip consecutive citations of the same work + if final_citation_num in self.current_citations: + start, end = citation.span() + real_start = length_to_add + start + diff = end - start + self.curr_segment = ( + self.curr_segment[: length_to_add + start] + + self.curr_segment[real_start + diff :] + ) + length_to_add -= diff + continue + + # Handle edge case where LLM outputs citation itself + if self.curr_segment.startswith("[["): + match = re.match(r"\[\[(\d+)\]\]", self.curr_segment) + if match: + try: + doc_id = int(match.group(1)) + context_llm_doc = self.context_docs[doc_id - 1] yield CitationInfo( - # citation number is now the one that was displayed to user + # citation_num is now the number post initial ranking, i.e. as displayed to user citation_num=displayed_citation_num, document_id=context_llm_doc.document_id, ) - - start, end = citation.span() - if link: - prev_length = len(self.curr_segment) - self.curr_segment = ( - self.curr_segment[: start + length_to_add] - + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user - + self.curr_segment[end + length_to_add :] - ) - length_to_add += len(self.curr_segment) - prev_length - else: - prev_length = len(self.curr_segment) - self.curr_segment = ( - self.curr_segment[: start + length_to_add] - + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user - + self.curr_segment[end + length_to_add :] + except Exception as e: + logger.warning( + f"Manual LLM citation didn't properly cite documents {e}" ) - length_to_add += len(self.curr_segment) - prev_length + else: + logger.warning("Manual LLM citation wasn't able to close brackets") + continue - last_citation_end = end + length_to_add + link = context_llm_doc.link - if last_citation_end > 0: - result += self.curr_segment[:last_citation_end] - self.curr_segment = self.curr_segment[last_citation_end:] + self.past_cite_count = len(self.llm_out) + self.current_citations.append(final_citation_num) - if not possible_citation_found: - result += self.curr_segment - self.curr_segment = "" + if citation_order_idx not in self.cited_inds: + self.cited_inds.add(citation_order_idx) + yield CitationInfo( + # citation number is now the one that was displayed to user + citation_num=displayed_citation_num, + document_id=context_llm_doc.document_id, + ) - if result: - yield OnyxAnswerPiece(answer_piece=result) + start, end = citation.span() + prev_length = len(self.curr_segment) + link_str = link or "" + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + + f"[[{displayed_citation_num}]]({link_str})" # use the value that was displayed to user + + self.curr_segment[end + length_to_add :] + ) + length_to_add += len(self.curr_segment) - prev_length + + last_citation_end = end + length_to_add + + if last_citation_end > 0: + self.result += self.curr_segment[:last_citation_end] + self.curr_segment = self.curr_segment[last_citation_end:]