Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make training of segnet, unet and classifiers easier by providing a single entry point to all training steps #54

Merged
merged 19 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
47ab174
Fixed 'TypeError: Cannot convert 4.999899999999999e-07 to EagerTensor…
liebharc Nov 29, 2023
9f5a6e9
--format was deprecated in ruff and replaced wtih --output-format
liebharc Nov 29, 2023
b431701
Added a single entry point to train all models
liebharc Nov 30, 2023
bdbe6f5
Added convenience wrapper for oemer
liebharc Nov 30, 2023
a87d309
Tried to figure out the definitions for the dense dataset and to docu…
liebharc Nov 30, 2023
8d8c64f
Decreased queue sizes as otherwise the training process crashed with …
liebharc Nov 30, 2023
4c51b52
Added model outputs to git ignore
liebharc Nov 30, 2023
cae8dd7
Added checks for dataset folders
liebharc Nov 30, 2023
7dbd6ae
Using default training params
liebharc Nov 30, 2023
b200d0f
Added workarounds for removal of np.float
liebharc Nov 30, 2023
d087cf1
Using dataset definitions
liebharc Nov 30, 2023
2deda24
Added type annotations
liebharc Nov 30, 2023
6f1a678
Added a train_all_rests even if the resulting model is right now not …
liebharc Dec 1, 2023
0e56ca0
Merge pull request #4 from liebharc/main
liebharc Dec 1, 2023
f0bf03e
segnet and unet should now pick the correct model
liebharc Jan 5, 2024
d3db06a
Changed label definitions from what appears to be used in oemer right…
liebharc Jan 8, 2024
3655f61
With this commit the resulting arch.json matches the one inside of oe…
liebharc Jan 10, 2024
842dab2
Avoid that the OMR processes finishes prematurely (#53)
liebharc Jan 29, 2024
c44cdd2
Fix install from github command in README
BreezeWhite Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,23 @@ checkpoints/

*.musicxml
*.mp3
*.swp
*.swp

# Model training datasets
/ds2_dense
/CvcMuscima-Distortions

# Model training checkpoints and outputs
/seg_unet
/test_data
/train_data
/*.model
/*.h5
/*.json

/segnet_*
/unet_*
/rests_*
/all_rests_*
/sfn_*
/clef_*
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pip install oemer
pip install oemer[tf]

# (optional) Or install the newest updates directly from Github.
pip install https://github.com/BreezeWhite/oemer
pip install git+https://github.com/BreezeWhite/oemer

# Run
oemer <path_to_image>
Expand Down
3 changes: 3 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from oemer import ete

ete.main()
14 changes: 8 additions & 6 deletions oemer/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Union, Any, List, Tuple, Dict

import cv2
from cv2.typing import RotatedRect
import numpy as np
from numpy import ndarray
from sklearn.cluster import AgglomerativeClustering
Expand Down Expand Up @@ -118,11 +119,12 @@ def find_lines(data: ndarray, min_len: int = 10, max_gap: int = 20) -> List[BBox

lines = cv2.HoughLinesP(data.astype(np.uint8), 1, np.pi/180, 50, None, min_len, max_gap)
new_line = []
for line in lines:
line = line[0]
top_x, bt_x = (line[0], line[2]) if line[0] < line[2] else (line[2], line[0])
top_y, bt_y = (line[1], line[3]) if line[1] < line[3] else (line[3], line[1])
new_line.append((top_x, top_y, bt_x, bt_y))
if lines is not None:
for line in lines:
line = line[0]
top_x, bt_x = (line[0], line[2]) if line[0] < line[2] else (line[2], line[0])
top_y, bt_y = (line[1], line[3]) if line[1] < line[3] else (line[3], line[1])
new_line.append((top_x, top_y, bt_x, bt_y))
return new_line


Expand Down Expand Up @@ -159,7 +161,7 @@ def draw_bounding_boxes(
return img


def get_rotated_bbox(data: ndarray) -> List[Tuple[Tuple[float, float], Tuple[float, float], float]]:
def get_rotated_bbox(data: ndarray) -> List[RotatedRect]:
contours, _ = cv2.findContours(data.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
bboxes = []
for cnt in contours:
Expand Down
23 changes: 10 additions & 13 deletions oemer/build_label.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import sys
import os
import random
from PIL import Image

import cv2
import numpy as np

from .constant_min import CLASS_CHANNEL_MAP
from .constant_min import CLASS_CHANNEL_MAP, CHANNEL_NUM
from .dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF


HALF_WHOLE_NOTE = [39, 41, 42, 43, 45, 46, 47, 49]
HALF_WHOLE_NOTE = DEF.NOTEHEADS_HOLLOW + DEF.NOTEHEADS_WHOLE + [42]


def fill_hole(gt, tar_color):
Expand Down Expand Up @@ -75,12 +77,12 @@ def build_label(seg_path):
color_set = set(np.unique(arr))
color_set.remove(0) # Remove background color from the candidates

total_chs = len(set(CLASS_CHANNEL_MAP.values())) + 2 # Plus 'background' and 'others' channel.
total_chs = CHANNEL_NUM
output = np.zeros(arr.shape + (total_chs,))

output[..., 0] = np.where(arr==0, 1, 0)
for color in color_set:
ch = CLASS_CHANNEL_MAP.get(color, -1)
ch = CLASS_CHANNEL_MAP.get(color, 0)
if (ch != 0) and color in HALF_WHOLE_NOTE:
note = fill_hole(arr, color)
output[..., ch] += note
Expand All @@ -101,12 +103,7 @@ def find_example(dataset_path: str, color: int, max_count=100, mark_value=200):


if __name__ == "__main__":
seg_folder = '/media/kohara/ADATA HV620S/dataset/ds2_dense/segmentation'
files = os.listdir(seg_folder)
path = os.path.join(seg_folder, random.choice(files))
#out = build_label(path)

color = 45
arr = find_example(color) # type: ignore
arr = np.where(arr==200, color, arr)
out = fill_hole(arr, color)
seg_folder = 'ds2_dense/segmentation'
color = int(sys.argv[1])
with_background, without_background = find_example(seg_folder, color)
cv2.imwrite("example.png", with_background)
52 changes: 51 additions & 1 deletion oemer/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _collect(color, out_path, samples=100):
img = imaugs.resize(Image.fromarray(patch.astype(np.uint8)), width=tar_w, height=tar_h)

seed = random.randint(0, 1000)
np.float = float # Monkey patch to workaround removal of np.float
img = imaugs.perspective_transform(img, seed=seed, sigma=3)
img = np.where(np.array(img)>0, 255, 0)
Image.fromarray(img.astype(np.uint8)).save(out_path / f"{idx}.png")
Expand Down Expand Up @@ -118,10 +119,12 @@ def train(folders):
model.fit(train_x, train_y)
return model, class_map

def build_class_map(folders):
return {idx: Path(ff).name for idx, ff in enumerate(folders)}

def train_tf(folders):
import tensorflow as tf
class_map = {idx: Path(ff).name for idx, ff in enumerate(folders)}
class_map = build_class_map(folders)
train_x = []
train_y = []
samples = None
Expand Down Expand Up @@ -234,6 +237,53 @@ def predict(region, model_name):
pred = model.predict(np.array(region).reshape(1, -1))
return m_info['class_map'][pred[0]]

def train_rests_above8(filename = "rests_above8.model"):
folders = ["rest_8th", "rest_16th", "rest_32nd", "rest_64th"]
model, class_map = train_tf([f"train_data/{folder}" for folder in folders])
test_tf(model, [f"test_data/{folder}" for folder in folders])
output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map}
pickle.dump(output, open(filename, "wb"))


def train_rests(filename = "rests.model"):
folders = ["rest_whole", "rest_quarter", "rest_8th"]
model, class_map = train_tf([f"train_data/{folder}" for folder in folders])
test_tf(model, [f"test_data/{folder}" for folder in folders])
output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map}
pickle.dump(output, open(filename, "wb"))


def train_all_rests(filename = "all_rests.model"):
folders = ["rest_whole", "rest_quarter", "rest_8th", "rest_16th", "rest_32nd", "rest_64th"]
model, class_map = train_tf([f"train_data/{folder}" for folder in folders])
test_tf(model, [f"test_data/{folder}" for folder in folders])
output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map}
pickle.dump(output, open(filename, "wb"))


def train_sfn(filename = "sfn.model"):
folders = ["sharp", "flat", "natural"]
model, class_map = train_tf([f"train_data/{folder}" for folder in folders])
test_tf(model, [f"test_data/{folder}" for folder in folders])
output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map}
pickle.dump(output, open(filename, "wb"))


def train_clefs(filename = "clef.model"):
folders = ["gclef", "fclef"]
model, class_map = train_tf([f"train_data/{folder}" for folder in folders])
test_tf(model, [f"test_data/{folder}" for folder in folders])
output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map}
pickle.dump(output, open(filename, "wb"))


def train_noteheads():
folders = ["notehead_solid", "notehead_hollow"]
model, class_map = train_tf([f"train_data/{folder}" for folder in folders])
test_tf(model, [f"test_data/{folder}" for folder in folders])
output = {'model': model, 'w': TARGET_WIDTH, 'h': TARGET_HEIGHT, 'class_map': class_map}
pickle.dump(output, open(f"notehead.model", "wb"))


if __name__ == "__main__":
samples = 400
Expand Down
30 changes: 16 additions & 14 deletions oemer/constant.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from enum import Enum, auto

from oemer.dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF


CLASS_CHANNEL_LIST = [
[165, 2], # staff, ledgerLine
[35, 37, 38], # noteheadBlack
[39, 41, 42], # noteheadHalf
[43, 45, 46, 47, 49], # noteheadWhole
[64, 58, 59, 60, 66, 63, 69, 68, 61, 62, 67, 65], # flags
[146, 51], # beam, augmentationDot
[3, 52], # barline, stem
[74, 70, 72, 76], # accidentalSharp, accidentalFlat, accidentalNatural, accidentalDoubleSharp
[80, 78, 79], # keySharp, keyFlat, keyNatural
[97, 100, 99, 98, 101, 102, 103, 104, 96, 163], # rests
[136, 156, 137, 155, 152, 151, 153, 154, 149, 155], # tuplets
[145, 147], # slur, tie
[10, 13, 12, 19, 11, 20], # clefs
[25, 24, 29, 22, 23, 28, 27, 34, 30, 21, 33, 26], # timeSigs
DEF.STAFF + DEF.LEDGERLINE,
DEF.NOTEHEADS_SOLID + [38],
DEF.NOTEHEADS_HOLLOW + [42],
DEF.NOTEHEADS_WHOLE + [46],
DEF.FLAG_DOWN + DEF.FLAG_UP + [59, 65],
DEF.BEAM + DEF.DOT,
DEF.BARLINE_BETWEEN + DEF.STEM,
DEF.ALL_ACCIDENTALS,
DEF.ALL_KEYS,
DEF.ALL_RESTS + [163],
DEF.TUPETS,
DEF.SLUR_AND_TIE,
DEF.ALL_CLEFS + DEF.NUMBERS,
DEF.TIME_SIGNATURE_SUBSET
]

CLASS_CHANNEL_MAP = {
Expand Down
17 changes: 7 additions & 10 deletions oemer/constant_min.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from oemer.dense_dataset_definitions import DENSE_DATASET_DEFINITIONS as DEF


CLASS_CHANNEL_LIST = [
[165, 2], # staff, ledgerLine
[35, 37, 38, 39, 41, 42, 43, 45, 46, 47, 49, 52], # notehead, stem
[
64, 58, 60, 66, 63, 69, 68, 61, 62, 67, 65, 59, 146, # flags, beam
97, 100, 99, 98, 101, 102, 103, 104, 96, 163, # rests
80, 78, 79, 74, 70, 72, 76, 3, # sharp, flat, natural, barline
10, 13, 12, 19, 11, 20, 51, # clefs, augmentationDot,
25, 24, 29, 22, 23, 28, 27, 34, 30, 21, 33, 26, # timeSigs
]
DEF.STEM + DEF.ALL_RESTS_EXCEPT_LARGE + DEF.BARLINE_BETWEEN + DEF.BARLINE_END,
DEF.NOTEHEADS_ALL,
DEF.ALL_CLEFS + DEF.ALL_KEYS + DEF.ALL_ACCIDENTALS,
]

CLASS_CHANNEL_MAP = {
Expand All @@ -16,4 +13,4 @@
for color in colors
}

CHANNEL_NUM = len(CLASS_CHANNEL_LIST) + 2
CHANNEL_NUM = len(CLASS_CHANNEL_LIST) + 1 # Plus 'background' and 'others' channel.
78 changes: 78 additions & 0 deletions oemer/dense_dataset_definitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
class Symbols:
BACKGROUND = [0]
LEDGERLINE = [2]
BARLINE_BETWEEN = [3]
BARLINE_END = [4]
ALL_BARLINES = BARLINE_BETWEEN + BARLINE_END
REPEAT_DOTS = [7]
G_GLEF = [10]
C_CLEF = [11, 12]
F_CLEF = [13]
ALL_CLEFS = G_GLEF + C_CLEF + F_CLEF
NUMBERS = [19, 20]
TIME_SIGNATURE_SUBSET = [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33, 34]
TIME_SIGNATURE = TIME_SIGNATURE_SUBSET + [31, 32] # Oemer hasn't used these in the past
NOTEHEAD_FULL_ON_LINE = [35]
UNKNOWN = [36, 38, 40, 128, 143, 144, 148, 150, 157, 159, 160, 161, 162, 163, 164, 167, 170, 171]
NOTEHEAD_FULL_BETWEEN_LINES = [37]
NOTEHEAD_HOLLOW_ON_LINE = [39]
NOTEHEAD_HOLLOW_BETWEEN_LINE = [41]
WHOLE_NOTE_ON_LINE = [43]
WHOLE_NOTE_BETWEEN_LINE = [45]
DOUBLE_WHOLE_NOTE_ON_LINE = [47]
DOUBLE_WHOLE_NOTE_BETWEEN_LINE = [49]
NOTEHEADS_SOLID = NOTEHEAD_FULL_ON_LINE + NOTEHEAD_FULL_BETWEEN_LINES
NOTEHEADS_HOLLOW = NOTEHEAD_HOLLOW_ON_LINE + NOTEHEAD_HOLLOW_BETWEEN_LINE
NOTEHEADS_WHOLE = WHOLE_NOTE_ON_LINE + WHOLE_NOTE_BETWEEN_LINE + DOUBLE_WHOLE_NOTE_ON_LINE + DOUBLE_WHOLE_NOTE_BETWEEN_LINE
NOTEHEADS_ALL = NOTEHEAD_FULL_ON_LINE + NOTEHEAD_FULL_BETWEEN_LINES + NOTEHEAD_HOLLOW_ON_LINE + NOTEHEAD_HOLLOW_BETWEEN_LINE + WHOLE_NOTE_ON_LINE + WHOLE_NOTE_BETWEEN_LINE + DOUBLE_WHOLE_NOTE_ON_LINE + DOUBLE_WHOLE_NOTE_BETWEEN_LINE
DOT = [51]
STEM = [52]
TREMOLO = [53, 54, 55, 56]
FLAG_DOWN = [58, 60, 61, 62, 63]
FLAG_UP = [64, 66, 67, 68, 69]
FLAT = [70]
NATURAL = [72]
SHARP = [74]
DOUBLE_SHARP = [76]
ALL_ACCIDENTALS = FLAT + NATURAL + SHARP + DOUBLE_SHARP
KEY_FLAT = [78]
KEY_NATURAL = [79]
KEY_SHARP = [80]
ALL_KEYS = KEY_FLAT + KEY_NATURAL + KEY_SHARP
ACCENT_ABOVE = [81]
ACCENT_BELOW = [82]
STACCATO_ABOVE = [83]
STACCATO_BELOW = [84]
TENUTO_ABOVE = [85]
TENUTO_BELOW = [86]
STACCATISSIMO_ABOVE = [87]
STACCATISSIMO_BELOW = [88]
MARCATO_ABOVE = [89]
MARCATO_BELOW = [90]
FERMATA_ABOVE = [91]
FERMATA_BELOW = [92]
BREATH_MARK = [93]
REST_LARGE = [95]
REST_LONG = [96]
REST_BREVE = [97]
REST_FULL = [98]
REST_QUARTER = [99]
REST_EIGHTH = [100]
REST_SIXTEENTH = [101]
REST_THIRTY_SECOND = [102]
REST_SIXTY_FOURTH = [103]
REST_ONE_HUNDRED_TWENTY_EIGHTH = [104]
ALL_RESTS_EXCEPT_LARGE = REST_LONG + REST_BREVE + REST_FULL + REST_QUARTER + REST_EIGHTH + REST_SIXTEENTH + REST_THIRTY_SECOND + REST_SIXTY_FOURTH + REST_ONE_HUNDRED_TWENTY_EIGHTH
ALL_RESTS = ALL_RESTS_EXCEPT_LARGE
TRILL = [127]
GRUPPETO = [129]
MORDENT = [130]
DOWN_BOW = [131]
UP_BOW = [132]
SYMBOL = [133, 134, 135, 138, 139, 141, 142]
TUPETS = [136, 137, 149, 151, 152, 153, 154, 155, 156]
SLUR_AND_TIE = [145, 147]
BEAM = [146]
STAFF = [165]

DENSE_DATASET_DEFINITIONS = Symbols()
4 changes: 2 additions & 2 deletions oemer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def inference(
image_pil = Image.open(img_path)
if "GIF" != image_pil.format:
# Tricky workaround to avoid random mistery transpose when loading with 'Image'.
image_pil = cv2.imread(img_path)
image_pil = Image.fromarray(image_pil)
image_cv = cv2.imread(img_path)
image_pil = Image.fromarray(image_cv)

image_pil = image_pil.convert("RGB")
image = np.array(resize_image(image_pil))
Expand Down
Loading
Loading