Skip to content

Commit

Permalink
Enhance JSON extraction and validation in Predictor; update version t…
Browse files Browse the repository at this point in the history
…o 0.3.4
  • Loading branch information
LMMasters committed Feb 6, 2025
1 parent fa4faab commit 6cffe49
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
17 changes: 15 additions & 2 deletions llm_extractinator/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,11 @@ def validate_and_fix_results(
Args:
results (List[Dict[str, Any]]): The list of results to be validated and potentially fixed.
parser_model (BaseModel): The Pydantic model for validation.
max_attempts (int, optional): Maximum number of attempts to fix each result. Defaults to 3.
Returns:
List[Dict[str, Any]]: The list of results with all items validated or attempted to be fixed.
List[Dict[str, Any]]: The validated and fixed results.
"""

def extract_json_from_text(text: str) -> Optional[dict]:
Expand Down Expand Up @@ -423,8 +424,9 @@ def handle_failure(annotation):
result.setdefault("status", "pending")
result["retry_count"] = 0

# If format is not json, attempt to extract json from the output
# If format is not JSON, attempt to extract JSON from the output
if self.format != "json" and isinstance(result, dict):
extracted_json = None
for key, value in result.items():
if isinstance(value, str): # Only process string values
extracted_json = extract_json_from_text(value)
Expand All @@ -434,6 +436,10 @@ def handle_failure(annotation):
) # Merge extracted JSON into result
break # Stop after first successful extraction

# If no JSON was extracted, mark for retrying
if not extracted_json:
result["status"] = "invalid"

attempt = 0
while attempt < max_attempts:
# Collect indices of invalid results
Expand All @@ -453,6 +459,7 @@ def handle_failure(annotation):
invalid_results = [results[i] for i in invalid_indices]
index_mapping = {i: results[i]["original_index"] for i in invalid_indices}
fixing_inputs = [{"completion": str(result)} for result in invalid_results]

print(
f"Retry {attempt + 1}: Attempting to fix {len(invalid_indices)} invalid results..."
)
Expand All @@ -468,6 +475,12 @@ def handle_failure(annotation):
# Update the results with fixed outputs
for idx, fixed_result in zip(invalid_indices, fixed_results):
original_index = index_mapping[idx]

if self.format != "json":
extracted_json = extract_json_from_text(str(fixed_result))
if extracted_json:
fixed_result = extracted_json

try:
parser_model.model_validate(fixed_result)
fixed_result["retry_count"] = results[idx]["retry_count"] + 1
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = llm_extractinator
version = 0.3.2
version = 0.3.4
description = A framework that enables efficient extraction of structured data from unstructured text using large language models (LLMs).
long_description = file: README.md
long_description_content_type = text/markdown
Expand Down

0 comments on commit 6cffe49

Please sign in to comment.