Skip to content

Commit

Permalink
Merge pull request #504 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Fix parsing logic for table cells
  • Loading branch information
VikParuchuri authored Jan 24, 2025
2 parents 8a2a845 + d3c43d6 commit 9ed906d
Show file tree
Hide file tree
Showing 15 changed files with 232 additions and 111 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,6 @@ jobs:
- name: Test single script
run: poetry run marker_single benchmark_data/pdfs/switch_trans.pdf --page_range 0
- name: Test convert script
run: poetry run marker benchmark_data/pdfs --max_files 1 --workers 1 --page_range 0
run: poetry run marker benchmark_data/pdfs --max_files 1 --workers 1 --page_range 0
- name: Text convert script multiple workers
run: poetry run marker benchmark_data/pdfs --max_files 2 --workers 2 --page_range 0-5
3 changes: 2 additions & 1 deletion marker/config/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def parse_args(self, ctx, args):
["--" + attr],
type=info['type'],
help=" ".join(info['metadata']) + f" (Applies to: {', '.join(info['classes'])})",
default=info['default'],
default=None, # This is important, or it sets all the default keys again in config
is_flag=info['is_flag'],
)
)
Expand Down Expand Up @@ -71,6 +71,7 @@ def parse_args(self, ctx, args):
type=attr_type,
help=" ".join(metadata),
is_flag=is_flag,
default=None # This is important, or it sets all the default keys again in config
)
)

Expand Down
66 changes: 65 additions & 1 deletion marker/processors/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,27 @@ def __call__(self, document: Document):
continue

block = document.get_block(equation_d["block_id"])
block.latex = prediction
block.html = self.parse_latex_to_html(prediction)

def parse_latex_to_html(self, latex: str):
html_out = ""
try:
latex = self.parse_latex(latex)
except ValueError as e:
# If we have mismatched delimiters, we'll treat it as a single block
# Strip the $'s from the latex
latex = [
{"class": "block", "content": latex.replace("$", "")}
]

for el in latex:
if el["class"] == "block":
html_out += f'<math display="block">{el["content"]}</math>'
elif el["class"] == "inline":
html_out += f'<math display="inline">{el["content"]}</math>'
else:
html_out += f" {el['content']} "
return html_out.strip()

def get_batch_size(self):
if self.texify_batch_size is not None:
Expand Down Expand Up @@ -110,3 +130,47 @@ def get_total_texify_tokens(self, text):
tokenizer = self.texify_model.processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])


@staticmethod
def parse_latex(text: str):
if text.count("$") % 2 != 0:
raise ValueError("Mismatched delimiters in LaTeX")

DELIMITERS = [
("$$", "block"),
("$", "inline")
]

text = text.replace("\n", "<br>") # we can't handle \n's inside <p> properly if we don't do this

i = 0
stack = []
result = []
buffer = ""

while i < len(text):
for delim, class_name in DELIMITERS:
if text[i:].startswith(delim):
if stack and stack[-1] == delim: # Closing
stack.pop()
result.append({"class": class_name, "content": buffer})
buffer = ""
i += len(delim)
break
elif not stack: # Opening
if buffer:
result.append({"class": "text", "content": buffer})
stack.append(delim)
buffer = ""
i += len(delim)
break
else:
raise ValueError(f"Nested {class_name} delimiters not supported")
else: # No delimiter match
buffer += text[i]
i += 1

if buffer:
result.append({"class": "text", "content": buffer})
return result
59 changes: 34 additions & 25 deletions marker/processors/llm/llm_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,68 +15,77 @@ class LLMEquationProcessor(BaseLLMProcessor):
min_equation_height: Annotated[
float,
"The minimum ratio between equation height and page height to consider for processing.",
] = 0.1
] = 0.08
equation_image_expansion_ratio: Annotated[
float,
"The ratio to expand the image by when cropping.",
] = 0.05 # Equations sometimes get bboxes that are too tight
equation_latex_prompt: Annotated[
str,
"The prompt to use for generating LaTeX from equations.",
"Default is a string containing the Gemini prompt."
] = """You're an expert mathematician who is good at writing LaTeX code for equations'.
You will receive an image of a math block that may contain one or more equations. Your job is to write the LaTeX code for the equation, along with markdown for any other text.
] = """You're an expert mathematician who is good at writing LaTeX code and html for equations.
You'll receive an image of a math block that may contain one or more equations. Your job is to write html that represents the content of the image, with the equations in LaTeX format, and fenced by delimiters.
Some guidelines:
- Keep the LaTeX code simple and concise.
- Make it KaTeX compatible.
- Use $$ as a block equation delimiter and $ for inline equations. Block equations should also be on their own line. Do not use any other delimiters.
- You can include text in between equation blocks as needed. Try to put long text segments into plain text and not inside the equations.
- Output valid html, where all the equations can render properly.
- Use <math display="block"> as a block equation delimiter and <math> for inline equations.
- Keep the LaTeX code inside the math tags simple, concise, and KaTeX compatible.
- Enclose all equations in the correct math tags. Use multiple math tags inside the html to represent multiple equations.
- Only use the html tags math, i, b, p, and br.
- Make sure to include all the equations in the image in the html output.
**Instructions:**
1. Carefully examine the provided image.
2. Analyze the existing markdown, which may include LaTeX code.
3. If the markdown and LaTeX are correct, write "No corrections needed."
4. If the markdown and LaTeX are incorrect, generate the corrected markdown and LaTeX.
5. Output only the corrected text or "No corrections needed."
2. Analyze the existing html, which may include LaTeX code.
3. If the html and LaTeX are correct, write "No corrections needed."
4. If the html and LaTeX are incorrect, generate the corrected html.
5. Output only the corrected html or "No corrections needed."
**Example:**
Input:
```markdown
```html
Equation 1:
$$x^2 + y^2 = z2$$
<math display="block">x2 + y2 = z2</math>
Equation 2:
<math display="block">\frac{ab \cdot x^5 + x^2 + 2 \cdot x + 123}{t}</math>
```
Output:
```markdown
Equation 1:
$$x^2 + y^2 = z^2$$
```html
<p>Equation 1:</p>
<math display="block">x^{2} + y^{2} = z^{2}</math>
<p>Equation 2:</p>
<math display="block">\frac{ab \cdot x^{5} + x^{2} + 2 \cdot x + 123}{t}</math>
```
**Input:**
```markdown
```html
{equation}
```
"""

def process_rewriting(self, document: Document, page: PageGroup, block: Equation):
text = block.latex if block.latex else block.raw_text(document)
text = block.html if block.html else block.raw_text(document)
prompt = self.equation_latex_prompt.replace("{equation}", text)

image = self.extract_image(document, block)
response_schema = content.Schema(
type=content.Type.OBJECT,
enum=[],
required=["markdown_equation"],
required=["html_equation"],
properties={
"markdown_equation": content.Schema(
"html_equation": content.Schema(
type=content.Type.STRING
)
},
)

response = self.model.generate_response(prompt, image, block, response_schema)

if not response or "markdown_equation" not in response:
if not response or "html_equation" not in response:
block.update_metadata(llm_error_count=1)
return

markdown_equation = response["markdown_equation"]
if len(markdown_equation) < len(text) * .5:
html_equation = response["html_equation"]
if len(html_equation) < len(text) * .5:
block.update_metadata(llm_error_count=1)
return

block.latex = markdown_equation
block.html = html_equation
6 changes: 3 additions & 3 deletions marker/processors/llm/llm_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class LLMTableProcessor(BaseLLMProcessor):
- Make sure to reproduce the original values as faithfully as possible.
- If you see any math in a table cell, fence it with the <math display="inline"> tag. Block math should be fenced with <math display="block">.
- Replace any images with a description, like "Image: [description]".
- Only use the tags th, td, tr, span, i, b, math, and table. Only use the attributes display, style, colspan, and rowspan if necessary.
- Only use the tags th, td, tr, br, span, i, b, math, and table. Only use the attributes display, style, colspan, and rowspan if necessary. You can use br to break up text lines in cells.
**Instructions:**
1. Carefully examine the provided text block image.
Expand Down Expand Up @@ -172,11 +172,11 @@ def rewrite_single_chunk(self, page: PageGroup, block: Block, block_html: str, c
return parsed_cells

@staticmethod
def get_cell_text(element, keep_tags=('br',)):
def get_cell_text(element, keep_tags=('br','i', 'b', 'span', 'math')) -> str:
for tag in element.find_all(True):
if tag.name not in keep_tags:
tag.unwrap()
return element.decode_contents().replace("<br>", "\n")
return element.decode_contents()

def parse_html_table(self, html_text: str, block: Block, page: PageGroup) -> List[TableCell]:
soup = BeautifulSoup(html_text, 'html.parser')
Expand Down
6 changes: 5 additions & 1 deletion marker/processors/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class TableProcessor(BaseProcessor):
List[BlockTypes],
"Block types to remove if they're contained inside the tables."
] = (BlockTypes.Text, BlockTypes.TextInlineMath)
pdftext_workers: Annotated[
int,
"The number of workers to use for pdftext.",
] = 4

def __init__(
self,
Expand Down Expand Up @@ -273,7 +277,7 @@ def assign_pdftext_lines(self, extract_blocks: list, filepath: str):
"tables": tables,
"img_size": img_size
})
cell_text = table_output(filepath, table_inputs, page_range=unique_pages)
cell_text = table_output(filepath, table_inputs, page_range=unique_pages, workers=self.pdftext_workers)
assert len(cell_text) == len(unique_pages), "Number of pages and table inputs must match"

for pidx, (page_tables, pnum) in enumerate(zip(cell_text, unique_pages)):
Expand Down
31 changes: 28 additions & 3 deletions marker/renderers/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Annotated, Tuple

import regex
from bs4 import NavigableString
from markdownify import MarkdownConverter
from pydantic import BaseModel

Expand All @@ -16,8 +17,32 @@ def cleanup_text(full_text):
full_text = re.sub(r'(\n\s){3,}', '\n\n', full_text)
return full_text.strip()

def get_text_with_br(element):
return ''.join(str(content) if content.name == 'br' else content.strip() for content in element.contents)
def get_formatted_table_text(element):
text = []
for content in element.contents:
if content is None:
continue

if isinstance(content, NavigableString):
stripped = content.strip()
if stripped:
text.append(stripped)
elif content.name == 'br':
text.append('<br>')
elif content.name == "math":
text.append("$" + content.text + "$")
else:
text.append(str(content))

full_text = ""
for i, t in enumerate(text):
if t == '<br>':
full_text += t
elif i > 0 and text[i - 1] != '<br>':
full_text += " " + t
else:
full_text += t
return full_text


class Markdownify(MarkdownConverter):
Expand Down Expand Up @@ -81,7 +106,7 @@ def convert_table(self, el, text, convert_as_inline):
col_idx += 1

# Fill in grid
value = get_text_with_br(cell).replace("\n", " ").replace("|", " ")
value = get_formatted_table_text(cell).replace("\n", " ").replace("|", " ").strip()
rowspan = int(cell.get('rowspan', 1))
colspan = int(cell.get('colspan', 1))

Expand Down
67 changes: 3 additions & 64 deletions marker/schema/blocks/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,76 +6,15 @@

class Equation(Block):
block_type: BlockTypes = BlockTypes.Equation
latex: str | None = None
html: str | None = None
block_description: str = "A block math equation."

def assemble_html(self, document, child_blocks, parent_structure=None):
if self.latex:
if self.html:
child_ref_blocks = [block for block in child_blocks if block.id.block_type == BlockTypes.Reference]
html_out = super().assemble_html(document, child_ref_blocks, parent_structure)
html_out += f"<p block-type='{self.block_type}'>"

try:
latex = self.parse_latex(html.escape(self.latex))
except ValueError as e:
# If we have mismatched delimiters, we'll treat it as a single block
# Strip the $'s from the latex
latex = [
{"class": "block", "content": self.latex.replace("$", "")}
]

for el in latex:
if el["class"] == "block":
html_out += f'<math display="block">{el["content"]}</math>'
elif el["class"] == "inline":
html_out += f'<math display="inline">{el["content"]}</math>'
else:
html_out += el["content"]
html_out += "</p>"
html_out += f"""<p block-type='{self.block_type}'>{self.html}</p>"""
return html_out
else:
template = super().assemble_html(document, child_blocks, parent_structure)
return f"<p block-type='{self.block_type}'>{template}</p>"

@staticmethod
def parse_latex(text: str):
if text.count("$") % 2 != 0:
raise ValueError("Mismatched delimiters in LaTeX")

DELIMITERS = [
("$$", "block"),
("$", "inline")
]

text = text.replace("\n", "<br>") # we can't handle \n's inside <p> properly if we don't do this

i = 0
stack = []
result = []
buffer = ""

while i < len(text):
for delim, class_name in DELIMITERS:
if text[i:].startswith(delim):
if stack and stack[-1] == delim: # Closing
stack.pop()
result.append({"class": class_name, "content": buffer})
buffer = ""
i += len(delim)
break
elif not stack: # Opening
if buffer:
result.append({"class": "text", "content": buffer})
stack.append(delim)
buffer = ""
i += len(delim)
break
else:
raise ValueError(f"Nested {class_name} delimiters not supported")
else: # No delimiter match
buffer += text[i]
i += 1

if buffer:
result.append({"class": "text", "content": buffer})
return result
2 changes: 1 addition & 1 deletion marker/scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def process_single_pdf(args):

@click.command(cls=CustomClickPrinter)
@click.argument("in_folder", type=str)
@ConfigParser.common_options
@click.option("--chunk_idx", type=int, default=0, help="Chunk index to convert")
@click.option("--num_chunks", type=int, default=1, help="Number of chunks being processed in parallel")
@click.option("--max_files", type=int, default=None, help="Maximum number of pdfs to convert")
@click.option("--workers", type=int, default=5, help="Number of worker processes to use.")
@click.option("--skip_existing", is_flag=True, default=False, help="Skip existing converted files.")
@ConfigParser.common_options
def convert_cli(in_folder: str, **kwargs):
in_folder = os.path.abspath(in_folder)
files = [os.path.join(in_folder, f) for f in os.listdir(in_folder)]
Expand Down
Loading

0 comments on commit 9ed906d

Please sign in to comment.