Skip to content

Commit

Permalink
add ONNXModifier class to perform various optimisations to the ONNX m…
Browse files Browse the repository at this point in the history
…odel before converting to RVC4
  • Loading branch information
ptoupas committed Dec 10, 2024
1 parent f139d32 commit 5e161e9
Show file tree
Hide file tree
Showing 4 changed files with 997 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.2
Expand Down
15 changes: 15 additions & 0 deletions modelconverter/packages/rvc4/exporter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import shutil
import subprocess
import time
Expand All @@ -6,6 +7,7 @@
from typing import Any, Dict, List, NamedTuple, Optional, cast

from modelconverter.utils import (
ONNXModifier,
exit_with,
onnx_attach_normalization_to_inputs,
read_image,
Expand Down Expand Up @@ -57,6 +59,19 @@ def __init__(self, config: SingleStageConfig, output_dir: Path):
self._attach_suffix(self.input_model, "modified.onnx"),
self.inputs,
)

onnx_modifier = ONNXModifier(
model_path=self.input_model,
output_path=self._attach_suffix(
self.input_model, "modified_optimised.onnx"
),
)
onnx_modifier.modify_onnx()
if onnx_modifier.compare_outputs():
logger.info("ONNX model has been optimised for RVC4.")
shutil.move(onnx_modifier.output_path, self.input_model)
else:
os.remove(onnx_modifier.output_path)
else:
logger.warning(
"Input file type is not ONNX. Skipping pre-processing."
Expand Down
3 changes: 2 additions & 1 deletion modelconverter/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
modelconverter_config_to_nn,
process_nn_archive,
)
from .onnx_tools import onnx_attach_normalization_to_inputs
from .onnx_tools import ONNXModifier, onnx_attach_normalization_to_inputs
from .subprocess import subprocess_run

__all__ = [
Expand All @@ -37,6 +37,7 @@
"S3Exception",
"SubprocessException",
"exit_with",
"ONNXModifier",
"onnx_attach_normalization_to_inputs",
"read_calib_dir",
"read_image",
Expand Down
Loading

0 comments on commit 5e161e9

Please sign in to comment.