From 083ddbd5b0389d96c1916e918d87aeef5f14c00a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Budy=C5=9B?= Date: Sat, 29 Jun 2024 11:23:04 +0200 Subject: [PATCH] use ruff (#73) * use ruff * fix pytest call * set line length to 120 --- .github/workflows/black.yml | 10 ------ .github/workflows/main.yml | 7 ++-- .github/workflows/ruff.yml | 8 +++++ manga_ocr/__init__.py | 4 +-- manga_ocr/ocr.py | 16 +++------ manga_ocr/run.py | 16 +++------ manga_ocr_dev/data/generate_backgrounds.py | 3 +- manga_ocr_dev/data/process_manga109s.py | 16 +++------ .../synthetic_data_generator/generator.py | 9 ++--- .../synthetic_data_generator/renderer.py | 36 +++++-------------- .../synthetic_data_generator/run_generate.py | 6 ++-- .../synthetic_data_generator/utils.py | 4 +-- manga_ocr_dev/training/dataset.py | 4 +-- manga_ocr_dev/training/get_model.py | 12 ++----- manga_ocr_dev/training/metrics.py | 4 +-- manga_ocr_dev/training/train.py | 12 ++----- manga_ocr_dev/training/utils.py | 7 +--- pyproject.toml | 10 ++++++ tests/test_ocr.py | 4 +-- 19 files changed, 58 insertions(+), 130 deletions(-) delete mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/ruff.yml diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml deleted file mode 100644 index 0a8b8e9..0000000 --- a/.github/workflows/black.yml +++ /dev/null @@ -1,10 +0,0 @@ -name: Lint - -on: [push, pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: psf/black@stable \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cf778f8..eaa22c5 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -16,8 +16,6 @@ jobs: steps: - name: Checkout uses: actions/checkout@v3 - with: - path: manga_ocr - name: Set up Python uses: actions/setup-python@v3 @@ -27,9 +25,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest - pip install -e manga_ocr + pip install -e ".[dev]" - name: Test run: | - pytest manga_ocr/tests + pytest diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..c8a0958 --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,8 @@ +name: Ruff +on: [ push, pull_request ] +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 \ No newline at end of file diff --git a/manga_ocr/__init__.py b/manga_ocr/__init__.py index 14fd587..1525c54 100644 --- a/manga_ocr/__init__.py +++ b/manga_ocr/__init__.py @@ -1,2 +1,2 @@ -from ._version import __version__ -from manga_ocr.ocr import MangaOcr +from ._version import __version__ as __version__ +from manga_ocr.ocr import MangaOcr as MangaOcr diff --git a/manga_ocr/ocr.py b/manga_ocr/ocr.py index 25e4576..c5f398c 100644 --- a/manga_ocr/ocr.py +++ b/manga_ocr/ocr.py @@ -9,17 +9,11 @@ class MangaOcr: - def __init__( - self, pretrained_model_name_or_path="kha-white/manga-ocr-base", force_cpu=False - ): + def __init__(self, pretrained_model_name_or_path="kha-white/manga-ocr-base", force_cpu=False): logger.info(f"Loading OCR model from {pretrained_model_name_or_path}") - self.processor = ViTImageProcessor.from_pretrained( - pretrained_model_name_or_path - ) + self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) - self.model = VisionEncoderDecoderModel.from_pretrained( - pretrained_model_name_or_path - ) + self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path) if not force_cpu and torch.cuda.is_available(): logger.info("Using CUDA") @@ -43,9 +37,7 @@ def __call__(self, img_or_path): elif isinstance(img_or_path, Image.Image): img = img_or_path else: - raise ValueError( - f"img_or_path must be a path or PIL.Image, instead got: {img_or_path}" - ) + raise ValueError(f"img_or_path must be a path or PIL.Image, instead got: {img_or_path}") img = img.convert("L").convert("RGB") diff --git a/manga_ocr/run.py b/manga_ocr/run.py index 3de5bf2..5af6e6a 100644 --- a/manga_ocr/run.py +++ b/manga_ocr/run.py @@ -34,9 +34,7 @@ def process_and_write_results(mocr, img_or_path, write_to): else: write_to = Path(write_to) if write_to.suffix != ".txt": - raise ValueError( - 'write_to must be either "clipboard" or a path to a text file' - ) + raise ValueError('write_to must be either "clipboard" or a path to a text file') with write_to.open("a", encoding="utf-8") as f: f.write(text + "\n") @@ -102,13 +100,9 @@ def run( # Pillow error when clipboard contains text (Linux, X11) pass else: - logger.warning( - "Error while reading from clipboard ({})".format(error) - ) + logger.warning("Error while reading from clipboard ({})".format(error)) else: - if isinstance(img, Image.Image) and not are_images_identical( - img, old_img - ): + if isinstance(img, Image.Image) and not are_images_identical(img, old_img): process_and_write_results(mocr, img, write_to) time.sleep(delay_secs) @@ -116,9 +110,7 @@ def run( else: read_from = Path(read_from) if not read_from.is_dir(): - raise ValueError( - 'read_from must be either "clipboard" or a path to a directory' - ) + raise ValueError('read_from must be either "clipboard" or a path to a directory') logger.info(f"Reading from directory {read_from}") diff --git a/manga_ocr_dev/data/generate_backgrounds.py b/manga_ocr_dev/data/generate_backgrounds.py index 8178db7..d0f1342 100644 --- a/manga_ocr_dev/data/generate_backgrounds.py +++ b/manga_ocr_dev/data/generate_backgrounds.py @@ -77,8 +77,7 @@ def generate_backgrounds(crops_per_page=5, min_size=40): if crop.shape[0] >= min_size and crop.shape[1] >= min_size: out_filename = ( - "_".join(Path(page_path).with_suffix("").parts[-2:]) - + f"_{ymin}_{ymax}_{xmin}_{xmax}.png" + "_".join(Path(page_path).with_suffix("").parts[-2:]) + f"_{ymin}_{ymax}_{xmin}_{xmax}.png" ) cv2.imwrite(str(BACKGROUND_DIR / out_filename), crop) diff --git a/manga_ocr_dev/data/process_manga109s.py b/manga_ocr_dev/data/process_manga109s.py index 2c79314..64aeef5 100644 --- a/manga_ocr_dev/data/process_manga109s.py +++ b/manga_ocr_dev/data/process_manga109s.py @@ -14,9 +14,7 @@ def get_books(): books = pd.DataFrame( { "book": books, - "annotations": [ - str(root / "annotations" / f"{book}.xml") for book in books - ], + "annotations": [str(root / "annotations" / f"{book}.xml") for book in books], "images": [str(root / "images" / book) for book in books], } ) @@ -36,9 +34,7 @@ def export_frames(): row = {} row["book"] = book.book row["page_index"] = int(page.attrib["index"]) - row["page_path"] = str( - Path(book.images) / f'{row["page_index"]:03d}.jpg' - ) + row["page_path"] = str(Path(book.images) / f'{row["page_index"]:03d}.jpg') row["page_width"] = int(page.attrib["width"]) row["page_height"] = int(page.attrib["height"]) row["id"] = frame.attrib["id"] @@ -69,9 +65,7 @@ def export_crops(): row = {} row["book"] = book.book row["page_index"] = int(page.attrib["index"]) - row["page_path"] = str( - Path(book.images) / f'{row["page_index"]:03d}.jpg' - ) + row["page_path"] = str(Path(book.images) / f'{row["page_index"]:03d}.jpg') row["page_width"] = int(page.attrib["width"]) row["page_height"] = int(page.attrib["height"]) row["id"] = text.attrib["id"] @@ -93,9 +87,7 @@ def export_crops(): data.crop_path = data.crop_path.apply(lambda x: "/".join(Path(x).parts[-2:])) data.to_csv(MANGA109_ROOT / "data.csv", index=False) - for page_path, boxes in tqdm( - data.groupby("page_path"), total=data.page_path.nunique() - ): + for page_path, boxes in tqdm(data.groupby("page_path"), total=data.page_path.nunique()): img = cv2.imread(str(MANGA109_ROOT / page_path)) for box in boxes.itertuples(): diff --git a/manga_ocr_dev/synthetic_data_generator/generator.py b/manga_ocr_dev/synthetic_data_generator/generator.py index 5ec4606..5229784 100644 --- a/manga_ocr_dev/synthetic_data_generator/generator.py +++ b/manga_ocr_dev/synthetic_data_generator/generator.py @@ -128,7 +128,6 @@ def add_random_furigana(self, line, word_prob=1.0, vocab=None): kanji_group = "" ascii_group = "" for i, c in enumerate(line): - if is_kanji(c): c_type = "kanji" kanji_group += c @@ -141,12 +140,8 @@ def add_random_furigana(self, line, word_prob=1.0, vocab=None): if c_type != "kanji" or i == len(line) - 1: if kanji_group: if np.random.uniform() < word_prob: - furigana_len = int( - np.clip(np.random.normal(1.5, 0.5), 1, 4) * len(kanji_group) - ) - char_source = np.random.choice( - ["hiragana", "katakana", "all"], p=[0.8, 0.15, 0.05] - ) + furigana_len = int(np.clip(np.random.normal(1.5, 0.5), 1, 4) * len(kanji_group)) + char_source = np.random.choice(["hiragana", "katakana", "all"], p=[0.8, 0.15, 0.05]) char_source = { "hiragana": self.hiragana, "katakana": self.katakana, diff --git a/manga_ocr_dev/synthetic_data_generator/renderer.py b/manga_ocr_dev/synthetic_data_generator/renderer.py index 135dfa1..d8a307a 100644 --- a/manga_ocr_dev/synthetic_data_generator/renderer.py +++ b/manga_ocr_dev/synthetic_data_generator/renderer.py @@ -60,9 +60,7 @@ def get_random_css_params(): if np.random.rand() < 0.7: params["text_orientation"] = "upright" - stroke_variant = np.random.choice( - ["stroke", "shadow", "none"], p=[0.8, 0.15, 0.05] - ) + stroke_variant = np.random.choice(["stroke", "shadow", "none"], p=[0.8, 0.15, 0.05]) if stroke_variant == "stroke": params["stroke_size"] = np.random.choice([1, 2, 3, 4, 8]) params["stroke_color"] = "white" @@ -88,9 +86,7 @@ def render_background(self, img): A.HorizontalFlip(), A.RandomRotate90(), A.InvertImg(), - A.RandomBrightnessContrast( - (-0.2, 0.4), (-0.8, -0.3), p=0.5 if draw_bubble else 1 - ), + A.RandomBrightnessContrast((-0.2, 0.4), (-0.8, -0.3), p=0.5 if draw_bubble else 1), A.Blur((3, 5), p=0.3), A.Resize(img.shape[0], img.shape[1]), ] @@ -108,17 +104,9 @@ def render_background(self, img): sigma = np.random.randint(10, 15) ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12)) - ymax = ( - img.shape[0] - - m0 - + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12)) - ) + ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12)) xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12)) - xmax = ( - img.shape[1] - - m0 - + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12)) - ) + xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12)) bubble_fill_color = (255, 255, 255, 255) bubble_contour_color = (0, 0, 0, 255) @@ -150,13 +138,9 @@ def render_background(self, img): img = blend(img, background) ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2)) - ymax = ( - img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2)) - ) + ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2)) xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2)) - xmax = ( - img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2)) - ) + xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2)) img = img[ymin:ymax, xmin:xmax] return img @@ -184,9 +168,7 @@ def blend(img, background): return img -def rounded_rectangle( - src, top_left, bottom_right, radius=1, color=255, thickness=1, line_type=cv2.LINE_AA -): +def rounded_rectangle(src, top_left, bottom_right, radius=1, color=255, thickness=1, line_type=cv2.LINE_AA): """From https://stackoverflow.com/a/60210706""" # corners: @@ -345,9 +327,7 @@ def get_css( # stroke is simulated by shadow overlaid multiple times styles.extend( [ - f"text-shadow: " - + ",".join([f"0 0 {stroke_size}px {stroke_color}"] * 10 * stroke_size) - + ";", + "text-shadow: " + ",".join([f"0 0 {stroke_size}px {stroke_color}"] * 10 * stroke_size) + ";", "-webkit-font-smoothing: antialiased;", ] ) diff --git a/manga_ocr_dev/synthetic_data_generator/run_generate.py b/manga_ocr_dev/synthetic_data_generator/run_generate.py index 65bb7c3..bb5b25d 100644 --- a/manga_ocr_dev/synthetic_data_generator/run_generate.py +++ b/manga_ocr_dev/synthetic_data_generator/run_generate.py @@ -24,7 +24,7 @@ def f(args): ret = source, id_, text_gt, params["vertical"], str(font_path) return ret - except Exception as e: + except Exception: print(traceback.format_exc()) @@ -54,9 +54,7 @@ def run(package=0, n_random=1000, n_limit=None, max_workers=16): OUT_DIR = DATA_SYNTHETIC_ROOT / "img" / package OUT_DIR.mkdir(parents=True, exist_ok=True) - data = thread_map( - f, args, max_workers=max_workers, desc=f"Processing package {package}" - ) + data = thread_map(f, args, max_workers=max_workers, desc=f"Processing package {package}") data = pd.DataFrame(data, columns=["source", "id", "text", "vertical", "font_path"]) meta_path = DATA_SYNTHETIC_ROOT / f"meta/{package}.csv" diff --git a/manga_ocr_dev/synthetic_data_generator/utils.py b/manga_ocr_dev/synthetic_data_generator/utils.py index 4715151..17e1570 100644 --- a/manga_ocr_dev/synthetic_data_generator/utils.py +++ b/manga_ocr_dev/synthetic_data_generator/utils.py @@ -52,7 +52,5 @@ def get_charsets(vocab_path=None): def get_font_meta(): df = pd.read_csv(ASSETS_PATH / "fonts.csv") df.font_path = df.font_path.apply(lambda x: str(FONTS_ROOT / x)) - font_map = { - row.font_path: set(row.supported_chars) for row in df.dropna().itertuples() - } + font_map = {row.font_path: set(row.supported_chars) for row in df.dropna().itertuples()} return df, font_map diff --git a/manga_ocr_dev/training/dataset.py b/manga_ocr_dev/training/dataset.py index 53f0833..8fb9293 100644 --- a/manga_ocr_dev/training/dataset.py +++ b/manga_ocr_dev/training/dataset.py @@ -40,9 +40,7 @@ def __init__( continue df = pd.read_csv(path) df = df.dropna() - df["path"] = df.id.apply( - lambda x: str(DATA_SYNTHETIC_ROOT / "img" / path.stem / f"{x}.jpg") - ) + df["path"] = df.id.apply(lambda x: str(DATA_SYNTHETIC_ROOT / "img" / path.stem / f"{x}.jpg")) df = df[["path", "text"]] df["synthetic"] = True data.append(df) diff --git a/manga_ocr_dev/training/get_model.py b/manga_ocr_dev/training/get_model.py index 78a7395..c7fc478 100644 --- a/manga_ocr_dev/training/get_model.py +++ b/manga_ocr_dev/training/get_model.py @@ -40,21 +40,15 @@ def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None): if num_decoder_layers is not None: if decoder_config.model_type == "bert": - decoder.bert.encoder.layer = decoder.bert.encoder.layer[ - -num_decoder_layers: - ] + decoder.bert.encoder.layer = decoder.bert.encoder.layer[-num_decoder_layers:] elif decoder_config.model_type in ("roberta", "xlm-roberta"): - decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[ - -num_decoder_layers: - ] + decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[-num_decoder_layers:] else: raise ValueError(f"Unsupported model_type: {decoder_config.model_type}") decoder_config.num_hidden_layers = num_decoder_layers - config = VisionEncoderDecoderConfig.from_encoder_decoder_configs( - encoder.config, decoder.config - ) + config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config) config.tie_word_embeddings = False model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config) diff --git a/manga_ocr_dev/training/metrics.py b/manga_ocr_dev/training/metrics.py index 78ba2df..20aeb71 100644 --- a/manga_ocr_dev/training/metrics.py +++ b/manga_ocr_dev/training/metrics.py @@ -21,9 +21,7 @@ def compute_metrics(self, pred): results = {} try: - results["cer"] = self.cer_metric.compute( - predictions=pred_str, references=label_str - ) + results["cer"] = self.cer_metric.compute(predictions=pred_str, references=label_str) except Exception as e: print(e) print(pred_str) diff --git a/manga_ocr_dev/training/train.py b/manga_ocr_dev/training/train.py index 4bcebf7..d6ef516 100644 --- a/manga_ocr_dev/training/train.py +++ b/manga_ocr_dev/training/train.py @@ -20,17 +20,11 @@ def run( ): wandb.login() - model, processor = get_model( - encoder_name, decoder_name, max_len, num_decoder_layers - ) + model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers) # keep package 0 for validation - train_dataset = MangaDataset( - processor, "train", max_len, augment=True, skip_packages=[0] - ) - eval_dataset = MangaDataset( - processor, "test", max_len, augment=False, skip_packages=range(1, 9999) - ) + train_dataset = MangaDataset(processor, "train", max_len, augment=True, skip_packages=[0]) + eval_dataset = MangaDataset(processor, "test", max_len, augment=False, skip_packages=range(1, 9999)) metrics = Metrics(processor) diff --git a/manga_ocr_dev/training/utils.py b/manga_ocr_dev/training/utils.py index 1dff603..cb60e8b 100644 --- a/manga_ocr_dev/training/utils.py +++ b/manga_ocr_dev/training/utils.py @@ -37,9 +37,4 @@ def decoder_summary(model, batch_size=4): def tensor_to_image(img): - return ( - ((img.cpu().numpy() + 1) / 2 * 255) - .clip(0, 255) - .astype(np.uint8) - .transpose(1, 2, 0) - ) + return ((img.cpu().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0) diff --git a/pyproject.toml b/pyproject.toml index c8f5def..6f9f0c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,12 @@ dependencies = [ "unidic_lite", ] +[project.optional-dependencies] +dev = [ + "pytest", + "ruff", +] + [project.urls] Homepage = "https://github.com/kha-white/manga-ocr" @@ -38,5 +44,9 @@ packages = ["manga_ocr"] [tool.setuptools.dynamic] version = {attr = "manga_ocr._version.__version__"} +[tool.ruff] +line-length = 120 +indent-width = 4 + [project.scripts] manga_ocr = "manga_ocr.__main__:main" \ No newline at end of file diff --git a/tests/test_ocr.py b/tests/test_ocr.py index 3ca7f9b..1369585 100644 --- a/tests/test_ocr.py +++ b/tests/test_ocr.py @@ -9,9 +9,7 @@ def test_ocr(): mocr = MangaOcr() - expected_results = json.loads( - (TEST_DATA_ROOT / "expected_results.json").read_text(encoding="utf-8") - ) + expected_results = json.loads((TEST_DATA_ROOT / "expected_results.json").read_text(encoding="utf-8")) for item in expected_results: result = mocr(TEST_DATA_ROOT / "images" / item["filename"])