diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index d57f1671..026f7d4a 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -1,8 +1,8 @@ from enum import Enum from pathlib import Path -from typing import List, Literal, Optional, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import AnyUrl, BaseModel, ConfigDict, Field class TableFormerMode(str, Enum): @@ -61,6 +61,46 @@ class TesseractOcrOptions(OcrOptions): ) +class PicDescBaseOptions(BaseModel): + kind: str + batch_size: int = 8 + scale: float = 2 + + bitmap_area_threshold: float = ( + 0.2 # percentage of the area for a bitmap to processed with the models + ) + + +class PicDescApiOptions(PicDescBaseOptions): + kind: Literal["api"] = "api" + + url: AnyUrl = AnyUrl("") + headers: Dict[str, str] = {} + params: Dict[str, Any] = {} + timeout: float = 20 + + llm_prompt: str = "" + provenance: str = "" + + +class PicDescVllmOptions(PicDescBaseOptions): + kind: Literal["vllm"] = "vllm" + + # For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html + + # Parameters for LLaVA-1.6/LLaVA-NeXT + llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf" + llm_prompt: str = "[INST] \nDescribe the image in details. [/INST]" + llm_extra: Dict[str, Any] = dict(max_model_len=8192) + + # Parameters for Phi-3-Vision + # llm_name: str = "microsoft/Phi-3-vision-128k-instruct" + # llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n" + # llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True) + + sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42) + + class PipelineOptions(BaseModel): create_legacy_output: bool = ( True # This defautl will be set to False on a future version of docling @@ -71,11 +111,15 @@ class PdfPipelineOptions(PipelineOptions): artifacts_path: Optional[Union[Path, str]] = None do_table_structure: bool = True # True: perform table structure extraction do_ocr: bool = True # True: perform OCR, replace programmatic PDF text + do_picture_description: bool = False table_structure_options: TableStructureOptions = TableStructureOptions() ocr_options: Union[EasyOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions] = ( Field(EasyOcrOptions(), discriminator="kind") ) + picture_description_options: Annotated[ + Union[PicDescApiOptions, PicDescVllmOptions], Field(discriminator="kind") + ] = PicDescApiOptions() # TODO: needs defaults or optional images_scale: float = 1.0 generate_page_images: bool = False diff --git a/docling/models/pic_description_api_model.py b/docling/models/pic_description_api_model.py new file mode 100644 index 00000000..465b6516 --- /dev/null +++ b/docling/models/pic_description_api_model.py @@ -0,0 +1,99 @@ +import base64 +import io +import logging +from typing import List, Optional + +import httpx +from docling_core.types.doc import PictureItem +from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc + PictureDescriptionData, +) +from pydantic import BaseModel, ConfigDict + +from docling.datamodel.pipeline_options import PicDescApiOptions +from docling.models.pic_description_base_model import PictureDescriptionBaseModel + +_log = logging.getLogger(__name__) + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: str + + +class ResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ApiResponse(BaseModel): + model_config = ConfigDict( + protected_namespaces=(), + ) + + id: str + model: Optional[str] = None # returned bu openai + choices: List[ResponseChoice] + created: int + usage: ResponseUsage + + +class PictureDescriptionApiModel(PictureDescriptionBaseModel): + + def __init__(self, enabled: bool, options: PicDescApiOptions): + super().__init__(enabled=enabled, options=options) + self.options: PicDescApiOptions + + def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData: + assert picture.image is not None + + img_io = io.BytesIO() + picture.image.pil_image.save(img_io, "PNG") + + image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": self.options.llm_prompt, + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + ], + } + ] + + payload = { + "messages": messages, + **self.options.params, + } + + r = httpx.post( + str(self.options.url), + headers=self.options.headers, + json=payload, + timeout=self.options.timeout, + ) + if not r.is_success: + _log.error(f"Error calling the API. Reponse was {r.text}") + r.raise_for_status() + + api_resp = ApiResponse.model_validate_json(r.text) + generated_text = api_resp.choices[0].message.content.strip() + + return PictureDescriptionData( + provenance=self.options.provenance, + text=generated_text, + ) diff --git a/docling/models/pic_description_base_model.py b/docling/models/pic_description_base_model.py new file mode 100644 index 00000000..76949294 --- /dev/null +++ b/docling/models/pic_description_base_model.py @@ -0,0 +1,46 @@ +import logging +from pathlib import Path +from typing import Any, Iterable + +from docling_core.types.doc import ( + DoclingDocument, + NodeItem, + PictureClassificationClass, + PictureItem, +) +from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc + PictureDescriptionData, +) + +from docling.datamodel.pipeline_options import PicDescBaseOptions +from docling.models.base_model import BaseEnrichmentModel + + +class PictureDescriptionBaseModel(BaseEnrichmentModel): + + def __init__(self, enabled: bool, options: PicDescBaseOptions): + self.enabled = enabled + self.options = options + self.provenance = "TODO" + + def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: + # TODO: once the image classifier is active, we can differentiate among image types + return self.enabled and isinstance(element, PictureItem) + + def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData: + raise NotImplemented + + def __call__( + self, doc: DoclingDocument, element_batch: Iterable[NodeItem] + ) -> Iterable[Any]: + if not self.enabled: + return + + for element in element_batch: + assert isinstance(element, PictureItem) + assert element.image is not None + + annotation = self._annotate_image(element) + element.annotations.append(annotation) + + yield element diff --git a/docling/models/pic_description_vllm_model.py b/docling/models/pic_description_vllm_model.py new file mode 100644 index 00000000..b461256c --- /dev/null +++ b/docling/models/pic_description_vllm_model.py @@ -0,0 +1,59 @@ +import json +from typing import List + +from docling_core.types.doc import PictureItem +from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc + PictureDescriptionData, +) + +from docling.datamodel.pipeline_options import PicDescVllmOptions +from docling.models.pic_description_base_model import PictureDescriptionBaseModel + + +class PictureDescriptionVllmModel(PictureDescriptionBaseModel): + + def __init__(self, enabled: bool, options: PicDescVllmOptions): + super().__init__(enabled=enabled, options=options) + self.options: PicDescVllmOptions + + if self.enabled: + raise NotImplemented + + if self.enabled: + try: + from vllm import LLM, SamplingParams # type: ignore + except ImportError: + raise ImportError( + "VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`." + ) + + self.sampling_params = SamplingParams(**self.options.sampling_params) # type: ignore + self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) # type: ignore + + # Generate a stable hash from the extra parameters + def create_hash(t): + return "" + + params_hash = create_hash( + json.dumps(self.options.llm_extra, sort_keys=True) + + json.dumps(self.options.sampling_params, sort_keys=True) + ) + self.provenance = f"{self.options.llm_name}-{params_hash[:8]}" + + def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData: + assert picture.image is not None + + from vllm import RequestOutput + + inputs = [ + { + "prompt": self.options.llm_prompt, + "multi_modal_data": {"image": picture.image.pil_image}, + } + ] + outputs: List[RequestOutput] = self.llm.generate( # type: ignore + inputs, sampling_params=self.sampling_params # type: ignore + ) + + generated_text = outputs[0].outputs[0].text + return PictureDescriptionData(provenance=self.provenance, text=generated_text) diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 65803d4f..0fe7a3c9 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -11,6 +11,8 @@ from docling.datamodel.pipeline_options import ( EasyOcrOptions, PdfPipelineOptions, + PicDescApiOptions, + PicDescVllmOptions, TesseractCliOcrOptions, TesseractOcrOptions, ) @@ -23,6 +25,9 @@ PagePreprocessingModel, PagePreprocessingOptions, ) +from docling.models.pic_description_api_model import PictureDescriptionApiModel +from docling.models.pic_description_base_model import PictureDescriptionBaseModel +from docling.models.pic_description_vllm_model import PictureDescriptionVllmModel from docling.models.table_structure_model import TableStructureModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel from docling.models.tesseract_ocr_model import TesseractOcrModel @@ -83,8 +88,15 @@ def __init__(self, pipeline_options: PdfPipelineOptions): PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)), ] + # Picture description model + if (pic_desc_model := self.get_pic_description_model()) is None: + raise RuntimeError( + f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}." + ) + self.enrichment_pipe = [ # Other models working on `NodeItem` elements in the DoclingDocument + pic_desc_model, ] @staticmethod @@ -120,6 +132,23 @@ def get_ocr_model(self) -> Optional[BaseOcrModel]: ) return None + def get_pic_description_model(self) -> Optional[PictureDescriptionBaseModel]: + if isinstance( + self.pipeline_options.picture_description_options, PicDescApiOptions + ): + return PictureDescriptionApiModel( + enabled=self.pipeline_options.do_picture_description, + options=self.pipeline_options.picture_description_options, + ) + elif isinstance( + self.pipeline_options.picture_description_options, PicDescVllmOptions + ): + return PictureDescriptionVllmModel( + enabled=self.pipeline_options.do_picture_description, + options=self.pipeline_options.picture_description_options, + ) + return None + def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: with TimeRecorder(conv_res, "page_init"): page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore