From b7c94b0203199ea43ac1057caa9e01975ebe76c3 Mon Sep 17 00:00:00 2001 From: pramitchoudhary Date: Thu, 25 May 2023 20:57:42 -0700 Subject: [PATCH] Fix duplicate removal #1 --- sidekick/query.py | 1 + sidekick/utils.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sidekick/query.py b/sidekick/query.py index 4338d5b..1b02f7f 100644 --- a/sidekick/query.py +++ b/sidekick/query.py @@ -158,6 +158,7 @@ def generate_sql(self, table_name: str, input_question: str, _dialect: str = "po _tasks=_tasks.lower(), ) + logger.debug(f"Query Prompt:\n{query_str}") table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()} self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict) diff --git a/sidekick/utils.py b/sidekick/utils.py index 84fce89..6782aa5 100644 --- a/sidekick/utils.py +++ b/sidekick/utils.py @@ -31,12 +31,12 @@ def remove_duplicates(input_x: list, model_path: str, threshold: float = 0.89): # Remove duplicates pairs embeddings = generate_sentence_embeddings(model_path, x=input_x, device="cpu") similarity_scores = compute_similarity(embeddings) - similar_indices = [(x, y) for (x, y) in np.argwhere(similarity_scores > threshold) if x != y] + # Remove identical pairs e.g. [(0, 3), (3, 0)] -> [(0, 3)] - si = [sum(tpl) for tpl in similar_indices] + si = [similarity_scores[tpl] for tpl in similar_indices] dup_pairs_idx = np.where(pd.Series(si).duplicated())[0].tolist() - for _itm in dup_pairs_idx: - similar_indices.pop(_itm) - res = list(set([item[1] for item in similar_indices])) + remove_vals = [similar_indices[_itm] for _itm in dup_pairs_idx] + [similar_indices.remove(_itm) for _itm in remove_vals] + res = list(set([item[0] for item in similar_indices])) return res