diff --git a/.gitignore b/.gitignore index 2b08b63..1142069 100755 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,23 @@ checkpoints/ *.musicxml *.mp3 -*.swp \ No newline at end of file +*.swp + +# Model training datasets +/ds2_dense +/CvcMuscima-Distortions + +# Model training checkpoints and outputs +/seg_unet +/test_data +/train_data +/*.model +/*.h5 +/*.json + +/segnet_* +/unet_* +/rests_* +/all_rests_* +/sfn_* +/clef_* \ No newline at end of file diff --git a/README.md b/README.md index 8b5b6c7..26d6db1 100755 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ pip install oemer pip install oemer[tf] # (optional) Or install the newest updates directly from Github. -pip install https://github.com/BreezeWhite/oemer +pip install git+https://github.com/BreezeWhite/oemer # Run oemer diff --git a/main.py b/main.py new file mode 100644 index 0000000..aee9622 --- /dev/null +++ b/main.py @@ -0,0 +1,3 @@ +from oemer import ete + +ete.main() \ No newline at end of file diff --git a/oemer/bbox.py b/oemer/bbox.py index 60baa3b..9ba88d3 100755 --- a/oemer/bbox.py +++ b/oemer/bbox.py @@ -2,6 +2,7 @@ from typing import Union, Any, List, Tuple, Dict import cv2 +from cv2.typing import RotatedRect import numpy as np from numpy import ndarray from sklearn.cluster import AgglomerativeClustering @@ -118,11 +119,12 @@ def find_lines(data: ndarray, min_len: int = 10, max_gap: int = 20) -> List[BBox lines = cv2.HoughLinesP(data.astype(np.uint8), 1, np.pi/180, 50, None, min_len, max_gap) new_line = [] - for line in lines: - line = line[0] - top_x, bt_x = (line[0], line[2]) if line[0] < line[2] else (line[2], line[0]) - top_y, bt_y = (line[1], line[3]) if line[1] < line[3] else (line[3], line[1]) - new_line.append((top_x, top_y, bt_x, bt_y)) + if lines is not None: + for line in lines: + line = line[0] + top_x, bt_x = (line[0], line[2]) if line[0] < line[2] else (line[2], line[0]) + top_y, bt_y = (line[1], line[3]) if line[1] < line[3] else (line[3], line[1]) + new_line.append((top_x, top_y, bt_x, bt_y)) return new_line @@ -159,7 +161,7 @@ def draw_bounding_boxes( return img -def get_rotated_bbox(data: ndarray) -> List[Tuple[Tuple[float, float], Tuple[float, float], float]]: +def get_rotated_bbox(data: ndarray) -> List[RotatedRect]: contours, _ = cv2.findContours(data.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) bboxes = [] for cnt in contours: diff --git a/oemer/build_label.py b/oemer/build_label.py index 1490266..6026696 100755 --- a/oemer/build_label.py +++ b/oemer/build_label.py @@ -1,3 +1,4 @@ +import sys import os import random from PIL import Image @@ -5,10 +6,11 @@ import cv2 import numpy as np -from .constant_min import CLASS_CHANNEL_MAP +from .constant_min import CLASS_CHANNEL_MAP, CHANNEL_NUM +from .dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF -HALF_WHOLE_NOTE = [39, 41, 42, 43, 45, 46, 47, 49] +HALF_WHOLE_NOTE = DEF.NOTEHEADS_HOLLOW + DEF.NOTEHEADS_WHOLE + [42] def fill_hole(gt, tar_color): @@ -75,12 +77,12 @@ def build_label(seg_path): color_set = set(np.unique(arr)) color_set.remove(0) # Remove background color from the candidates - total_chs = len(set(CLASS_CHANNEL_MAP.values())) + 2 # Plus 'background' and 'others' channel. + total_chs = CHANNEL_NUM output = np.zeros(arr.shape + (total_chs,)) output[..., 0] = np.where(arr==0, 1, 0) for color in color_set: - ch = CLASS_CHANNEL_MAP.get(color, -1) + ch = CLASS_CHANNEL_MAP.get(color, 0) if (ch != 0) and color in HALF_WHOLE_NOTE: note = fill_hole(arr, color) output[..., ch] += note @@ -101,12 +103,7 @@ def find_example(dataset_path: str, color: int, max_count=100, mark_value=200): if __name__ == "__main__": - seg_folder = '/media/kohara/ADATA HV620S/dataset/ds2_dense/segmentation' - files = os.listdir(seg_folder) - path = os.path.join(seg_folder, random.choice(files)) - #out = build_label(path) - - color = 45 - arr = find_example(color) # type: ignore - arr = np.where(arr==200, color, arr) - out = fill_hole(arr, color) + seg_folder = 'ds2_dense/segmentation' + color = int(sys.argv[1]) + with_background, without_background = find_example(seg_folder, color) + cv2.imwrite("example.png", with_background) diff --git a/oemer/classifier.py b/oemer/classifier.py index 8d8536f..d94a1ee 100755 --- a/oemer/classifier.py +++ b/oemer/classifier.py @@ -58,6 +58,7 @@ def _collect(color, out_path, samples=100): img = imaugs.resize(Image.fromarray(patch.astype(np.uint8)), width=tar_w, height=tar_h) seed = random.randint(0, 1000) + np.float = float # Monkey patch to workaround removal of np.float img = imaugs.perspective_transform(img, seed=seed, sigma=3) img = np.where(np.array(img)>0, 255, 0) Image.fromarray(img.astype(np.uint8)).save(out_path / f"{idx}.png") @@ -118,10 +119,12 @@ def train(folders): model.fit(train_x, train_y) return model, class_map +def build_class_map(folders): + return {idx: Path(ff).name for idx, ff in enumerate(folders)} def train_tf(folders): import tensorflow as tf - class_map = {idx: Path(ff).name for idx, ff in enumerate(folders)} + class_map = build_class_map(folders) train_x = [] train_y = [] samples = None @@ -234,6 +237,53 @@ def predict(region, model_name): pred = model.predict(np.array(region).reshape(1, -1)) return m_info['class_map'][pred[0]] +def train_rests_above8(filename = "rests_above8.model"): + folders = ["rest_8th", "rest_16th", "rest_32nd", "rest_64th"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_rests(filename = "rests.model"): + folders = ["rest_whole", "rest_quarter", "rest_8th"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_all_rests(filename = "all_rests.model"): + folders = ["rest_whole", "rest_quarter", "rest_8th", "rest_16th", "rest_32nd", "rest_64th"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_sfn(filename = "sfn.model"): + folders = ["sharp", "flat", "natural"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_clefs(filename = "clef.model"): + folders = ["gclef", "fclef"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(filename, "wb")) + + +def train_noteheads(): + folders = ["notehead_solid", "notehead_hollow"] + model, class_map = train_tf([f"train_data/{folder}" for folder in folders]) + test_tf(model, [f"test_data/{folder}" for folder in folders]) + output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map} + pickle.dump(output, open(f"notehead.model", "wb")) + if __name__ == "__main__": samples = 400 diff --git a/oemer/constant.py b/oemer/constant.py index 3cc789a..3d7a320 100755 --- a/oemer/constant.py +++ b/oemer/constant.py @@ -1,21 +1,23 @@ from enum import Enum, auto +from oemer.dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF + CLASS_CHANNEL_LIST = [ - [165, 2], # staff, ledgerLine - [35, 37, 38], # noteheadBlack - [39, 41, 42], # noteheadHalf - [43, 45, 46, 47, 49], # noteheadWhole - [64, 58, 59, 60, 66, 63, 69, 68, 61, 62, 67, 65], # flags - [146, 51], # beam, augmentationDot - [3, 52], # barline, stem - [74, 70, 72, 76], # accidentalSharp, accidentalFlat, accidentalNatural, accidentalDoubleSharp - [80, 78, 79], # keySharp, keyFlat, keyNatural - [97, 100, 99, 98, 101, 102, 103, 104, 96, 163], # rests - [136, 156, 137, 155, 152, 151, 153, 154, 149, 155], # tuplets - [145, 147], # slur, tie - [10, 13, 12, 19, 11, 20], # clefs - [25, 24, 29, 22, 23, 28, 27, 34, 30, 21, 33, 26], # timeSigs + DEF.STAFF + DEF.LEDGERLINE, + DEF.NOTEHEADS_SOLID + [38], + DEF.NOTEHEADS_HOLLOW + [42], + DEF.NOTEHEADS_WHOLE + [46], + DEF.FLAG_DOWN + DEF.FLAG_UP + [59, 65], + DEF.BEAM + DEF.DOT, + DEF.BARLINE_BETWEEN + DEF.STEM, + DEF.ALL_ACCIDENTALS, + DEF.ALL_KEYS, + DEF.ALL_RESTS + [163], + DEF.TUPETS, + DEF.SLUR_AND_TIE, + DEF.ALL_CLEFS + DEF.NUMBERS, + DEF.TIME_SIGNATURE_SUBSET ] CLASS_CHANNEL_MAP = { diff --git a/oemer/constant_min.py b/oemer/constant_min.py index ea64b4d..eb761b4 100755 --- a/oemer/constant_min.py +++ b/oemer/constant_min.py @@ -1,13 +1,10 @@ +from oemer.dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF + + CLASS_CHANNEL_LIST = [ - [165, 2], # staff, ledgerLine - [35, 37, 38, 39, 41, 42, 43, 45, 46, 47, 49, 52], # notehead, stem - [ - 64, 58, 60, 66, 63, 69, 68, 61, 62, 67, 65, 59, 146, # flags, beam - 97, 100, 99, 98, 101, 102, 103, 104, 96, 163, # rests - 80, 78, 79, 74, 70, 72, 76, 3, # sharp, flat, natural, barline - 10, 13, 12, 19, 11, 20, 51, # clefs, augmentationDot, - 25, 24, 29, 22, 23, 28, 27, 34, 30, 21, 33, 26, # timeSigs - ] + DEF.STEM + DEF.ALL_RESTS_EXCEPT_LARGE + DEF.BARLINE_BETWEEN + DEF.BARLINE_END, + DEF.NOTEHEADS_ALL, + DEF.ALL_CLEFS + DEF.ALL_KEYS + DEF.ALL_ACCIDENTALS, ] CLASS_CHANNEL_MAP = { @@ -16,4 +13,4 @@ for color in colors } -CHANNEL_NUM = len(CLASS_CHANNEL_LIST) + 2 +CHANNEL_NUM = len(CLASS_CHANNEL_LIST) + 1 # Plus 'background' and 'others' channel. diff --git a/oemer/dense_dataset_definitions.py b/oemer/dense_dataset_definitions.py new file mode 100644 index 0000000..b0fa1a3 --- /dev/null +++ b/oemer/dense_dataset_definitions.py @@ -0,0 +1,78 @@ +class Symbols: + BACKGROUND = [0] + LEDGERLINE = [2] + BARLINE_BETWEEN = [3] + BARLINE_END = [4] + ALL_BARLINES = BARLINE_BETWEEN + BARLINE_END + REPEAT_DOTS = [7] + G_GLEF = [10] + C_CLEF = [11, 12] + F_CLEF = [13] + ALL_CLEFS = G_GLEF + C_CLEF + F_CLEF + NUMBERS = [19, 20] + TIME_SIGNATURE_SUBSET = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33, 34] + TIME_SIGNATURE = TIME_SIGNATURE_SUBSET + [31, 32] # Oemer hasn't used these in the past + NOTEHEAD_FULL_ON_LINE = [35] + UNKNOWN = [36, 38, 40, 128, 143, 144, 148, 150, 157, 159, 160, 161, 162, 163, 164, 167, 170, 171] + NOTEHEAD_FULL_BETWEEN_LINES = [37] + NOTEHEAD_HOLLOW_ON_LINE = [39] + NOTEHEAD_HOLLOW_BETWEEN_LINE = [41] + WHOLE_NOTE_ON_LINE = [43] + WHOLE_NOTE_BETWEEN_LINE = [45] + DOUBLE_WHOLE_NOTE_ON_LINE = [47] + DOUBLE_WHOLE_NOTE_BETWEEN_LINE = [49] + NOTEHEADS_SOLID = NOTEHEAD_FULL_ON_LINE + NOTEHEAD_FULL_BETWEEN_LINES + NOTEHEADS_HOLLOW = NOTEHEAD_HOLLOW_ON_LINE + NOTEHEAD_HOLLOW_BETWEEN_LINE + NOTEHEADS_WHOLE = WHOLE_NOTE_ON_LINE + WHOLE_NOTE_BETWEEN_LINE + DOUBLE_WHOLE_NOTE_ON_LINE + DOUBLE_WHOLE_NOTE_BETWEEN_LINE + NOTEHEADS_ALL = NOTEHEAD_FULL_ON_LINE + NOTEHEAD_FULL_BETWEEN_LINES + NOTEHEAD_HOLLOW_ON_LINE + NOTEHEAD_HOLLOW_BETWEEN_LINE + WHOLE_NOTE_ON_LINE + WHOLE_NOTE_BETWEEN_LINE + DOUBLE_WHOLE_NOTE_ON_LINE + DOUBLE_WHOLE_NOTE_BETWEEN_LINE + DOT = [51] + STEM = [52] + TREMOLO = [53, 54, 55, 56] + FLAG_DOWN = [58, 60, 61, 62, 63] + FLAG_UP = [64, 66, 67, 68, 69] + FLAT = [70] + NATURAL = [72] + SHARP = [74] + DOUBLE_SHARP = [76] + ALL_ACCIDENTALS = FLAT + NATURAL + SHARP + DOUBLE_SHARP + KEY_FLAT = [78] + KEY_NATURAL = [79] + KEY_SHARP = [80] + ALL_KEYS = KEY_FLAT + KEY_NATURAL + KEY_SHARP + ACCENT_ABOVE = [81] + ACCENT_BELOW = [82] + STACCATO_ABOVE = [83] + STACCATO_BELOW = [84] + TENUTO_ABOVE = [85] + TENUTO_BELOW = [86] + STACCATISSIMO_ABOVE = [87] + STACCATISSIMO_BELOW = [88] + MARCATO_ABOVE = [89] + MARCATO_BELOW = [90] + FERMATA_ABOVE = [91] + FERMATA_BELOW = [92] + BREATH_MARK = [93] + REST_LARGE = [95] + REST_LONG = [96] + REST_BREVE = [97] + REST_FULL = [98] + REST_QUARTER = [99] + REST_EIGHTH = [100] + REST_SIXTEENTH = [101] + REST_THIRTY_SECOND = [102] + REST_SIXTY_FOURTH = [103] + REST_ONE_HUNDRED_TWENTY_EIGHTH = [104] + ALL_RESTS_EXCEPT_LARGE = REST_LONG + REST_BREVE + REST_FULL + REST_QUARTER + REST_EIGHTH + REST_SIXTEENTH + REST_THIRTY_SECOND + REST_SIXTY_FOURTH + REST_ONE_HUNDRED_TWENTY_EIGHTH + ALL_RESTS = ALL_RESTS_EXCEPT_LARGE + TRILL = [127] + GRUPPETO = [129] + MORDENT = [130] + DOWN_BOW = [131] + UP_BOW = [132] + SYMBOL = [133, 134, 135, 138, 139, 141, 142] + TUPETS = [136, 137, 149, 151, 152, 153, 154, 155, 156] + SLUR_AND_TIE = [145, 147] + BEAM = [146] + STAFF = [165] + +DENSE_DATASET_DEFINITIONS = Symbols() \ No newline at end of file diff --git a/oemer/inference.py b/oemer/inference.py index 74db2c0..082815c 100755 --- a/oemer/inference.py +++ b/oemer/inference.py @@ -58,8 +58,8 @@ def inference( image_pil = Image.open(img_path) if "GIF" != image_pil.format: # Tricky workaround to avoid random mistery transpose when loading with 'Image'. - image_pil = cv2.imread(img_path) - image_pil = Image.fromarray(image_pil) + image_cv = cv2.imread(img_path) + image_pil = Image.fromarray(image_cv) image_pil = image_pil.convert("RGB") image = np.array(resize_image(image_pil)) diff --git a/oemer/models/unet.py b/oemer/models/unet.py index 66efd1d..ae6ee92 100755 --- a/oemer/models/unet.py +++ b/oemer/models/unet.py @@ -142,10 +142,19 @@ def my_conv_block(inp, kernels, kernel_size=(3, 3), strides=(1, 1)): return out +def my_conv_small_block(inp, kernels, kernel_size=(3, 3), strides=(1, 1)): + inp = L.Conv2D(kernels, kernel_size, strides=strides, padding='same', dtype=tf.float32)(inp) + out = L.Activation("relu")(L.LayerNormalization()(inp)) + out = L.Dropout(0.3)(out) + out = L.Add()([inp, out]) + out = L.Activation("relu")(L.LayerNormalization()(out)) + return out + + def my_trans_conv_block(inp, kernels, kernel_size=(3, 3), strides=(1, 1)): inp = L.Conv2DTranspose(kernels, kernel_size, strides=strides, padding='same', dtype=tf.float32)(inp) - out = L.Activation("relu")(L.LayerNormalization()(inp)) - out = L.Conv2D(kernels, kernel_size, padding='same', dtype=tf.float32)(out) + #out = L.Activation("relu")(L.LayerNormalization()(inp)) + out = L.Conv2D(kernels, kernel_size, padding='same', dtype=tf.float32)(inp) out = L.Activation("relu")(L.LayerNormalization()(out)) out = L.Dropout(0.3)(out) out = L.Add()([inp, out]) @@ -157,25 +166,25 @@ def u_net(win_size=288, out_class=3): inp = L.Input(shape=(win_size, win_size, 3)) tensor = L.SeparableConv2D(128, (3, 3), activation="relu", padding='same')(inp) - l1 = my_conv_block(tensor, 64, (3, 3), strides=(2, 2)) # 128 - l1 = my_conv_block(l1, 128, (3, 3)) - l1 = my_conv_block(l1, 128, (3, 3)) + l1 = my_conv_small_block(tensor, 64, (3, 3), strides=(2, 2)) + l1 = my_conv_small_block(l1, 64, (3, 3)) + l1 = my_conv_small_block(l1, 64, (3, 3)) - skip = my_conv_block(l1, 128, (3, 3), strides=(2, 2)) # 64 - l2 = my_conv_block(skip, 128, (3, 3)) - l2 = my_conv_block(l2, 128, (3, 3)) - l2 = my_conv_block(l2, 128, (3, 3)) - l2 = my_conv_block(l2, 128, (3, 3)) + skip = my_conv_small_block(l1, 128, (3, 3), strides=(2, 2)) + l2 = my_conv_small_block(skip, 128, (3, 3)) + l2 = my_conv_small_block(l2, 128, (3, 3)) + l2 = my_conv_small_block(l2, 128, (3, 3)) + l2 = my_conv_small_block(l2, 128, (3, 3)) l2 = L.Concatenate()([skip, l2]) - l3 = my_conv_block(l2, 256, (3, 3)) - l3 = my_conv_block(l3, 256, (3, 3)) - l3 = my_conv_block(l3, 256, (3, 3)) - l3 = my_conv_block(l3, 256, (3, 3)) - l3 = my_conv_block(l3, 256, (3, 3)) + l3 = my_conv_small_block(l2, 256, (3, 3)) + l3 = my_conv_small_block(l3, 256, (3, 3)) + l3 = my_conv_small_block(l3, 256, (3, 3)) + l3 = my_conv_small_block(l3, 256, (3, 3)) + l3 = my_conv_small_block(l3, 256, (3, 3)) l3 = L.Concatenate()([l2, l3]) - bot = my_conv_block(l3, 256, (3, 3), strides=(2, 2)) # 32 + bot = my_conv_small_block(l3, 256, (3, 3), strides=(2, 2)) # 32 st1 = L.SeparableConv2D(256, (3, 3), padding='same', dtype=tf.float32)(bot) st1 = L.Activation("relu")(L.LayerNormalization()(st1)) st2 = L.SeparableConv2D(256, (3, 3), dilation_rate=(2, 2), padding='same', dtype=tf.float32)(bot) @@ -189,20 +198,23 @@ def u_net(win_size=288, out_class=3): norm = L.Activation("relu")(L.LayerNormalization()(st)) bot = my_trans_conv_block(norm, 256, (3, 3), strides=(2, 2)) # 64 - tl3 = L.Conv2D(256, (3, 3), padding='same', dtype=tf.float32)(bot) + tl3 = L.Conv2D(128, (3, 3), padding='same', dtype=tf.float32)(bot) tl3 = L.Activation("relu")(L.LayerNormalization()(tl3)) tl3 = L.Concatenate()([tl3, l3]) + tl3 = my_conv_small_block(tl3, 128, (3, 3)) tl3 = my_trans_conv_block(tl3, 128, (3, 3)) # Head 1 tl2 = L.Conv2D(128, (3, 3), padding='same', dtype=tf.float32)(tl3) tl2 = L.Activation("relu")(L.LayerNormalization()(tl2)) tl2 = L.Concatenate()([tl2, l2]) + tl2 = my_conv_small_block(tl2, 128, (3, 3)) tl2 = my_trans_conv_block(tl2, 128, (3, 3), strides=(2, 2)) # 128 tl1 = L.Conv2D(128, (3, 3), padding='same', dtype=tf.float32)(tl2) tl1 = L.Activation("relu")(L.LayerNormalization()(tl1)) tl1 = L.Concatenate()([tl1, l1]) + tl1 = my_conv_small_block(tl1, 128, (3, 3)) tl1 = my_trans_conv_block(tl1, 128, (3, 3), strides=(2, 2)) # 256 out1 = L.Conv2D(out_class, (1, 1), activation='softmax', padding='same', dtype=tf.float32)(tl1) diff --git a/oemer/notehead_extraction.py b/oemer/notehead_extraction.py index 13b8174..21eeb28 100755 --- a/oemer/notehead_extraction.py +++ b/oemer/notehead_extraction.py @@ -306,7 +306,7 @@ def fill_hole(region: ndarray) -> ndarray: def gen_notes(bboxes: List[ndarray], symbols: ndarray) -> List[NoteHead]: notes = [] for bbox in bboxes: - # Instanitiate notehead. + # Instantiate notehead. nn = NoteHead() nn.bbox = typing.cast(BBox, bbox) @@ -343,7 +343,7 @@ def assign_group_track(st: Staff) -> None: # The value could also be negative. The zero index starts from the position # same as D4, assert the staffline is in treble clef. The value increases # as the pitch goes up. - # Build centers of each postion first. + # Build centers of each position first. step = st_master.unit_size / 2 pos_cen = [l.y_center for l in st_master.lines[::-1]] tmp_inter = [] @@ -354,7 +354,7 @@ def assign_group_track(st: Staff) -> None: pos_cen.insert(idx*2+1, interp) pos_cen = [pos_cen[0]+step] + pos_cen + [pos_cen[-1]-step] - # Estimate position by the closeset center. + # Estimate position by the closest center. pos_idx = np.argmin(np.abs(np.array(pos_cen)-cen_y)) if 0 < pos_idx < len(pos_cen)-1: nn.staff_line_pos = int(pos_idx) diff --git a/oemer/rhythm_extraction.py b/oemer/rhythm_extraction.py index 26bf8f2..73d4e59 100755 --- a/oemer/rhythm_extraction.py +++ b/oemer/rhythm_extraction.py @@ -3,6 +3,7 @@ import math import cv2 +from cv2.typing import RotatedRect import scipy.ndimage import numpy as np from numpy import ndarray @@ -37,7 +38,12 @@ def scan_dot( # Find the right most bound for scan the dot. # Should have width less than unit_size, and can't # touch the nearby note. - cur_scan_line = note_id_map[int(start_y):int(bbox[3]), int(right_bound)] + try: + cur_scan_line = note_id_map[int(start_y):int(bbox[3]), int(right_bound)] + except IndexError as e: + print(e) + break + ids = set(np.unique(cur_scan_line)) if -1 in ids: ids.remove(-1) @@ -134,7 +140,7 @@ def parse_beams( min_area_ratio: float = 0.07, min_tp_ratio: float = 0.4, min_width_ratio: float = 0.2 -) -> Tuple[ndarray, List[Tuple[Tuple[float, float], Tuple[float, float], float]], ndarray]: +) -> Tuple[ndarray, List[RotatedRect], ndarray]: # Fetch parameters symbols = layers.get_layer('symbols_pred') staff_pred = layers.get_layer('staff_pred') @@ -156,14 +162,14 @@ def parse_beams( ratio_map = np.copy(poly_map) null_color = (255, 255, 255) - valid_box = [] + valid_box: List[RotatedRect] = [] valid_idxs = [] idx_map = np.zeros_like(poly_map) - 1 - for idx, rbox in enumerate(rboxes): # type: ignore + for box_idx, rbox in enumerate(rboxes): # type: ignore # Used to find indexes of contour areas later. Must be check before # any 'continue' statement. - idx %= 255 # type: ignore - if idx == 0: + box_idx %= 255 # type: ignore + if box_idx == 0: idx_map = np.zeros_like(poly_map) - 1 # Get the contour of the rotated box @@ -186,8 +192,8 @@ def parse_beams( continue # Tricky way to get the index of the contour area - cv2.fillPoly(idx_map, [cnt], color=(idx, 0, 0)) - yi, xi = np.where(idx_map[..., 0] == idx) + cv2.fillPoly(idx_map, [cnt], color=(box_idx, 0, 0)) + yi, xi = np.where(idx_map[..., 0] == box_idx) pts = beams[yi, xi] meta_idx = np.where(pts>0)[0] ryi = yi[meta_idx] @@ -429,7 +435,9 @@ def get_label(nbox, stem_up): end_x=min(set_box[2], cen_x+half_scan_width), end_y=end_y, threshold=threshold - ) + ) + if count >= len(note_type_map): + return note_type_map[len(note_type_map) - 1] return note_type_map[count] if len(nts) == 2: @@ -501,7 +509,7 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t rev_map_info[gid] = {'reg': reg, 'bbox': box} # Define beam count to note type mapping - note_type_map = { + note_type_map: Dict[int, NoteType] = { 0: NoteType.QUARTER, 1: NoteType.EIGHTH, 2: NoteType.SIXTEENTH, @@ -580,7 +588,7 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t end_y = gbox[3] # Calculate how many beams/flags are there. - count = scan_beam_flag( # type: ignore + beam_flag_count = scan_beam_flag( # type: ignore bin_beam_map, max(reg_box[0], cen_x-half_scan_width), start_y, @@ -590,12 +598,15 @@ def parse_rhythm(beam_map: ndarray, map_info: Dict[int, Dict[str, Any]], agree_t ) #cv2.rectangle(beam_img, (gbox[0], gbox[1]), (gbox[2], gbox[3]), (255, 0, 255), 1) - cv2.putText(beam_img, str(count), (int(cen_x), int(gbox[3])+2), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1) + cv2.putText(beam_img, str(beam_flag_count), (int(cen_x), int(gbox[3])+2), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 1) # Assign note label for nid in group.note_ids: if notes[nid].label is None: - notes[nid].label = note_type_map[count] # type: ignore + if beam_flag_count in note_type_map: + notes[nid].label = note_type_map[beam_flag_count] + else: + notes[nid].invalid = True return beam_img @@ -605,7 +616,7 @@ def extract( dot_max_area_ratio: float = 0.2, beam_min_area_ratio: float = 0.07, agree_th: float = 0.15 -) -> Tuple[ndarray, List[Tuple[Tuple[float, float], Tuple[float, float], float]]]: +) -> Tuple[ndarray, List[RotatedRect]]: logger.debug("Parsing dot") parse_dot(max_area_ratio=dot_max_area_ratio, min_area_ratio=dot_min_area_ratio) diff --git a/oemer/staffline_extraction.py b/oemer/staffline_extraction.py index 70bd92d..fc4b24f 100755 --- a/oemer/staffline_extraction.py +++ b/oemer/staffline_extraction.py @@ -331,7 +331,7 @@ def extract( # Start process zones, *_ = init_zones(staff_pred, splits=splits) - all_staffs = [] + all_staffs: List[List[Staff]] = [] for rr in zones: print(rr[0], rr[-1], end=' ') rr = np.array(rr, dtype=np.int64) @@ -339,43 +339,45 @@ def extract( if staffs is not None: all_staffs.append(staffs) print(len(staffs)) - all_staffs = align_staffs(all_staffs) # type: ignore + aligned_staffs: np.ndarray = align_staffs(all_staffs) # Use barline information to infer the number of tracks for each group. - num_track = further_infer_track_nums(all_staffs, min_degree=barline_min_degree) # type: ignore + num_track = further_infer_track_nums(aligned_staffs, min_degree=barline_min_degree) logger.debug(f"Tracks: {num_track}") - for col_sts in all_staffs: + for col_sts in aligned_staffs: for idx, st in enumerate(col_sts): st.track = idx % num_track st.group = idx // num_track # Validate staffs across zones. # Should have same number of staffs - if not all([len(staff) == len(all_staffs[0]) for staff in all_staffs]): + if not all([len(staff) == len(aligned_staffs[0]) for staff in aligned_staffs]): raise Exception - assert all([len(staff) == len(all_staffs[0]) for staff in all_staffs]) + assert all([len(staff) == len(aligned_staffs[0]) for staff in aligned_staffs]) norm = lambda data: np.abs(np.array(data) / np.mean(data) - 1) - for staffs in all_staffs.T: # type: ignore + valid_staffs: list[list[Staff]] = [] + for staffs in aligned_staffs.T: # Should all have 5 lines line_num = [len(staff.lines) for staff in staffs] if len(set(line_num)) != 1: - raise E.StafflineCountInconsistent( - f"Some of the stafflines contains less or more than 5 lines: {line_num}") + print(f"Some of the stafflines contains less or more than 5 lines: {line_num}") + continue # Check Staffs that are approximately at the same row. centers = np.array([staff.y_center for staff in staffs]) if not np.all(norm(centers) < horizontal_diff_th): - raise E.StafflineNotAligned( - f"Centers of staff parts at the same row not aligned (Th: {horizontal_diff_th}): {norm(centers)}") + print(f"Centers of staff parts at the same row not aligned (Th: {horizontal_diff_th}): {norm(centers)}") + continue # Unit sizes should roughly all the same unit_size = np.array([staff.unit_size for staff in staffs]) - if not np.all(norm(unit_size) < unit_size_diff_th): - raise E.StafflineUnitSizeInconsistent( - f"Unit sizes not consistent (th: {unit_size_diff_th}): {norm(unit_size)}") + if not np.all(norm(unit_size) < unit_size_diff_th): + print(f"Unit sizes not consistent (th: {unit_size_diff_th}): {norm(unit_size)}") + continue + valid_staffs.append(staffs) - return np.array(all_staffs), zones + return np.array(valid_staffs).T, zones def extract_part(pred: ndarray, x_offset: int, line_threshold: float = 0.8) -> List[Staff]: diff --git a/oemer/symbol_extraction.py b/oemer/symbol_extraction.py index 369dc84..9d0b2f9 100755 --- a/oemer/symbol_extraction.py +++ b/oemer/symbol_extraction.py @@ -270,7 +270,7 @@ def parse_clefs_keys( for box in bboxes: w = box[2] - box[0] h = box[3] - box[1] - region = clefs_keys[box[1]:box[3], box[0]:box[2]] + region: ndarray = clefs_keys[box[1]:box[3], box[0]:box[2]] usize = get_unit_size(*get_center(box)) area_size_ratio = w * h / usize**2 area_tp_ratio = region[region>0].size / (w * h) @@ -317,6 +317,8 @@ def parse_rests(line_box: ndarray, unit_size: float) -> Tuple[List[BBox], List[s bboxes = get_bbox(rests) bboxes = filter_out_of_range_bbox(bboxes) + if len(bboxes) == 0: + return [], [] bboxes = merge_nearby_bbox(bboxes, unit_size*1.2) bboxes = rm_merge_overlap_bbox(bboxes) bboxes = filter_out_small_area(bboxes, area_size_func=lambda usize: usize**2 * 0.7) @@ -373,6 +375,9 @@ def get_nearby_note_id(box: BBox, note_id_map: ndarray) -> Union[int, None]: unit_size = int(round(get_unit_size(cen_x, cen_y))) nid = None for x in range(box[2], box[2]+unit_size): + is_in_range = (0 <= cen_y < note_id_map.shape[0]) and (0 <= x < note_id_map.shape[1]) + if not is_in_range: + continue if note_id_map[cen_y, x] != -1: nid = note_id_map[cen_y, x] break @@ -401,11 +406,14 @@ def gen_sfns(bboxes: List[BBox], labels: List[str]) -> List[Sfn]: if ss.note_id is not None: note = notes[ss.note_id] if ss.track != note.track: - raise E.SfnNoteTrackMismatch(f"Track of sfn and note not mismatch: {ss}\n{note}") - if ss.group != note.group: - raise E.SfnNoteGroupMismatch(f"Group of sfn and note not mismatch: {ss}\n{note}") - notes[ss.note_id].sfn = ss.label - ss.is_key = False + print(f"Track of sfn and note mismatch: {ss}\n{note}") + notes[ss.note_id].invalid = True + elif ss.group != note.group: + print(f"Group of sfn and note mismatch: {ss}\n{note}") + notes[ss.note_id].invalid = True + else: + notes[ss.note_id].sfn = ss.label + ss.is_key = False sfns.append(ss) return sfns @@ -432,7 +440,7 @@ def gen_rests(bboxes: List[BBox], labels: List[str]) -> List[Rest]: rr.group = st1.group unit_size = int(round(get_unit_size(*get_center(box)))) - dot_range = range(box[2]+1, box[2]+unit_size) + dot_range = range(box[2]+1, min(box[2]+unit_size, symbols.shape[1] - 1)) dot_region = symbols[box[1]:box[3], dot_range] if 0 < np.sum(dot_region) < unit_size**2 / 7: rr.has_dot = True diff --git a/oemer/train.py b/oemer/train.py index dbe5b46..475acc2 100755 --- a/oemer/train.py +++ b/oemer/train.py @@ -11,11 +11,14 @@ from .build_label import build_label from .models.unet import semantic_segmentation, u_net -from .constant import CHANNEL_NUM +from .constant_min import CHANNEL_NUM def get_cvc_data_paths(dataset_path): + if not os.path.exists(dataset_path): + raise FileNotFoundError(f"{dataset_path} not found, download the dataset first.") + dirs = ["curvature", "ideal", "interrupted", "kanungo", "rotated", "staffline-thickness-variation-v1", "staffline-thickness-variation-v2", "staffline-y-variation-v1", "staffline-y-variation-v2", "thickness-ratio", "typeset-emulation", "whitespeckles"] @@ -37,6 +40,9 @@ def get_cvc_data_paths(dataset_path): def get_deep_score_data_paths(dataset_path): + if not os.path.exists(dataset_path): + raise FileNotFoundError(f"{dataset_path} not found, download the dataset first.") + imgs = os.listdir(os.path.join(dataset_path, "images")) paths = [] for img in imgs: @@ -122,9 +128,42 @@ def batch_transform(img, trans_func): result.append(np.array(tmp_img)) return np.dstack(result) + +class MultiprocessingDataLoader: + def __init__(self, num_worker: int): + self._queue: Queue = Queue(maxsize=20) + self._dist_queue: Queue = Queue(maxsize=30) + self._process_pool = [] + for _ in range(num_worker): + processor = Process(target=self._preprocess_image) + processor.daemon = True + self._process_pool.append(processor) + self._pdist = Process(target=self._distribute_process) + self._pdist.daemon = True + + def _start_processes(self): + if not self._pdist.is_alive(): + self._pdist.start() + for process in self._process_pool: + if not process.is_alive(): + process.start() -class DataLoader: + def _terminate_processes(self): + self._pdist.terminate() + for process in self._process_pool: + process.terminate() + + + def _distribute_process(self): + pass + + def _preprocess_image(self): + pass + + +class DataLoader(MultiprocessingDataLoader): def __init__(self, feature_files, win_size=256, num_samples=100, min_step_size=0.2, num_worker=4): + super().__init__(num_worker) self.feature_files = feature_files random.shuffle(self.feature_files) self.win_size = win_size @@ -138,16 +177,6 @@ def __init__(self, feature_files, win_size=256, num_samples=100, min_step_size=0 self.file_idx = 0 - self._queue = Queue(maxsize=200) - self._dist_queue = Queue(maxsize=300) - self._process_pool = [] - for _ in range(num_worker): - processor = Process(target=self._preprocess_image) - processor.daemon = True - self._process_pool.append(processor) - self._pdist = Process(target=self._distribute_process) - self._pdist.daemon = True - def _distribute_process(self): while True: paths = self.feature_files[self.file_idx] @@ -175,6 +204,7 @@ def _preprocess_image(self): # Random perspective transform seed = random.randint(0, 1000) + np.float = float # Monkey patch to workaround removal of np.float perspect_trans = lambda img: imaugs.perspective_transform(img, seed=seed, sigma=70) image = np.array(perspect_trans(image)) # RGB image staff_img = np.array(perspect_trans(staff_img)) # 1-bit mask @@ -187,11 +217,7 @@ def _preprocess_image(self): def __iter__(self): samples = 0 - if not self._pdist.is_alive(): - self._pdist.start() - for process in self._process_pool: - if not process.is_alive(): - process.start() + self._start_processes() while samples < self.num_samples: image, staff_img, symbol_img, ratio = self._queue.get() @@ -218,9 +244,7 @@ def __iter__(self): start_y = min(start_y + y_step, max_y) start_x = min(start_x + x_step, max_x) - self._pdist.terminate() - for process in self._process_pool: - process.terminate() + self._terminate_processes() def get_dataset(self, batch_size, output_types=None, output_shapes=None): def gen_wrapper(): @@ -240,8 +264,9 @@ def gen_wrapper(): .prefetch(tf.data.experimental.AUTOTUNE) -class DsDataLoader: +class DsDataLoader(MultiprocessingDataLoader): def __init__(self, feature_files, win_size=256, num_samples=100, step_size=0.5, num_worker=4): + super().__init__(num_worker) self.feature_files = feature_files random.shuffle(self.feature_files) self.win_size = win_size @@ -255,16 +280,6 @@ def __init__(self, feature_files, win_size=256, num_samples=100, step_size=0.5, self.file_idx = 0 - self._queue = Queue(maxsize=200) - self._dist_queue = Queue(maxsize=100) - self._process_pool = [] - for _ in range(num_worker): - processor = Process(target=self._preprocess_image) - processor.daemon = True - self._process_pool.append(processor) - self._pdist = Process(target=self._distribute_process) - self._pdist.daemon = True - def _distribute_process(self): while True: paths = self.feature_files[self.file_idx] @@ -293,6 +308,7 @@ def _preprocess_image(self): # Random perspective transform seed = random.randint(0, 1000) + np.float = float # Monkey patch to workaround removal of np.float perspect_trans = lambda img: imaugs.perspective_transform(img, seed=seed, sigma=70) image = np.array(batch_transform(image, perspect_trans)) # RGB image label = np.array(batch_transform(label, perspect_trans)) @@ -302,11 +318,7 @@ def _preprocess_image(self): def __iter__(self): samples = 0 - if not self._pdist.is_alive(): - self._pdist.start() - for process in self._process_pool: - if not process.is_alive(): - process.start() + self._start_processes() while samples < self.num_samples: image, label, ratio = self._queue.get() @@ -337,10 +349,7 @@ def __iter__(self): ll = label[index] yield feat, ll - self._pdist.terminate() - for process in self._process_pool: - process.terminate() - + self._terminate_processes() def get_dataset(self, batch_size, output_types=None, output_shapes=None): def gen_wrapper(): for data in self: @@ -410,7 +419,6 @@ def focal_tversky_loss(y_true, y_pred, fw=0.7, alpha=0.7, smooth=1., gamma=0.75) def train_model( dataset_path, - win_size=288, train_val_split=0.1, learning_rate=5e-4, epochs=15, @@ -418,33 +426,51 @@ def train_model( batch_size=8, val_steps=200, val_batch_size=8, - early_stop=8 + early_stop=8, + data_model="segnet" ): - # feat_files = get_cvc_data_paths(dataset_path) - feat_files = get_deep_score_data_paths(dataset_path) + if data_model == "segnet": + feat_files = get_deep_score_data_paths(dataset_path) + else: + feat_files = get_cvc_data_paths(dataset_path) random.shuffle(feat_files) split_idx = round(train_val_split * len(feat_files)) train_files = feat_files[split_idx:] val_files = feat_files[:split_idx] print(f"Loading dataset. Train/validation: {len(train_files)}/{len(val_files)}") - train_data = DsDataLoader( - train_files, - win_size=win_size, - num_samples=epochs*steps*batch_size - ) \ - .get_dataset(batch_size) - val_data = DsDataLoader( - val_files, - win_size=win_size, - num_samples=epochs*val_steps*val_batch_size - ) \ - .get_dataset(val_batch_size) + if data_model == "segnet": + win_size=288 + train_data = DsDataLoader( + train_files, + win_size=win_size, + num_samples=epochs*steps*batch_size + ) \ + .get_dataset(batch_size) + val_data = DsDataLoader( + val_files, + win_size=win_size, + num_samples=epochs*val_steps*val_batch_size + ) \ + .get_dataset(val_batch_size) + model = u_net(win_size=win_size, out_class=CHANNEL_NUM) + else: + win_size=256 + train_data = DataLoader( + train_files, + win_size=win_size, + num_samples=epochs*steps*batch_size + ) \ + .get_dataset(batch_size) + val_data = DataLoader( + val_files, + win_size=win_size, + num_samples=epochs*val_steps*val_batch_size + ) \ + .get_dataset(val_batch_size) + model = semantic_segmentation(win_size=256, out_class=3) print("Initializing model") - #model = naive_conv(win_size=win_size) - #model = u_net(win_size=win_size, out_class=CHANNEL_NUM) - model = semantic_segmentation(win_size=win_size, out_class=CHANNEL_NUM) optim = tf.keras.optimizers.Adam(learning_rate=WarmUpLearningRate(learning_rate)) #loss = tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1) #loss = tf.keras.losses.CategoricalCrossentropy() @@ -457,15 +483,19 @@ def train_model( ] print("Start training") - model.fit( - train_data, - validation_data=val_data, - epochs=epochs, - steps_per_epoch=steps, - validation_steps=val_steps, - callbacks=callbacks - ) - return model + try: + model.fit( + train_data, + validation_data=val_data, + epochs=epochs, + steps_per_epoch=steps, + validation_steps=val_steps, + callbacks=callbacks + ) + return model + except Exception as e: + print(e) + return model def resize_image(image: Image.Image): diff --git a/train.py b/train.py new file mode 100644 index 0000000..2b90a51 --- /dev/null +++ b/train.py @@ -0,0 +1,64 @@ +import sys +import time +import os + +import tensorflow as tf + +from oemer import train +from oemer import classifier + + +def write_text_to_file(text, path): + with open(path, "w") as f: + f.write(text) + +if len(sys.argv) != 2: + print("Usage: python train.py ") + sys.exit(1) + +def get_model_base_name(model_name: str) -> str: + timestamp = str(round(time.time())) + return f"{model_name}_{timestamp}" + +model_type = sys.argv[1] + +def prepare_classifier_data(): + if not os.path.exists("train_data"): + classifier.collect_data(2000) + +if model_type == "segnet": + model = train.train_model("ds2_dense", data_model=model_type, steps=1500, epochs=15) + filename = get_model_base_name(model_type) + os.makedirs(filename) + write_text_to_file(model.to_json(), os.path.join(filename, "arch.json")) + model.save_weights(os.path.join(filename, "weights.h5")) +elif model_type == "unet": + model = train.train_model("CvcMuscima-Distortions", data_model=model_type, steps=1500, epochs=15) + filename = get_model_base_name(model_type) + os.makedirs(filename) + write_text_to_file(model.to_json(), os.path.join(filename, "arch.json")) + model.save_weights(os.path.join(filename, "weights.h5")) +elif model_type == "unet_from_checkpoint" or model_type == "segnet_from_checkpoint": + model = tf.keras.models.load_model("seg_unet", custom_objects={"WarmUpLearningRate": train.WarmUpLearningRate}) + filename = get_model_base_name(model_type.split("_")[0]) + os.makedirs(filename) + write_text_to_file(model.to_json(), os.path.join(filename, "arch.json")) + model.save_weights(os.path.join(filename, "weights.h5")) +elif model_type == "rests_above8": + prepare_classifier_data() + classifier.train_rests_above8(get_model_base_name(model_type)) +elif model_type == "rests": + prepare_classifier_data() + classifier.train_rests(get_model_base_name(model_type)) +elif model_type == "all_rests": + prepare_classifier_data() + classifier.train_all_rests(get_model_base_name(model_type)) +elif model_type == "sfn": + prepare_classifier_data() + classifier.train_sfn(get_model_base_name(model_type)) +elif model_type == "clef": + prepare_classifier_data() + classifier.train_clefs(get_model_base_name(model_type)) +else: + print("Unknown model: " + model_type) + sys.exit(1) \ No newline at end of file