Skip to content

Commit

Permalink
improving tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista committed Oct 30, 2024
1 parent b655014 commit a486008
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def __init__( # pylint: disable=R0917
if self.prompt_variable not in self.prompt:
raise ValueError(f"Prompt variable '{self.prompt_variable}' must be in the prompt.")
self.splitter = DocumentSplitter(split_by="page", split_length=1)
if page_range:
self.expanded_range = expand_page_range(page_range)
self.expanded_range = expand_page_range(page_range) if page_range else None

@staticmethod
def _init_generator(
Expand Down Expand Up @@ -278,7 +277,6 @@ def _extract_metadata_and_update_doc(self, document: Document, errors: Dict[str,
else:
errors[document.id] = llm_answer


@component.output_types(documents=List[Document], errors=Dict[str, Any])
def run(self, documents: List[Document], page_range: Optional[List[Union[str, int]]] = None):
"""
Expand Down Expand Up @@ -317,6 +315,10 @@ def run(self, documents: List[Document], page_range: Optional[List[Union[str, in
if page_range:
self.expanded_range = expand_page_range(page_range)

if not self.expanded_range:
msg = f"Page range {self.expanded_range} invalid"
raise ValueError(msg)

splitter = DocumentSplitter(split_by="page", split_length=1)
pages = splitter.run(documents=[document])
content = [p.content + "\n" for idx, p in enumerate(pages["documents"]) if idx in self.expanded_range]
Expand Down
39 changes: 38 additions & 1 deletion test/components/extractors/test_llm_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def test_init_with_parameters(self, monkeypatch):
'model': 'gpt-3.5-turbo',
'generation_kwargs': {"temperature": 0.5}
},
prompt_variable="test")
prompt_variable="test",
page_range=['1-5']
)
assert isinstance(extractor.builder, PromptBuilder)
assert extractor.expected_keys == ["key1", "key2"]
assert extractor.raise_on_failure is True
Expand All @@ -45,6 +47,17 @@ def test_init_with_parameters(self, monkeypatch):
'model': 'gpt-3.5-turbo',
'generation_kwargs': {"temperature": 0.5}
}
assert extractor.expanded_range == [1, 2, 3, 4, 5]

def test_init_missing_prompt_variable(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
with pytest.raises(ValueError):
_ = LLMMetadataExtractor(
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test2"
)

def test_to_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
Expand Down Expand Up @@ -103,6 +116,30 @@ def test_from_dict(self, monkeypatch):
assert extractor.prompt == "some prompt that was used with the LLM {{test}}"
assert extractor.generator_api == LLMProvider.OPENAI

def test_output_invalid_json_raise_on_failure_true(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test",
raise_on_failure=True
)
with pytest.raises(ValueError):
extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="""{"json": "output"}""")

def test_output_valid_json_not_expected_keys(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
extractor = LLMMetadataExtractor(
prompt="prompt {{test}}",
expected_keys=["key1", "key2"],
generator_api=LLMProvider.OPENAI,
prompt_variable="test",
raise_on_failure=True
)
with pytest.raises(ValueError):
extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="{'json': 'output'}")

@pytest.mark.integration
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
Expand Down

0 comments on commit a486008

Please sign in to comment.