Skip to content

Commit

Permalink
Create LLMSchemaExtractor and LLMPropertyExtractor classes. (#945)
Browse files Browse the repository at this point in the history
We previously had OpenAI-specific versions of these, even though they
weren't particularly tied to the LLM. This commit generalizes those
classes and tweaks the Bedrock llm implementation to handle the basic
guidance prompts so that it works with the existing implementation.
  • Loading branch information
bsowell authored Oct 18, 2024
1 parent 4cebb73 commit e6e9877
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 21 deletions.
14 changes: 11 additions & 3 deletions lib/sycamore/sycamore/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from PIL import Image

from sycamore.llms.llms import LLM
from sycamore.llms.prompts.default_prompts import SimplePrompt
from sycamore.utils.cache import Cache
from sycamore.utils.image_utils import base64_data

Expand Down Expand Up @@ -107,13 +108,20 @@ def _get_generate_kwargs(self, prompt_kwargs: Dict, llm_kwargs: Optional[Dict] =

if "prompt" in prompt_kwargs:
prompt = prompt_kwargs.get("prompt")
kwargs.update({"messages": [{"role": "user", "content": f"{prompt}"}]})

if isinstance(prompt, SimplePrompt):
kwargs.update({"messages": prompt.as_messages(prompt_kwargs)})
else:
kwargs.update({"messages": [{"role": "user", "content": f"{prompt}"}]})

elif "messages" in prompt_kwargs:
kwargs.update({"messages": prompt_kwargs["messages"]})
if self._model_name.startswith("anthropic."):
kwargs["messages"] = self._rewrite_system_messages(kwargs["messages"])
else:
raise ValueError("Either prompt or messages must be present in prompt_kwargs.")

if self._model_name.startswith("anthropic."):
kwargs["messages"] = self._rewrite_system_messages(kwargs["messages"])

return kwargs

def generate(self, *, prompt_kwargs: dict, llm_kwargs: Optional[dict] = None) -> str:
Expand Down
19 changes: 10 additions & 9 deletions lib/sycamore/sycamore/llms/prompts/default_prompts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import ABC
from typing import Optional, Type
from typing import Any, Optional, Type

logger = logging.getLogger(__name__)

Expand All @@ -10,18 +10,19 @@ class SimplePrompt(ABC):
user: Optional[str] = None
var_name: str = "answer"

"""
Using this method assumes that the system and user prompts are populated with any placeholder values. Or the
caller is responsible for processing the messages after.
"""

def as_messages(self) -> list[dict]:
def as_messages(self, prompt_kwargs: Optional[dict[str, Any]] = None) -> list[dict]:
messages = []
if self.system is not None:
messages.append({"role": "system", "content": self.system})
system = self.system
if prompt_kwargs is not None:
system = self.system.format(**prompt_kwargs)
messages.append({"role": "system", "content": system})

if self.user is not None:
messages.append({"role": "user", "content": self.user})
user = self.user
if prompt_kwargs is not None:
user = self.user.format(**prompt_kwargs)
messages.append({"role": "user", "content": user})
return messages

def __eq__(self, other):
Expand Down
14 changes: 7 additions & 7 deletions lib/sycamore/sycamore/tests/unit/transforms/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sycamore.data import Document, Element
from sycamore.llms.llms import LLM, FakeLLM
from sycamore.transforms.extract_schema import ExtractBatchSchema, SchemaExtractor
from sycamore.transforms.extract_schema import OpenAISchemaExtractor, OpenAIPropertyExtractor
from sycamore.transforms.extract_schema import LLMSchemaExtractor, LLMPropertyExtractor
from sycamore.utils.ray_utils import check_serializable


Expand All @@ -25,7 +25,7 @@ def test_serializable(self, mocker):
check_serializable(t)

llm = FakeLLM()
o = OpenAISchemaExtractor("Foo", llm)
o = LLMSchemaExtractor("Foo", llm)
check_serializable(o)

llm = mocker.Mock(spec=LLM)
Expand All @@ -49,7 +49,7 @@ def test_extract_schema(self, mocker):
element2.text_representation = "".join(random.choices(string.ascii_letters, k=20))
doc.elements = [element1, element2]

schema_extractor = OpenAISchemaExtractor(
schema_extractor = LLMSchemaExtractor(
class_name, llm, num_of_elements=num_of_elements, max_num_properties=max_num_properties
)
doc = schema_extractor.extract_schema(doc)
Expand All @@ -75,7 +75,7 @@ def test_extract_batch_schema(self, mocker):
llm = mocker.Mock(spec=LLM)
generate = mocker.patch.object(llm, "generate")
generate.return_value = '```json {"accidentNumber": "string"}```'
schema_extractor = OpenAISchemaExtractor("AircraftIncident", llm)
schema_extractor = LLMSchemaExtractor("AircraftIncident", llm)

dicts = [
{"index": 1, "doc": "Members of a strike at Yale University."},
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_extract_properties(self, mocker):
"entity": {"weather": "sunny"},
}

property_extractor = OpenAIPropertyExtractor(llm)
property_extractor = LLMPropertyExtractor(llm)
doc = property_extractor.extract_properties(doc)

assert doc.properties["entity"]["weather"] == "sunny"
Expand All @@ -138,7 +138,7 @@ def test_extract_properties_explicit_json(self, mocker):
"_schema_class": "AircraftIncident",
}

property_extractor = OpenAIPropertyExtractor(llm)
property_extractor = LLMPropertyExtractor(llm)
doc = property_extractor.extract_properties(doc)

assert doc.properties["entity"]["accidentNumber"] == "FTW95FA129"
Expand All @@ -155,7 +155,7 @@ def test_extract_properties_fixed_json(self, mocker):
element2.text_representation = "".join(random.choices(string.ascii_letters, k=20))
doc.elements = [element1, element2]

property_extractor = OpenAIPropertyExtractor(
property_extractor = LLMPropertyExtractor(
llm, schema_name="AircraftIncident", schema={"accidentNumber": "string"}
)
doc = property_extractor.extract_properties(doc)
Expand Down
28 changes: 26 additions & 2 deletions lib/sycamore/sycamore/transforms/extract_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def extract_properties(self, document: Document) -> Document:
pass


class OpenAISchemaExtractor(SchemaExtractor):
class LLMSchemaExtractor(SchemaExtractor):
"""
OpenAISchema uses one of OpenAI's language model (LLM) for schema extraction,
given a suggested entity type to be extracted.
Expand Down Expand Up @@ -110,7 +110,19 @@ def _handle_zero_shot_prompting(self, document: Document) -> Any:
return entities


class OpenAIPropertyExtractor(PropertyExtractor):
class OpenAISchemaExtractor(LLMSchemaExtractor):
"""Alias for LLMSchemaExtractor for OpenAI models.
Retained for backward compatibility.
.. deprecated:: 0.1.25
Use LLMSchemaExtractor instead.
"""

pass


class LLMPropertyExtractor(PropertyExtractor):
"""
OpenAISchema uses one of OpenAI's language model (LLM) to extract actual property values once
a schema has been detected or provided.
Expand Down Expand Up @@ -219,6 +231,18 @@ def __init__(self, child: Node, schema_extractor: SchemaExtractor, **resource_ar
super().__init__(child, f=schema_extractor.extract_schema, **resource_args)


class OpenAIPropertyExtractor(LLMPropertyExtractor):
"""Alias for LLMPropertyExtractor for OpenAI models.
Retained for backward compatibility.
.. deprecated:: 0.1.25
Use LLMPropertyExtractor instead.
"""

pass


class ExtractBatchSchema(Map):
"""
ExtractBatchSchema is a transformation class for extracting a schema from a dataset using an SchemaExtractor.
Expand Down

0 comments on commit e6e9877

Please sign in to comment.