Skip to content

Commit

Permalink
fix all but rpred tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mittagessen committed Jan 5, 2024
1 parent 5a14092 commit ca2e77f
Show file tree
Hide file tree
Showing 16 changed files with 97 additions and 120 deletions.
4 changes: 2 additions & 2 deletions kraken/blla.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch
import logging
import numpy as np
import pkg_resources
import importlib.resources
import shapely.geometry as geom
import torch.nn.functional as F
import torchvision.transforms as tf
Expand Down Expand Up @@ -310,7 +310,7 @@ def segment(im: PIL.Image.Image,
"""
if model is None:
logger.info('No segmentation model given. Loading default model.')
model = vgsl.TorchVGSLModel.load_model(pkg_resources.resource_filename(__name__, 'blla.mlmodel'))
model = vgsl.TorchVGSLModel.load_model(importlib.resources.files(__name__).joinpath('blla.mlmodel'))

if isinstance(model, vgsl.TorchVGSLModel):
model = [model]
Expand Down
10 changes: 6 additions & 4 deletions kraken/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,17 @@ class Segmentation:
imagename: Union[str, 'PathLike']
text_direction: Literal['horizontal-lr', 'horizontal-rl', 'vertical-lr', 'vertical-rl']
script_detection: bool
lines: List[Union[BaselineLine, BBoxLine]]
regions: Dict[str, List[Region]]
line_orders: List[List[int]]
lines: Optional[List[Union[BaselineLine, BBoxLine]]] = None
regions: Optional[Dict[str, List[Region]]] = None
line_orders: Optional[List[List[int]]] = None

def __post_init__(self):
if not self.regions:
self.regions = {}
if not self.lines:
self.lines = []
if not self.line_orders:
self.line_orders = []
if len(self.lines) and not isinstance(self.lines[0], BBoxLine) and not isinstance(self.lines[0], BaselineLine):
line_cls = BBoxLine if self.type == 'bbox' else BaselineLine
self.lines = [line_cls(**line) for line in self.lines]
Expand Down Expand Up @@ -502,7 +504,7 @@ def __init__(self,
ocr_record.__init__(self, prediction, cuts, confidences, display_order)

def __repr__(self) -> str:
return f'pred: {self.prediction} line: {self.line} confidences: {self.confidences}'
return f'pred: {self.prediction} bbox: {self.bbox} confidences: {self.confidences}'

def __next__(self) -> Tuple[str, int, float]:
if self.idx + 1 < len(self):
Expand Down
5 changes: 3 additions & 2 deletions kraken/ketos/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def publish(ctx, metadata, access_token, private, model):
Publishes a model on the zenodo model repository.
"""
import json
import pkg_resources
import importlib.resources

from jsonschema import validate
from jsonschema.exceptions import ValidationError
Expand All @@ -52,7 +52,8 @@ def publish(ctx, metadata, access_token, private, model):
from kraken.lib import models
from kraken.lib.progress import KrakenDownloadProgressBar

with pkg_resources.resource_stream('kraken', 'metadata.schema.json') as fp:
ref = importlib.resources.files('kraken').joinpath('metadata.schema.json')
with open(ref, 'rb') as fp:
schema = json.load(fp)

nn = models.load_any(model)
Expand Down
4 changes: 2 additions & 2 deletions kraken/kraken.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import warnings
import logging
import dataclasses
import pkg_resources
import importlib.resources

from PIL import Image
from pathlib import Path
Expand All @@ -43,7 +43,7 @@
install(suppress=[click])

APP_NAME = 'kraken'
SEGMENTATION_DEFAULT_MODEL = pkg_resources.resource_filename(__name__, 'blla.mlmodel')
SEGMENTATION_DEFAULT_MODEL = importlib.resources.files(APP_NAME).joinpath('blla.mlmodel')
DEFAULT_MODEL = ['en_best.mlmodel']

# raise default max image size to 20k * 20k pixels
Expand Down
2 changes: 1 addition & 1 deletion kraken/lib/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _extract_line(xml_record, skip_empty_lines: bool = True):
lines=[rec],
regions=None,
script_detection=False,
line_orders=None)
line_orders=[])
try:
line_im, line = next(extract_polygons(im, seg))
except KrakenInputException:
Expand Down
21 changes: 10 additions & 11 deletions kraken/lib/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

from skimage.draw import polygon

from kraken.lib.segmentation import scale_regions

if TYPE_CHECKING:
from kraken.containers import Segmentation
from kraken.lib.xml import XMLPage
Expand All @@ -51,7 +53,6 @@ def __init__(self,
line_width: int = 4,
padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
im_transforms: Callable[[Any], torch.Tensor] = transforms.Compose([]),
mode: Optional[Literal['alto', 'page', 'xml']] = 'xml',
augmentation: bool = False,
valid_baselines: Sequence[str] = None,
merge_baselines: Dict[str, Sequence[str]] = None,
Expand All @@ -65,10 +66,6 @@ def __init__(self,
padding: Tuple of ints containing the left/right, top/bottom
padding of the input images.
target_size: Target size of the image as a (height, width) tuple.
mode: Either path, alto, page, xml, or None. In alto, page, and xml
mode the baseline paths and image data is retrieved from an
ALTO/PageXML file. In `None` mode data is iteratively added
through the `add` method.
augmentation: Enable/disable augmentation.
valid_baselines: Sequence of valid baseline identifiers. If `None`
all are valid.
Expand All @@ -82,7 +79,7 @@ def __init__(self,
been discarded.
"""
super().__init__()
self.mode = mode
self.imgs = []
self.im_mode = '1'
self.pad = padding
self.targets = []
Expand Down Expand Up @@ -125,12 +122,12 @@ def __init__(self,
self.transforms = im_transforms
self.seg_type = None

def add(self, doc: Union['Segmentation', 'XMLPage']):
def add(self, doc: Union['Segmentation']):
"""
Adds a page to the dataset.
Args:
doc: Either a Segmentation container class or an XMLPage.
doc: A Segmentation container class.
"""
if doc.type != 'baselines':
raise ValueError(f'{doc} is of type {doc.type}. Expected "baselines".')
Expand All @@ -139,6 +136,7 @@ def add(self, doc: Union['Segmentation', 'XMLPage']):
for line in doc.lines:
if self.valid_baselines is None or set(line.tags.values()).intersection(self.valid_baselines):
tags = set(line.tags.values()).intersection(self.valid_baselines) if self.valid_baselines else line.tags.values()
tags = set([self.mbl_dict.get(v, v) for v in tags])
for tag in tags:
baselines_[tag].append(line.baseline)
self.class_stats['baselines'][tag] += 1
Expand All @@ -149,8 +147,8 @@ def add(self, doc: Union['Segmentation', 'XMLPage']):

regions_ = defaultdict(list)
for k, v in doc.regions.items():
reg_type = self.mreg_dict.get(k, k)
if self.valid_regions is None or reg_type in self.valid_regions:
if self.valid_regions is None or k in self.valid_regions:
reg_type = self.mreg_dict.get(k, k)
regions_[reg_type].extend(v)
self.class_stats['baselines'][reg_type] += len(v)
if reg_type not in self.class_mapping['regions']:
Expand All @@ -170,6 +168,7 @@ def __getitem__(self, idx):
im, target = self.transform(im, target)
return {'image': im, 'target': target}
except Exception:
raise
self.failed_samples.add(idx)
idx = np.random.randint(0, len(self.imgs))
logger.debug(traceback.format_exc())
Expand Down Expand Up @@ -235,7 +234,7 @@ def transform(self, image, target):
# skip regions of classes not present in the training set
continue
for region in regions:
region = np.array(region)*scale
region = np.array(scale_regions([region.boundary], scale)[0])
rr, cc = polygon(region[:, 1], region[:, 0], shape=image.shape[1:])
t[cls_idx, rr, cc] = 1
target = F.pad(t, self.pad)
Expand Down
5 changes: 3 additions & 2 deletions kraken/lib/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json
import torch
import numbers
import pkg_resources
import importlib.resources
import torch.nn.functional as F

from functools import partial
Expand Down Expand Up @@ -320,7 +320,8 @@ def compute_confusions(algn1: Sequence[str], algn2: Sequence[str]):
script substitutions.
"""
counts: Dict[Tuple[str, str], int] = Counter()
with pkg_resources.resource_stream(__name__, 'scripts.json') as fp:
ref = importlib.resources.files(__name__).joinpath('scripts.json')
with ref.open('rb') as fp:
script_map = json.load(fp)

def _get_script(c):
Expand Down
2 changes: 1 addition & 1 deletion kraken/lib/xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,4 +598,4 @@ def to_container(self) -> Segmentation:
script_detection=True,
lines=self.get_sorted_lines(),
regions=self._regions,
line_orders=None)
line_orders=[])
15 changes: 7 additions & 8 deletions kraken/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import datetime

from pkg_resources import get_distribution
import importlib.metadata

from kraken.lib.util import make_printable

Expand Down Expand Up @@ -116,7 +116,7 @@ def serialize(results: 'Segmentation',
'base_dir': [rec.base_dir for rec in results.lines][0] if len(results.lines) else None,
'seg_type': results.type} # type: dict
metadata = {'processing_steps': processing_steps,
'version': get_distribution('kraken').version}
'version': importlib.metadata.version('kraken')}

seg_idx = 0
char_idx = 0
Expand All @@ -127,7 +127,7 @@ def serialize(results: 'Segmentation',
if line.tags is not None:
types.extend((k, v) for k, v in line.tags.items())
page['line_types'] = list(set(types))
page['region_types'] = [list(results.regions.keys())]
page['region_types'] = list(results.regions.keys())

# map reading orders indices to line IDs
ros = []
Expand All @@ -144,8 +144,8 @@ def serialize(results: 'Segmentation',
prev_reg = None
for idx, record in enumerate(results.lines):
# line not in region
if len(record.regions) == 0:
cur_ent = page['entitites']
if not record.regions or len(record.regions) == 0:
cur_ent = page['entities']
# line not in same region as previous line
elif prev_reg != record.regions[0]:
prev_reg = record.regions[0]
Expand All @@ -164,11 +164,11 @@ def serialize(results: 'Segmentation',
# set field to indicate the availability of baseline segmentation in
# addition to bounding boxes
line = {'index': idx,
'bbox': max_bbox([record.boundary] if record.type == 'baselines' else record.bbox),
'bbox': max_bbox([record.boundary] if record.type == 'baselines' else [record.bbox]),
'cuts': record.cuts,
'confidences': record.confidences,
'recognition': [],
'boundary': [list(x) for x in record.boundary],
'boundary': [list(x) for x in record.boundary] if record.type == 'baselines' else record.bbox,
'type': 'line'
}
if record.tags is not None:
Expand Down Expand Up @@ -199,7 +199,6 @@ def serialize(results: 'Segmentation',
segment,
range(char_idx, char_idx + len(segment)))],
'index': seg_idx}
# compute convex hull of all characters in segment
if record.type == 'baselines':
seg_struct['boundary'] = record[line_offset:line_offset + len(segment)][1]
line['recognition'].append(seg_struct)
Expand Down
12 changes: 6 additions & 6 deletions kraken/templates/alto
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
'postprocessing': 'postOperation'}
%}
{%+ macro render_line(page, line) +%}
<TextLine ID="line_{{ line.index }}" HPOS="{{ line.bbox[0] }}" VPOS="{{ line.bbox[1] }}" WIDTH="{{ line.bbox[2] - line.bbox[0] }}" HEIGHT="{{ line.bbox[3] - line.bbox[1] }}" {% if line.baseline %}BASELINE="{{ line.baseline|sum(start=[])|join(' ') }}"{% endif %} {% if line.tags %}TAGREFS="{% for type in page.types %}{% if type in line.tags.values() %}TYPE_{{ loop.index }}{% endif %}{% endfor %}"{% endif %}>
<TextLine ID="line_{{ line.index }}" HPOS="{{ line.bbox[0] }}" VPOS="{{ line.bbox[1] }}" WIDTH="{{ line.bbox[2] - line.bbox[0] }}" HEIGHT="{{ line.bbox[3] - line.bbox[1] }}" {% if line.baseline %}BASELINE="{{ line.baseline|sum(start=[])|join(' ') }}"{% endif %} {% if line.tags %}TAGREFS="{% for type in page.line_types %}{% if type[0] in line.tags and line.tags["type[0]"] == type[1] %}TYPE_{{ loop.index }}{% endif %}{% endfor %}"{% endif %}>
{% if line.boundary %}
<Shape>
<Polygon POINTS="{{ line.boundary|sum(start=[])|join(' ') }}"/>
Expand Down Expand Up @@ -52,7 +52,7 @@
<Processing ID="OCR_{{ step.id }}">
<processingCategory>{{ proc_type_table[step.category] }}</processingCategory>
<processingStepDescription>{{ step.description }}</processingStepDescription>
<processingStepSettings>{% for k, v in step.settings.items() %}{{k}}: {{v}}; {% endfor %}</processingStepSettings>
<processingStepSettings>{% for k, v in step.settings.items() %}{{k}}: {{v}}{% if not loop.last %}; {% endif %}{% endfor %}</processingStepSettings>
<processingSoftware>
<softwareName>kraken</softwareName>
<softwareVersion>{{ metadata.version }}</softwareVersion>
Expand All @@ -72,13 +72,13 @@
</Description>
<Tags>
{% for type, label in page.line_types %}
<OtherTag DESCRIPTION="line type" ID="TYPE_{{ loop.index }}" TYPE={{ type }} LABEL="{{ label }}"/>
<OtherTag DESCRIPTION="line type" ID="TYPE_{{ loop.index }}" TYPE="{{ type }}" LABEL="{{ label }}"/>
{% endfor %}
{% for label in page.region_types %}
<OtherTag DESCRIPTION="region type" ID="TYPE_{{ loop.index }}" TYPE={{ type }} LABEL="{{ label }}"/>
<OtherTag DESCRIPTION="region type" ID="TYPE_{{ loop.index }}" TYPE="region" LABEL="{{ label }}"/>
{% endfor %}
</Tags>
{% if len(page.line_orders) > 0 %}
{% if page.line_orders|length() > 0 %}
<ReadingOrder>
{% if len(page.line_orders) == 1 %}
<OrderedGroup ID="ro_0">
Expand Down Expand Up @@ -107,7 +107,7 @@
{% if loop.previtem and loop.previtem.type == 'line' %}
</TextBlock>
{% endif %}
<TextBlock ID="block_{{ entity.index }}" HPOS="{{ entity.bbox[0] }}" VPOS="{{ entity.bbox[1] }}" WIDTH="{{ entity.bbox[2] - entity.bbox[0] }}" HEIGHT="{{ entity.bbox[3] - entity.bbox[1] }}" {% for type in page.types %}{% if type == entity.region_type %}TAGREFS="TYPE_{{ loop.index }}"{% endif %}{% endfor %}>
<TextBlock ID="block_{{ entity.index }}" HPOS="{{ entity.bbox[0] }}" VPOS="{{ entity.bbox[1] }}" WIDTH="{{ entity.bbox[2] - entity.bbox[0] }}" HEIGHT="{{ entity.bbox[3] - entity.bbox[1] }}" {% for type in page.region_types %}{% if type == entity.region_type %}TAGREFS="TYPE_{{ loop.index }}"{% endif %}{% endfor %}>
<Shape>
<Polygon POINTS="{{ entity.boundary|sum(start=[])|join(' ') }}"/>
</Shape>
Expand Down
2 changes: 1 addition & 1 deletion tests/resources/bl_records.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/resources/records.json

Large diffs are not rendered by default.

Loading

0 comments on commit ca2e77f

Please sign in to comment.