Skip to content

Commit

Permalink
Fix duplicate removal #1
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed May 26, 2023
1 parent f84c848 commit b7c94b0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def generate_sql(self, table_name: str, input_question: str, _dialect: str = "po
_tasks=_tasks.lower(),
)

logger.debug(f"Query Prompt:\n{query_str}")
table_context_dict = {str(table_name[0]).lower(): str(additional_context).lower()}
self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict)

Expand Down
10 changes: 5 additions & 5 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def remove_duplicates(input_x: list, model_path: str, threshold: float = 0.89):
# Remove duplicates pairs
embeddings = generate_sentence_embeddings(model_path, x=input_x, device="cpu")
similarity_scores = compute_similarity(embeddings)

similar_indices = [(x, y) for (x, y) in np.argwhere(similarity_scores > threshold) if x != y]

# Remove identical pairs e.g. [(0, 3), (3, 0)] -> [(0, 3)]
si = [sum(tpl) for tpl in similar_indices]
si = [similarity_scores[tpl] for tpl in similar_indices]
dup_pairs_idx = np.where(pd.Series(si).duplicated())[0].tolist()
for _itm in dup_pairs_idx:
similar_indices.pop(_itm)
res = list(set([item[1] for item in similar_indices]))
remove_vals = [similar_indices[_itm] for _itm in dup_pairs_idx]
[similar_indices.remove(_itm) for _itm in remove_vals]
res = list(set([item[0] for item in similar_indices]))
return res

0 comments on commit b7c94b0

Please sign in to comment.