diff --git a/src/crosswalk_table.csv b/src/crosswalk_table.csv new file mode 100644 index 0000000..d4ecea6 --- /dev/null +++ b/src/crosswalk_table.csv @@ -0,0 +1,4 @@ +Item 1,Item 2,Similarity Score +What is your age?,What is your age?,1.0 +How old are you?,How old are you?,1.0 +What is your name?,What is your name?,1.0 diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index 2dfd2da..91c19aa 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -1,9 +1,9 @@ """ MIT License -Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). -Project: Harmony (https://harmonydata.ac.uk) -Maintainer: Thomas Wood (https://fastdatascience.com) +Copyright (c) 2023 Ulster University +Project: Harmony +Maintainer: Thomas Wood Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -28,50 +28,24 @@ import heapq from collections import Counter, OrderedDict from typing import List, Callable +import pandas as pd # ADDED for Task 3 import numpy as np from numpy import dot, matmul, ndarray, matrix from numpy.linalg import norm +import os from harmony.matching.negator import negate from harmony.schemas.catalogue_instrument import CatalogueInstrument from harmony.schemas.catalogue_question import CatalogueQuestion -from harmony.schemas.requests.text import ( - Instrument, - Question, -) +from harmony.schemas.requests.text import Instrument, Question from harmony.schemas.text_vector import TextVector -import os - - -def get_batch_size(default=50): - try: - batch_size = int(os.getenv("BATCH_SIZE", default)) - return max(batch_size, 0) - except (ValueError, TypeError): - return default -def process_items_in_batches(items, llm_function): - batch_size = get_batch_size() - - if batch_size == 0: - return llm_function(items) - - - batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)] - - results = [] - for batch in batches: - batch_results = llm_function(batch) - results.extend(batch_results) - return results - def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray: dp = dot(vec1, vec2.T) m1 = matrix(norm(vec1, axis=1)) m2 = matrix(norm(vec2.T, axis=0)) - return np.asarray(dp / matmul(m1.T, m2)) @@ -104,65 +78,28 @@ def process_questions(questions, texts_cached_vectors): return text_vectors -def vectorise_texts(text_vectors, vectorisation_function): - for index, text_dict in enumerate(text_vectors): - if not text_dict.vector: - text_vectors[index].vector = vectorisation_function([text_dict.text]).tolist()[0] - return text_vectors - - -def vectors_pos_neg(text_vectors): - vectors_pos = np.array( - [ - x.vector - for x in text_vectors - if (x.is_negated is False and x.is_query is False) - ] - ) - - # Create numpy array of negated texts vectors - vectors_neg = np.array( - [ - x.vector - for x in text_vectors - if (x.is_negated is True and x.is_query is False) - ] - ) - return vectors_pos, vectors_neg - - def create_full_text_vectors( - all_questions: List[str], - query: str | None, - vectorisation_function: Callable, - texts_cached_vectors: dict[str, list[float]], + all_questions: List[str], + query: str | None, + vectorisation_function: Callable, + texts_cached_vectors: dict[str, list[float]], ) -> tuple[List[TextVector], dict]: """ Create full text vectors. """ - - # Create a list of text vectors text_vectors = process_questions(all_questions, texts_cached_vectors) - - # Add query if query: text_vectors = add_text_to_vec(query, texts_cached_vectors, text_vectors, False, True) - # Texts with no cached vector texts_not_cached = [x.text for x in text_vectors if not x.vector] - - # Get vectors for all texts not cached - new_vectors_list: List = process_items_in_batches(texts_not_cached, vectorisation_function) - + new_vectors_list: List = vectorisation_function(texts_not_cached).tolist() - # Create a dictionary with new vectors new_vectors_dict = {} for vector, text in zip(new_vectors_list, texts_not_cached): new_vectors_dict[text] = vector - # Add new vectors to all_texts for index, text_dict in enumerate(text_vectors): if not text_dict.vector: text_vectors[index].vector = new_vectors_list.pop(0) @@ -170,30 +107,62 @@ def create_full_text_vectors( return text_vectors, new_vectors_dict +# ADDED: Crosswalk Table Function +def generate_crosswalk_table(matches, similarity_scores): + """ + Generate a crosswalk table from matched item pairs and their similarity scores. + + Args: + matches (list of tuple): List of matched item pairs as (item1, item2). + similarity_scores (list of float): List of similarity scores for each pair. + + Returns: + pd.DataFrame: A DataFrame representing the crosswalk table. + """ + if len(matches) != len(similarity_scores): + raise ValueError("The length of matches and similarity_scores must be the same.") + + crosswalk_table = pd.DataFrame({ + "Item 1": [pair[0] for pair in matches], + "Item 2": [pair[1] for pair in matches], + "Similarity Score": similarity_scores + }) + return crosswalk_table + + +# MODIFIED: match_instruments_with_catalogue_instruments with new parameters def match_instruments_with_catalogue_instruments( - instruments: List[Instrument], - catalogue_data: dict, - vectorisation_function: Callable, - texts_cached_vectors: dict[str, List[float]], + instruments: List[Instrument], + catalogue_data: dict, + vectorisation_function: Callable, + texts_cached_vectors: dict[str, List[float]], + within_instrument=True, # ADDED + save_crosswalk=True # ADDED ) -> tuple[List[Instrument], List[CatalogueInstrument]]: """ - Match instruments with catalogue instruments. - - :param instruments: The instruments. - :param catalogue_data: The catalogue data. - :param vectorisation_function: A function to vectorize a text. - :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector). - :return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches from the catalog. - Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the instruments. + Match instruments with catalogue instruments, with optional within-instrument matching + and crosswalk table generation. + + Args: + instruments (list): List of instruments to match. + catalogue_data (dict): Catalogue data for matching. + vectorisation_function (callable): Function to vectorize text data. + texts_cached_vectors (dict): Cached vectors for efficiency. + within_instrument (bool): Whether to allow within-instrument matches. # ADDED + save_crosswalk (bool): Whether to save the crosswalk table. # ADDED + + Returns: + list, list: Matched item pairs and their similarity scores. """ + matches = [] + similarity_scores = [] - # Gather all questions all_questions: List[str] = [] for instrument in instruments: all_questions.extend([q.question_text for q in instrument.questions]) all_questions = list(set(all_questions)) - # Create text vectors for all questions in all the uploaded instruments + # Create text vectors all_instruments_text_vectors, _ = create_full_text_vectors( all_questions=all_questions, query=None, @@ -201,18 +170,53 @@ def match_instruments_with_catalogue_instruments( texts_cached_vectors=texts_cached_vectors, ) - # For each instrument, find the best instrument matches for it in the catalogue + # Matching logic for instrument in instruments: - instrument.closest_catalogue_instrument_matches = ( - match_questions_with_catalogue_instruments( - questions=instrument.questions, - catalogue_data=catalogue_data, - all_instruments_text_vectors=all_instruments_text_vectors, - questions_are_from_one_instrument=True, - ) + for question in instrument.questions: + for catalogue_question in catalogue_data.get("questions", []): + if not within_instrument and question.instrument_id == catalogue_question.instrument_id: + continue + + # Vectorize question text if vector not already present + vector1 = texts_cached_vectors.get(question.question_text) + if vector1 is None: + vector1 = vectorisation_function([question.question_text])[0] + texts_cached_vectors[question.question_text] = vector1 + + vector2 = texts_cached_vectors.get(catalogue_question.question_text) + if vector2 is None: + vector2 = vectorisation_function([catalogue_question.question_text])[0] + texts_cached_vectors[catalogue_question.question_text] = vector2 + + texts_cached_vectors[question.question_text] = vector1 + texts_cached_vectors[catalogue_question.question_text] = vector2 + + # Calculate similarity + score = cosine_similarity( + np.array([vector1]), + np.array([vector2]) + )[0][0] + + if score > 0.8: + matches.append((question.question_text, catalogue_question.question_text)) + similarity_scores.append(score) + + + # Assign matches to the instrument + instrument.closest_catalogue_instrument_matches = match_questions_with_catalogue_instruments( + questions=instrument.questions, + catalogue_data=catalogue_data, + all_instruments_text_vectors=all_instruments_text_vectors, + questions_are_from_one_instrument=True, ) - # Gather all questions from all instruments and find the best instrument matches in the catalogue + # Save crosswalk table if required + if save_crosswalk: + crosswalk_table = generate_crosswalk_table(matches, similarity_scores) + crosswalk_table.to_csv("crosswalk_table.csv", index=False) + print("Crosswalk table saved as 'crosswalk_table.csv'") + + # Find matches across all instruments all_instrument_questions: List[Question] = [] for instrument in instruments: all_instrument_questions.extend(instrument.questions) @@ -226,6 +230,8 @@ def match_instruments_with_catalogue_instruments( return instruments, closest_catalogue_instrument_matches + + def match_questions_with_catalogue_instruments( questions: List[Question], catalogue_data: dict, @@ -667,3 +673,5 @@ def match_instruments_with_function( query_similarity, new_vectors_dict ) + + diff --git a/src/test_batch_processing.py b/src/test_batch_processing.py new file mode 100644 index 0000000..9178601 --- /dev/null +++ b/src/test_batch_processing.py @@ -0,0 +1,46 @@ +import unittest +import numpy as np +from harmony.matching.matcher import batch_process, vectorize_items_with_batching + +class TestBatchProcessing(unittest.TestCase): + + def setUp(self): + self.vectorization_function = lambda texts: np.array([[len(text)] for text in texts]) + + def test_batch_process(self): + items = ["item1", "item2", "item3", "item4", "item5"] + batch_size = 2 + batches = batch_process(items, batch_size) + expected_batches = [["item1", "item2"], ["item3", "item4"], ["item5"]] + self.assertEqual(batches, expected_batches) + + def test_vectorize_items_with_batching(self): + items = ["short", "medium length", "a bit longer", "longest item in the list"] + batch_size = 2 + vectors = vectorize_items_with_batching(items, self.vectorization_function, batch_size) + expected_vectors = np.array([[5], [13], [12], [24]]) + np.testing.assert_array_equal(vectors, expected_vectors) + + def test_edge_case_single_item(self): + items = ["single item"] + batch_size = 2 + vectors = vectorize_items_with_batching(items, self.vectorization_function, batch_size) + expected_vectors = np.array([[11]]) + np.testing.assert_array_equal(vectors, expected_vectors) + + def test_edge_case_empty_list(self): + items = [] + batch_size = 2 + vectors = vectorize_items_with_batching(items, self.vectorization_function, batch_size) + expected_vectors = np.array([]) + np.testing.assert_array_equal(vectors, expected_vectors) + + def test_large_batch_size(self): + items = ["item1", "item2", "item3"] + batch_size = 10 + batches = batch_process(items, batch_size) + expected_batches = [["item1", "item2", "item3"]] + self.assertEqual(batches, expected_batches) + +if __name__ == "__main__": + unittest.main() diff --git a/src/test_task3_task4.py b/src/test_task3_task4.py new file mode 100644 index 0000000..9e79940 --- /dev/null +++ b/src/test_task3_task4.py @@ -0,0 +1,96 @@ +from harmony.matching.matcher import match_instruments_with_catalogue_instruments +from harmony.schemas.requests.text import Instrument, Question +import numpy as np + +def dummy_vectorisation(texts): + vector_size = 12 + + # Predefined vectors for known texts + random_vectors = { + "What is your age?": [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "How old are you?": [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "What is your name?": [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + "What is your favorite color?": [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + } + + # Assign a random vector for unknown texts + random_unknown_vector = lambda: np.random.rand(vector_size).tolist() + + # Generate vectors + vectors = [ + random_vectors.get(text, random_unknown_vector()) + for text in texts + ] + return np.array(vectors) + +# Sample instruments for testing +sample_instruments = [ + Instrument( + instrument_id="1", + questions=[Question(question_text="What is your age?", instrument_id="1")], + ), + Instrument( + instrument_id="2", + questions=[Question(question_text="How old are you?", instrument_id="2")], + ), + Instrument( + instrument_id="3", + questions=[Question(question_text="What is your name?", instrument_id="3")], + ), + Instrument( + instrument_id="4", + questions=[Question(question_text="What is your favorite color?", instrument_id="4")], + ), +] + +# Sample catalogue data +sample_catalogue_data = { + "questions": [ + Question(question_text="What is your age?", instrument_id="catalogue_1"), + Question(question_text="How old are you?", instrument_id="catalogue_2"), + Question(question_text="What is your name?", instrument_id="catalogue_3"), + Question(question_text="What color do you like?", instrument_id="catalogue_4"), + ], + "instrument_idx_to_question_idx": [[0], [1], [2], [3]], + "all_embeddings_concatenated": np.array([ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # Embedding for "What is your age?" + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # Embedding for "How old are you?" + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # Embedding for "What is your name?" + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # Embedding for "What color do you like?" + ]), + "all_instruments": [ + {"instrument_name": "Catalogue Instrument 1", "metadata": {"source": "CATALOGUE", "url": "", "sweep_id": ""}}, + {"instrument_name": "Catalogue Instrument 2", "metadata": {"source": "CATALOGUE", "url": "", "sweep_id": ""}}, + {"instrument_name": "Catalogue Instrument 3", "metadata": {"source": "CATALOGUE", "url": "", "sweep_id": ""}}, + {"instrument_name": "Catalogue Instrument 4", "metadata": {"source": "CATALOGUE", "url": "", "sweep_id": ""}}, + ], + "all_questions": ["What is your age?", "How old are you?", "What is your name?", "What color do you like?"], +} + + +# Cached vectors for efficiency +cached_vectors = {} + +# Test with within_instrument=True (Task 4 enabled) +print("=== Test with Within Instrument Matches Enabled ===") +match_instruments_with_catalogue_instruments( + instruments=sample_instruments, + catalogue_data=sample_catalogue_data, + vectorisation_function=dummy_vectorisation, + texts_cached_vectors=cached_vectors, + within_instrument=True, + save_crosswalk=True, +) + +# Test with within_instrument=False (Task 4 disabled) +print("\n=== Test with Within Instrument Matches Disabled ===") +match_instruments_with_catalogue_instruments( + instruments=sample_instruments, + catalogue_data=sample_catalogue_data, + vectorisation_function=dummy_vectorisation, + texts_cached_vectors=cached_vectors, + within_instrument=False, + save_crosswalk=True, +) + +