diff --git a/configs/Cini/cini_config.json b/configs/Cini/cini_config.json index cc154f2..5cf9659 100644 --- a/configs/Cini/cini_config.json +++ b/configs/Cini/cini_config.json @@ -5,21 +5,7 @@ "make_patches": true, "n_epochs": 60, "input_resized_size" : 800000, - "patch_shape": [300, 300], - "data_augmentation": false - }, - "model_params": { - "batch_norm": true, - "batch_renorm": true, - "weight_decay": 1e-4, - "selected_levels_upscaling" : [ - true, - true, - true, - true, - false, - false - ] + "patch_shape": [400, 400] }, "pretrained_model_name" : "resnet50", "prediction_type" : "CLASSIFICATION" diff --git a/configs/Cini/cini_variant2.json b/configs/Cini/cini_variant.json similarity index 100% rename from configs/Cini/cini_variant2.json rename to configs/Cini/cini_variant.json diff --git a/configs/cBAD/cbad_post_processing_configs.json b/configs/cBAD/cbad_post_processing_configs.json index 4719711..7c06c64 100644 --- a/configs/cBAD/cbad_post_processing_configs.json +++ b/configs/cBAD/cbad_post_processing_configs.json @@ -1,6 +1,19 @@ { "configs":[ - {"sigma": 2.5, "low_threshold": 0.1, "high_threshold": 0.3}, - {"sigma": 2.5, "low_threshold": 0.3, "high_threshold": 0.7} + {"low_threshold": 0.1, "high_threshold": 0.3}, + {"low_threshold": 0.2, "high_threshold": 0.3}, + {"low_threshold": 0.2, "high_threshold": 0.4}, + {"low_threshold": 0.3, "high_threshold": 0.4}, + {"low_threshold": 0.3, "high_threshold": 0.5}, + {"low_threshold": 0.1, "high_threshold": 0.3, "sigma": 1.5}, + {"low_threshold": 0.2, "high_threshold": 0.3, "sigma": 1.5}, + {"low_threshold": 0.2, "high_threshold": 0.4, "sigma": 1.5}, + {"low_threshold": 0.3, "high_threshold": 0.4, "sigma": 1.5}, + {"low_threshold": 0.3, "high_threshold": 0.5, "sigma": 1.5}, + {"low_threshold": 0.1, "high_threshold": 0.3, "sigma": 2.5}, + {"low_threshold": 0.2, "high_threshold": 0.3, "sigma": 2.5}, + {"low_threshold": 0.2, "high_threshold": 0.4, "sigma": 2.5}, + {"low_threshold": 0.3, "high_threshold": 0.4, "sigma": 2.5}, + {"low_threshold": 0.3, "high_threshold": 0.5, "sigma": 2.5} ] } \ No newline at end of file diff --git a/configs/cBAD/cbad_variant2.json b/configs/cBAD/cbad_variant.json similarity index 100% rename from configs/cBAD/cbad_variant2.json rename to configs/cBAD/cbad_variant.json diff --git a/configs/cBAD/cbad_variant1.json b/configs/cBAD/cbad_variant1.json deleted file mode 100644 index eb67abb..0000000 --- a/configs/cBAD/cbad_variant1.json +++ /dev/null @@ -1,6 +0,0 @@ -{ -"train_dir" : "/home/datasets/cBAD/Baseline_Competition_Complex_Documents/generated_baseline_endpoints_thick_multi/train", -"eval_dir" : "/home/datasets/cBAD/Baseline_Competition_Complex_Documents/generated_baseline_endpoints_thick_multi//validation", -"classes_file" : "/home/datasets/cBAD/Baseline_Competition_Complex_Documents/generated_baseline_endpoints_thick_multi/classes.txt", -"model_output_dir" : "/home/docseg_models/cBAD/" -} \ No newline at end of file diff --git a/doc_seg/post_processing/PAGE.py b/doc_seg/post_processing/PAGE.py index 30f37c5..1d6bd47 100644 --- a/doc_seg/post_processing/PAGE.py +++ b/doc_seg/post_processing/PAGE.py @@ -273,6 +273,18 @@ def draw_baselines(self, img_canvas, color=(255, 0, 0), thickness=2, endpoint_ra cv2.circle(img_canvas, (coords[-1, 0, 0], coords[-1, 0, 1]), radius=endpoint_radius, color=color, thickness=-1) + def draw_textregions(self, img_canvas, color=(255, 0, 0), autoscale=True): + if autoscale: + assert self.image_height is not None + assert self.image_width is not None + ratio = (img_canvas.shape[0]/self.image_height, img_canvas.shape[1]/self.image_width) + else: + ratio = (1, 1) + + tr_coords = [(Point.list_to_cv2poly(tr.coords)*ratio).astype(np.int32) for tr in self.text_regions + if len(tr.coords) > 0] + cv2.fillPoly(img_canvas, tr_coords, color) + def parse_file(filename: str) -> Page: xml_page = ET.parse(filename) diff --git a/doc_seg/post_processing/__init__.py b/doc_seg/post_processing/__init__.py index 277076a..7b6a84b 100644 --- a/doc_seg/post_processing/__init__.py +++ b/doc_seg/post_processing/__init__.py @@ -1,4 +1,4 @@ -from .segmentation import dibco_binarization_fn, diva_post_processing_fn, page_post_processing_fn +from .segmentation import dibco_binarization_fn from .line_detection import cbad_post_processing_fn from .boxes_detection import cini_post_processing_fn, ornaments_post_processing_fn diff --git a/doc_seg/post_processing/line_detection.py b/doc_seg/post_processing/line_detection.py index 6ea0477..4ca0c87 100644 --- a/doc_seg/post_processing/line_detection.py +++ b/doc_seg/post_processing/line_detection.py @@ -14,7 +14,7 @@ def cbad_post_processing_fn(probs: np.array, sigma: float=2.5, low_threshold: float=0.8, high_threshold: float=0.9, - output_basename=None): + filter_width: float=0, output_basename=None): """ :param probs: output of the model (probabilities) in range [0, 255] @@ -28,7 +28,7 @@ def cbad_post_processing_fn(probs: np.array, sigma: float=2.5, low_threshold: fl WARNING : contours IN OPENCV format List[np.ndarray(n_points, 1, (x,y))] """ - contours, lines_mask = line_extraction_v1(probs[:, :, 1], sigma, low_threshold, high_threshold) + contours, lines_mask = line_extraction_v1(probs[:, :, 1], sigma, low_threshold, high_threshold, filter_width) if output_basename is not None: dump_pickle(output_basename+'.pkl', (contours, lines_mask.shape)) return contours, lines_mask @@ -47,14 +47,17 @@ def line_extraction_v0(probs, sigma, threshold): return contours, lines_mask -def line_extraction_v1(probs, sigma, low_threshold, high_threshold): - # probs_line = probs[:, :, 1] +def line_extraction_v1(probs, low_threshold, high_threshold, sigma=0.0, filter_width=0.00, vertical_maxima=True): probs_line = probs # Smooth - probs2 = cv2.GaussianBlur(probs_line, (int(3*sigma)*2+1, int(3*sigma)*2+1), sigma) - local_maxima = vertical_local_maxima(probs2) - lines_mask = hysteresis_thresholding(probs2, local_maxima, low_threshold, high_threshold) - # Remove lines touching border + if sigma > 0.: + probs2 = cv2.GaussianBlur(probs_line, (int(3*sigma)*2+1, int(3*sigma)*2+1), sigma) + else: + probs2 = cv2.fastNlMeansDenoising((probs_line*255).astype(np.uint8), h=50)/255 + #probs2 = probs_line + #local_maxima = vertical_local_maxima(probs2) + lines_mask = hysteresis_thresholding(probs2, low_threshold, high_threshold, + candidates=vertical_local_maxima(probs2) if vertical_maxima else None) #lines_mask = remove_borders(lines_mask) # Extract polygons from line mask contours = extract_line_polygons(lines_mask) @@ -62,10 +65,11 @@ def line_extraction_v1(probs, sigma, low_threshold, high_threshold): filtered_contours = [] page_width = probs.shape[1] for cnt in contours: - if cv2.arcLength(cnt, False) < 0.05*page_width: - continue - if cv2.arcLength(cnt, False) < 0.05*page_width: + centroid_x, centroid_y = np.mean(cnt, axis=0)[0] + if centroid_x < filter_width*page_width or centroid_x > (1-filter_width)*page_width: continue + # if cv2.arcLength(cnt, False) < filter_width*page_width: + # continue filtered_contours.append(cnt) return filtered_contours, lines_mask @@ -153,12 +157,14 @@ def goal_reached(self, int_index, float_cumcost): def vertical_local_maxima(probs): local_maxima = np.zeros_like(probs, dtype=bool) - local_maxima[1:-1] = (probs[1:-1] > probs[:-2]) & (probs[2:] < probs[1:-1]) + local_maxima[1:-1] = (probs[1:-1] >= probs[:-2]) & (probs[2:] <= probs[1:-1]) return local_maxima -def hysteresis_thresholding(probs: np.array, candidates: np.array, low_threshold: float, high_threshold: float): - low_mask = candidates & (probs > low_threshold) +def hysteresis_thresholding(probs: np.array, low_threshold: float, high_threshold: float, candidates=None): + low_mask = probs > low_threshold + if candidates is not None: + low_mask = candidates & low_mask # Connected components extraction label_components, count = label(low_mask, np.ones((3, 3))) # Keep components with high threshold elements diff --git a/doc_seg/utils.py b/doc_seg/utils.py index 7a8a5e7..99582b3 100644 --- a/doc_seg/utils.py +++ b/doc_seg/utils.py @@ -3,6 +3,7 @@ import os import json import pickle +from hashlib import sha1 class PredictionType: @@ -269,3 +270,7 @@ def load_pickle(filename): def dump_pickle(filename, obj): with open(filename, 'wb') as f: return pickle.dump(obj, f) + + +def hash_dict(params): + return sha1(json.dumps(params, sort_keys=True).encode()).hexdigest() \ No newline at end of file diff --git a/exps/Cini/__init__.py b/exps/Cini/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/exps/Cini/cini_evaluation.py b/exps/Cini/cini_evaluation.py index 0b9084c..b39081f 100644 --- a/exps/Cini/cini_evaluation.py +++ b/exps/Cini/cini_evaluation.py @@ -3,7 +3,7 @@ from scipy.misc import imread, imsave, imresize import cv2 import numpy as np -from .cini_post_processing import cini_post_processing_fn +from cini_post_processing import cini_post_processing_fn from doc_seg.utils import load_pickle import pandas as pd diff --git a/exps/Cini/cini_process_set.py b/exps/Cini/cini_process_set.py new file mode 100644 index 0000000..3e8342e --- /dev/null +++ b/exps/Cini/cini_process_set.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +__author__ = 'solivr' + +import os +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) +from doc_seg.loader import LoadedModel +from cini_post_processing import cini_post_processing_fn +from cini_evaluation import cini_evaluate_folder + +import tensorflow as tf +from tqdm import tqdm +import numpy as np +import argparse +from glob import glob +from scipy.misc import imread, imresize, imsave +import tempfile +import json +from doc_seg.post_processing import PAGE +from doc_seg.utils import hash_dict, dump_json + + +def predict_on_set(filenames_to_predict, model_dir, output_dir): + """ + + :param filenames_to_predict: + :param model_dir: + :param output_dir: + :return: + """ + with tf.Session(): + m = LoadedModel(model_dir, 'filename') + for filename in tqdm(filenames_to_predict, desc='Prediction'): + pred = m.predict(filename)['probs'][0] + np.save(os.path.join(output_dir, os.path.basename(filename).split('.')[0]), + np.uint8(255 * pred)) + + +def find_elements(img_filenames, dir_predictions, post_process_params, output_dir, debug=False, mask_dir: str=None): + """ + + :param img_filenames: + :param dir_predictions: + :param post_process_params: + :param output_dir: + :return: + """ + + os.makedirs(output_dir, exist_ok=True) + + for filename in tqdm(img_filenames, 'Post-processing'): + orig_img = imread(filename, mode='RGB') + basename = os.path.basename(filename).split('.')[0] + + filename_pred = os.path.join(dir_predictions, basename + '.npy') + pred = np.load(filename_pred)/255 # type: np.ndarray + + contours, lines_mask = cini_post_processing_fn(pred, **post_process_params, + output_basename=os.path.join(output_dir, basename)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--model-dir', type=str, required=True, + help='Directory of the model (should be of type ''*/export/)') + parser.add_argument('-i', '--input-files', type=str, required=True, nargs='+', + help='Folder containing the images to evaluate the model on') + parser.add_argument('-o', '--output-dir', type=str, required=True, + help='Folder containing the outputs (.npy predictions and visualization errors)') + parser.add_argument('-gt', '--ground_truth_dir', type=str, required=True, + help='Ground truth directory containing the labeled images') + parser.add_argument('--params-file', type=str, default=None, + help='JSOn file containing the params for post-processing') + parser.add_argument('--gpu', type=str, default='0', help='Which GPU to use') + parser.add_argument('-pp', '--post-process-only', default=False, action='store_true', + help='Whether to make or not the prediction') + args = parser.parse_args() + args = vars(args) + + os.environ["CUDA_VISIBLE_DEVICES"] = args.get('gpu') + model_dir = args.get('model_dir') + input_files = args.get('input_files') + if len(input_files) == 0: + raise FileNotFoundError + + output_dir = args.get('output_dir') + os.makedirs(output_dir, exist_ok=True) + + # Prediction + npy_directory = output_dir + if not args.get('post_process_only'): + predict_on_set(input_files, model_dir, npy_directory) + + npy_files = glob(os.path.join(npy_directory, '*.npy')) + + if args.get('params_file') is None: + print('No params file found') + params_list = [{"clean_predictions": True, "advanced": True}] + else: + with open(args.get('params_file'), 'r') as f: + configs_data = json.load(f) + # If the file contains a list of configurations + if 'configs' in configs_data.keys(): + params_list = configs_data['configs'] + assert isinstance(params_list, list) + # Or if there is a single configuration + else: + params_list = [configs_data] + + gt_dir = args.get('ground_truth_dir') + + for params in tqdm(params_list, desc='Params'): + print(params) + exp_dir = os.path.join(output_dir, '_' + hash_dict(params)) + find_elements(input_files, npy_directory, params, exp_dir, debug=False) + + if gt_dir is not None: + scores = cini_evaluate_folder(exp_dir, gt_dir, debug_folder=os.path.join(exp_dir, '_debug')) + dump_json(os.path.join(exp_dir, 'post_process_config.json'), params) + dump_json(os.path.join(exp_dir, 'scores.json'), scores) + print('Scores : {}'.format(scores)) + + + diff --git a/exps/DIBCO/__init__.py b/exps/DIBCO/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/exps/DIVA/diva_evaluation.py b/exps/DIVA/diva_evaluation.py index e44e8c1..8e46555 100644 --- a/exps/DIVA/diva_evaluation.py +++ b/exps/DIVA/diva_evaluation.py @@ -9,7 +9,7 @@ import numpy as np import cv2 import json -from diva_dataset_generator import MAP_COLORS +from .diva_dataset_generator import MAP_COLORS DIVA_CLASSES = { diff --git a/exps/Ornaments/__init__.py b/exps/Ornaments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/exps/__init__.py b/exps/__init__.py index 98d5bcb..3d8b54e 100644 --- a/exps/__init__.py +++ b/exps/__init__.py @@ -3,9 +3,9 @@ from .Page.page_post_processing import page_post_processing_fn from .cBAD.cbad_post_processing import cbad_post_processing_fn from .DIBCO.dibco_post_processing import dibco_binarization_fn -from .Cini.cini_post_processing import cini_post_processing_fn +#from .Cini.cini_post_processing import cini_post_processing_fn from .DIVA.diva_evaluation import diva_evaluate_folder -from .Cini.cini_evaluation import cini_evaluate_folder +#from .Cini.cini_evaluation import cini_evaluate_folder from .cBAD.cbad_evaluation import cbad_evaluate_folder from .DIBCO.dibco_evaluation import dibco_evaluate_folder from .Ornaments.ornaments_evaluation import ornament_evaluate_folder diff --git a/exps/cBAD/cbad_evaluation.py b/exps/cBAD/cbad_evaluation.py index 82dca2a..7d60aa3 100644 --- a/exps/cBAD/cbad_evaluation.py +++ b/exps/cBAD/cbad_evaluation.py @@ -35,16 +35,21 @@ def cbad_evaluate_folder(output_folder: str, validation_dir: str, verbose=False, gt_dir = os.path.join(validation_dir, 'gt') filenames_processed = glob(os.path.join(output_folder, '*.pkl')) + filenames_processed.extend(glob(os.path.join(output_folder, '*.xml'))) xml_filenames_list = list() for filename in filenames_processed: basename = os.path.basename(filename).split('.')[0] - gt_page = PAGE.parse_file(os.path.join(gt_dir, - '{}.xml'.format(basename))) - - contours, img_shape = load_pickle(filename) - ratio = (gt_page.image_height/img_shape[0], gt_page.image_width/img_shape[1]) + gt_page = PAGE.parse_file(os.path.join(gt_dir, '{}.xml'.format(basename))) xml_filename = os.path.join(tmpdirname, basename + '.xml') + if filename[-4:] == '.pkl': + contours, img_shape = load_pickle(filename) + else: + extracted_page = PAGE.parse_file(filename) + img_shape = (extracted_page.image_height, extracted_page.image_width) + contours = [PAGE.Point.list_to_cv2poly(tl.baseline) + for tr in extracted_page.text_regions for tl in tr.text_lines] + ratio = (gt_page.image_height/img_shape[0], gt_page.image_width/img_shape[1]) PAGE.save_baselines(xml_filename, contours, ratio, initial_shape=img_shape[:2]) gt_xml_file = os.path.join(gt_dir, basename + '.xml') @@ -52,7 +57,7 @@ def cbad_evaluate_folder(output_folder: str, validation_dir: str, verbose=False, if debug_folder is not None: img = imread(os.path.join(validation_dir, 'images', basename+'.jpg')) - img = imresize(img, img_shape[:2]) + img = imresize(img, 1000/img.shape[0]) gt_page.draw_baselines(img, color=(0, 255, 0)) generated_page = PAGE.parse_file(xml_filename) generated_page.draw_baselines(img, color=(255, 0, 0)) diff --git a/exps/cBAD/cbad_post_processing.py b/exps/cBAD/cbad_post_processing.py index 9f3dd84..6be46aa 100644 --- a/exps/cBAD/cbad_post_processing.py +++ b/exps/cBAD/cbad_post_processing.py @@ -7,13 +7,14 @@ from collections import defaultdict import numpy as np from scipy.ndimage import label + import cv2 import os from doc_seg.utils import dump_pickle def cbad_post_processing_fn(probs: np.array, sigma: float=2.5, low_threshold: float=0.8, high_threshold: float=0.9, - output_basename=None): + filter_width: float=0, output_basename=None): """ :param probs: output of the model (probabilities) in range [0, 255] @@ -27,7 +28,7 @@ def cbad_post_processing_fn(probs: np.array, sigma: float=2.5, low_threshold: fl WARNING : contours IN OPENCV format List[np.ndarray(n_points, 1, (x,y))] """ - contours, lines_mask = line_extraction_v1(probs[:, :, 1], sigma, low_threshold, high_threshold) + contours, lines_mask = line_extraction_v1(probs[:, :, 1], sigma, low_threshold, high_threshold, filter_width) if output_basename is not None: dump_pickle(output_basename+'.pkl', (contours, lines_mask.shape)) return contours, lines_mask @@ -46,13 +47,17 @@ def line_extraction_v0(probs, sigma, threshold): return contours, lines_mask -def line_extraction_v1(probs, sigma, low_threshold, high_threshold): - # probs_line = probs[:, :, 1] +def line_extraction_v1(probs, low_threshold, high_threshold, sigma=0.0, filter_width=0.00, vertical_maxima=False): probs_line = probs # Smooth - probs2 = cv2.GaussianBlur(probs_line, (int(3*sigma)*2+1, int(3*sigma)*2+1), sigma) - local_maxima = vertical_local_maxima(probs2) - lines_mask = hysteresis_thresholding(probs2, local_maxima, low_threshold, high_threshold) + if sigma > 0.: + probs2 = cv2.GaussianBlur(probs_line, (int(3*sigma)*2+1, int(3*sigma)*2+1), sigma) + else: + probs2 = cv2.fastNlMeansDenoising((probs_line*255).astype(np.uint8), h=20)/255 + #probs2 = probs_line + #local_maxima = vertical_local_maxima(probs2) + lines_mask = hysteresis_thresholding(probs2, low_threshold, high_threshold, + candidates=vertical_local_maxima(probs2) if vertical_maxima else None) # Remove lines touching border #lines_mask = remove_borders(lines_mask) # Extract polygons from line mask @@ -61,10 +66,11 @@ def line_extraction_v1(probs, sigma, low_threshold, high_threshold): filtered_contours = [] page_width = probs.shape[1] for cnt in contours: - if cv2.arcLength(cnt, False) < 0.05*page_width: - continue - if cv2.arcLength(cnt, False) < 0.05*page_width: + centroid_x, centroid_y = np.mean(cnt, axis=0)[0] + if centroid_x < filter_width*page_width or centroid_x > (1-filter_width)*page_width: continue + # if cv2.arcLength(cnt, False) < filter_width*page_width: + # continue filtered_contours.append(cnt) return filtered_contours, lines_mask @@ -152,12 +158,15 @@ def goal_reached(self, int_index, float_cumcost): def vertical_local_maxima(probs): local_maxima = np.zeros_like(probs, dtype=bool) - local_maxima[1:-1] = (probs[1:-1] > probs[:-2]) & (probs[2:] < probs[1:-1]) - return local_maxima + local_maxima[1:-1] = (probs[1:-1] => probs[:-2]) & (probs[2:] <= probs[1:-1]) + local_maxima = cv2.morphologyEx(local_maxima.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((5, 5), dtype=np.uint8)) + return local_maxima > 0 -def hysteresis_thresholding(probs: np.array, candidates: np.array, low_threshold: float, high_threshold: float): - low_mask = candidates & (probs > low_threshold) +def hysteresis_thresholding(probs: np.array, low_threshold: float, high_threshold: float, candidates=None): + low_mask = probs > low_threshold + if candidates is not None: + low_mask = candidates & low_mask # Connected components extraction label_components, count = label(low_mask, np.ones((3, 3))) # Keep components with high threshold elements diff --git a/exps/cBAD/cbad_process_set.py b/exps/cBAD/cbad_process_set.py index 6175a4b..0f5a5ad 100644 --- a/exps/cBAD/cbad_process_set.py +++ b/exps/cBAD/cbad_process_set.py @@ -6,17 +6,19 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) from doc_seg.loader import LoadedModel -from doc_seg.post_processing.line_detection import line_extraction_v1 +from cbad_post_processing import line_extraction_v1 +from cbad_evaluation import cbad_evaluate_folder import tensorflow as tf from tqdm import tqdm import numpy as np import argparse from glob import glob -from scipy.misc import imsave, imread +from scipy.misc import imread, imresize, imsave import tempfile import json from doc_seg.post_processing import PAGE +from doc_seg.utils import hash_dict, dump_json def predict_on_set(filenames_to_predict, model_dir, output_dir): @@ -35,7 +37,7 @@ def predict_on_set(filenames_to_predict, model_dir, output_dir): np.uint8(255 * pred)) -def find_lines(img_filenames, dir_predictions, post_process_params, output_dir, debug=False): +def find_lines(img_filenames, dir_predictions, post_process_params, output_dir, debug=False, mask_dir: str=None): """ :param img_filenames: @@ -45,14 +47,23 @@ def find_lines(img_filenames, dir_predictions, post_process_params, output_dir, :return: """ + os.makedirs(output_dir, exist_ok=True) + for filename in tqdm(img_filenames, 'Post-processing'): orig_img = imread(filename, mode='RGB') basename = os.path.basename(filename).split('.')[0] filename_pred = os.path.join(dir_predictions, basename + '.npy') - pred = np.load(filename_pred) + pred = np.load(filename_pred)/255 # type: np.ndarray + lines_prob = pred[:, :, 1] + + if mask_dir is not None: + mask = imread(os.path.join(mask_dir, basename + '.png'), mode='L') + mask = imresize(mask, lines_prob.shape) + lines_prob[mask == 0] = 0. + + contours, lines_mask = line_extraction_v1(lines_prob, **post_process_params) - contours, lines_mask = line_extraction_v1(pred[:, :, 1], **post_process_params) if debug: imsave(os.path.join(output_dir, '{}_bin.jpg'.format(basename)), lines_mask) @@ -67,17 +78,21 @@ def find_lines(img_filenames, dir_predictions, post_process_params, output_dir, if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-m', '--model_dir', type=str, required=True, + parser.add_argument('-m', '--model-dir', type=str, required=True, help='Directory of the model (should be of type ''*/export/)') - parser.add_argument('-i', '--input_files', type=str, required=True, nargs='+', + parser.add_argument('-i', '--input-files', type=str, required=True, nargs='+', help='Folder containing the images to evaluate the model on') - parser.add_argument('-o', '--output_dir', type=str, required=True, + parser.add_argument('-o', '--output-dir', type=str, required=True, help='Folder containing the outputs (.npy predictions and visualization errors)') - parser.add_argument('--post_process_params', type=str, default=None, + parser.add_argument('-gt', '--ground_truth_dir', type=str, required=True, + help='Ground truth directory containing the abeled images') + parser.add_argument('--params-file', type=str, default=None, help='JSOn file containing the params for post-processing') + parser.add_argument('--mask-dir', type=str, default=None, + help='Folder with the binary masks if available for predictions') parser.add_argument('--gpu', type=str, default='0', help='Which GPU to use') - # parser.add_argument('-pp', '--post_proces_only', default=False, action='store_true', - # help='Whether to make or not the prediction') + parser.add_argument('-pp', '--post-process-only', default=False, action='store_true', + help='Whether to make or not the prediction') args = parser.parse_args() args = vars(args) @@ -89,23 +104,44 @@ def find_lines(img_filenames, dir_predictions, post_process_params, output_dir, output_dir = args.get('output_dir') os.makedirs(output_dir, exist_ok=True) - post_process_params = args.get('post_proces_params') - - if post_process_params is not None: - with open(post_process_params, 'r') as f: - post_process_params = json.load(f) - post_process_params = post_process_params['params'] - else: - post_process_params = {"low_threshold": 0.1, "sigma": 2.5, "high_threshold": 0.3} # Prediction - with tempfile.TemporaryDirectory() as tmpdirname: - npy_directory = tmpdirname + npy_directory = output_dir + if not args.get('post_process_only'): predict_on_set(input_files, model_dir, npy_directory) - npy_files = glob(os.path.join(npy_directory, '*.npy')) + npy_files = glob(os.path.join(npy_directory, '*.npy')) - find_lines(input_files, npy_directory, post_process_params, output_dir, debug=True) + if args.get('params_file') is None: + print('No params file found') + params_list = [{"low_threshold": 0.25, "high_threshold": 0.6}] + else: + with open(args.get('params_file'), 'r') as f: + configs_data = json.load(f) + # If the file contains a list of configurations + if 'configs' in configs_data.keys(): + params_list = configs_data['configs'] + assert isinstance(params_list, list) + # Or if there is a single configuration + else: + params_list = [configs_data] + + gt_dir = args.get('ground_truth_dir') + if gt_dir is not None: + assert os.path.basename(gt_dir) == 'gt' + gt_dir = os.path.join(gt_dir, os.path.pardir) + + for params in tqdm(params_list, desc='Params'): + print(params) + exp_dir = os.path.join(output_dir, '_' + hash_dict(params)) + find_lines(input_files, npy_directory, params, exp_dir, + debug=False, mask_dir=args.get('mask_dir')) + + if gt_dir is not None: + scores = cbad_evaluate_folder(exp_dir, gt_dir, debug_folder=os.path.join(exp_dir, '_debug')) + dump_json(os.path.join(exp_dir, 'post_process_config.json'), params) + dump_json(os.path.join(exp_dir, 'scores.json'), scores) + print('Scores : {}'.format(scores)) diff --git a/exps/post_processing/PAGE.py b/exps/post_processing/PAGE.py index 30f37c5..7904fbb 100644 --- a/exps/post_processing/PAGE.py +++ b/exps/post_processing/PAGE.py @@ -100,16 +100,6 @@ def from_array(cls, cv2_coords: np.array=None, baseline_coords: np.array=None, text_equiv=text_equiv ) - @classmethod - def from_coords_array(cls, coords: np.array=None, baseline_coords: np.array=None, # shape [N, 1, 2] - text_equiv: str=None, id: str=None): - return TextLine( - id=id, - coords=Point.arr_to_point_list(coords) if coords is not None else [], - baseline=Point.arr_to_point_list(baseline_coords) if baseline_coords is not None else [], - text_equiv=text_equiv - ) - @classmethod def from_cv2_array(cls, cv2_coords: np.array=None, baseline_coords: np.array=None, # shape [N, 1, 2] text_equiv: str=None, id: str=None): diff --git a/model_selection.py b/model_selection.py index 6f8a473..1d13abb 100644 --- a/model_selection.py +++ b/model_selection.py @@ -55,7 +55,7 @@ # Perform test prediction (is it the right place?) test_folder = args.get('test_folder') - for i, best_experiment in enumerate(sorted_experiments[:1]): + for i, best_experiment in enumerate(sorted_experiments[:10]): print(best_experiment) print('Validation :') print(best_experiment.get_best_validated_epoch()) @@ -68,7 +68,7 @@ if len(test_files) == 0: print('No files in ', test_folder) continue - if args.get('eval_only'): + if not args.get('eval_only'): with tf.Graph().as_default(), tf.Session() as sess: m = LoadedModel(model_folder, input_dict_key='filename') for filename in tqdm(test_files, desc='Test images'): diff --git a/scripts/cBAD/cbad_process_set.py b/scripts/cBAD/cbad_process_set.py index 84aaea4..91a72f6 100644 --- a/scripts/cBAD/cbad_process_set.py +++ b/scripts/cBAD/cbad_process_set.py @@ -35,7 +35,7 @@ def predict_on_set(filenames_to_predict, model_dir, output_dir): np.uint8(255 * pred)) -def find_lines(img_filenames, dir_predictions, post_process_params, output_dir, debug=False): +def find_lines(img_filenames, dir_predictions, post_process_params, output_dir, debug=False, mask_dir: str=None): """ :param img_filenames: @@ -51,8 +51,13 @@ def find_lines(img_filenames, dir_predictions, post_process_params, output_dir, filename_pred = os.path.join(dir_predictions, basename + '.npy') pred = np.load(filename_pred) + lines_prob = pred[:, :, 1] - contours, lines_mask = line_extraction_v1(pred[:, :, 1], **post_process_params) + if mask_dir is not None: + mask = imread(os.path.join(mask_dir, basename + '_bin.png'), mode='L') + lines_prob[mask == 0] = 0. + + contours, lines_mask = line_extraction_v1(lines_prob, **post_process_params) if debug: imsave(os.path.join(output_dir, '{}_bin.jpg'.format(basename)), lines_mask)