Skip to content

Commit

Permalink
Resolves: Run model loading on separate thread (#19)
Browse files Browse the repository at this point in the history
* model loading now on separate thread for better ux

---------

Co-authored-by: Philip Colangelo <[email protected]>
  • Loading branch information
pcolange and Philip Colangelo authored Feb 27, 2025
1 parent 57fdb00 commit 1b7b9b0
Show file tree
Hide file tree
Showing 9 changed files with 576 additions and 815 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.1.1",
version="1.1.2",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand Down
2 changes: 1 addition & 1 deletion src/digest/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class ProgressDialog(QProgressDialog):
"""A pop up window with a progress label that goes from 1 to 100"""

def __init__(self, label: str, num_steps: int, parent=None):
def __init__(self, label: str, num_steps: int = 0, parent=None):
"""
label: the text to be shown in the pop up dialog
num_steps: the total number of events the progress bar will load through
Expand Down
822 changes: 216 additions & 606 deletions src/digest/main.py

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions src/digest/model_class/digest_onnx_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.

# pylint: disable=no-name-in-module
import os
from typing import List, Dict, Optional, Tuple, cast
from PySide6.QtCore import QRunnable, Signal, Slot, QObject
from datetime import datetime
import importlib.metadata
from collections import OrderedDict
Expand Down Expand Up @@ -654,3 +656,44 @@ def save_text_report(self, filepath: str) -> None:
f_p.write("Output Tensor(s) Information:\n")
f_p.write(output_table.get_string())
f_p.write("\n\n")


class WorkerSignals(QObject):
completed = Signal(DigestOnnxModel)


class LoadDigestOnnxModelWorker(QRunnable):

def __init__(
self,
model_file_path: str,
model_name: str,
):
super().__init__()
self.signals = WorkerSignals()
self.tab_name = model_name
self.model_file_path = model_file_path
self.unique_id: Optional[str] = None

@Slot()
def run(self):
try:
model_proto = onnx_utils.load_onnx(
self.model_file_path, load_external_data=False
)
opt_model, _ = onnx_utils.optimize_onnx_model(model_proto)
except FileNotFoundError as e:
print(f"File not found: {e.filename}")

digest_model = DigestOnnxModel(
opt_model,
model_name=self.tab_name,
onnx_filepath=self.model_file_path,
)

self.unique_id = digest_model.unique_id

if not self.tab_name:
self.tab_name = digest_model.model_name

self.signals.completed.emit(digest_model)
25 changes: 25 additions & 0 deletions src/digest/model_class/digest_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import csv
import ast
import re
from PySide6.QtCore import QRunnable, Signal, Slot, QObject
from typing import Tuple, Optional, List, Dict, Any, Union
import yaml
from digest.model_class.digest_model import (
Expand Down Expand Up @@ -153,6 +154,30 @@ def save_text_report(self, filepath: str) -> None:
return


class WorkerSignals(QObject):
completed = Signal(DigestReportModel)


class LoadDigestReportModelWorker(QRunnable):

def __init__(
self,
model_file_path: str,
model_name: str,
):
super().__init__()
self.signals = WorkerSignals()
self.tab_name = model_name
self.model_file_path = model_file_path
self.unique_id: Optional[str] = None

@Slot()
def run(self):

digest_model = DigestReportModel(self.model_file_path)
self.signals.completed.emit(digest_model)


def validate_yaml(report_file_path: str) -> bool:
"""Check that the provided yaml file is indeed a Digest Report file."""
expected_keys = [
Expand Down
146 changes: 138 additions & 8 deletions src/digest/modelsummary.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright(C) 2024 Advanced Micro Devices, Inc. All rights reserved.

import os

# pylint: disable=invalid-name
from typing import Optional, Union
import os
from datetime import datetime
from typing import Optional

# pylint: disable=no-name-in-module
from PySide6.QtWidgets import QWidget
from PySide6.QtWidgets import QWidget, QTableWidgetItem
from PySide6.QtGui import QMovie
from PySide6.QtCore import QSize

Expand All @@ -16,6 +16,7 @@
from digest.freeze_inputs import FreezeInputs
from digest.popup_window import PopupWindow
from digest.qt_utils import apply_dark_style_sheet
from digest.model_class.digest_model import SupportedModelTypes, DigestModel
from digest.model_class.digest_onnx_model import DigestOnnxModel
from digest.model_class.digest_report_model import DigestReportModel

Expand All @@ -25,20 +26,19 @@

class modelSummary(QWidget):

def __init__(
self, digest_model: Union[DigestOnnxModel, DigestReportModel], parent=None
):
def __init__(self, digest_model: DigestModel, parent=None):
super().__init__(parent)
self.ui = Ui_modelSummary()
self.ui.setupUi(self)
apply_dark_style_sheet(self)

self.file: Optional[str] = None
self.ui.warningLabel.hide()
self.digest_model = digest_model
self.model_id = digest_model.unique_id
self.model_proto: Optional[ModelProto] = None
model_name: str = digest_model.model_name if digest_model.model_name else ""

self.png_file_path: Optional[str] = None
self.load_gif = QMovie(":/assets/gifs/load.gif")
# We set the size of the GIF to half the original
self.load_gif.setScaledSize(QSize(214, 120))
Expand All @@ -50,13 +50,143 @@ def __init__(
self.freeze_inputs: Optional[FreezeInputs] = None
self.freeze_window: Optional[QWidget] = None

self.model_type: Optional[SupportedModelTypes] = None

if isinstance(digest_model, DigestOnnxModel):
self.model_type = SupportedModelTypes.ONNX
self.model_proto = (
digest_model.model_proto if digest_model.model_proto else ModelProto()
)
self.freeze_inputs = FreezeInputs(self.model_proto, model_name)
self.ui.freezeButton.clicked.connect(self.open_freeze_inputs)
self.freeze_inputs.complete_signal.connect(self.close_freeze_window)
elif isinstance(digest_model, DigestReportModel):
self.model_type = SupportedModelTypes.REPORT

# Hide some of the components
self.ui.similarityCorrelation.hide()
self.ui.similarityCorrelationStatic.hide()

self.file = digest_model.filepath
self.setObjectName(model_name)
self.ui.modelName.setText(model_name)
if self.file:
self.ui.modelFilename.setText(self.file)

self.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))

self.ui.parameters.setText(format(digest_model.parameters, ","))

node_type_counts = digest_model.node_type_counts
if len(node_type_counts) < 15:
bar_spacing = 40
else:
bar_spacing = 20
self.ui.opHistogramChart.bar_spacing = bar_spacing
self.ui.opHistogramChart.set_data(node_type_counts)
self.ui.nodes.setText(str(sum(node_type_counts.values())))

# Format flops with commas if available
flops_str = "N/A"
if digest_model.flops is not None:
flops_str = format(digest_model.flops, ",")

# Set up the FLOPs pie chart
pie_chart_labels, pie_chart_data = zip(
*digest_model.node_type_flops.items()
)
self.ui.flopsPieChart.set_data(
"FLOPs Intensity Per Op Type",
pie_chart_labels,
pie_chart_data,
)

# Set up the params pie chart
pie_chart_labels, pie_chart_data = zip(
*digest_model.node_type_parameters.items()
)
self.ui.parametersPieChart.set_data(
"Parameter Intensity Per Op Type",
pie_chart_labels,
pie_chart_data,
)

self.ui.flops.setText(flops_str)

# Inputs Table
self.ui.inputsTable.setRowCount(len(digest_model.model_inputs))

for row_idx, (input_name, input_info) in enumerate(
digest_model.model_inputs.items()
):
self.ui.inputsTable.setItem(row_idx, 0, QTableWidgetItem(input_name))
self.ui.inputsTable.setItem(
row_idx, 1, QTableWidgetItem(str(input_info.shape))
)
self.ui.inputsTable.setItem(
row_idx, 2, QTableWidgetItem(str(input_info.dtype))
)
self.ui.inputsTable.setItem(
row_idx, 3, QTableWidgetItem(str(input_info.size_kbytes))
)

self.ui.inputsTable.resizeColumnsToContents()
self.ui.inputsTable.resizeRowsToContents()

# Outputs Table
self.ui.outputsTable.setRowCount(len(digest_model.model_outputs))
for row_idx, (output_name, output_info) in enumerate(
digest_model.model_outputs.items()
):
self.ui.outputsTable.setItem(row_idx, 0, QTableWidgetItem(output_name))
self.ui.outputsTable.setItem(
row_idx, 1, QTableWidgetItem(str(output_info.shape))
)
self.ui.outputsTable.setItem(
row_idx, 2, QTableWidgetItem(str(output_info.dtype))
)
self.ui.outputsTable.setItem(
row_idx, 3, QTableWidgetItem(str(output_info.size_kbytes))
)

self.ui.outputsTable.resizeColumnsToContents()
self.ui.outputsTable.resizeRowsToContents()

if isinstance(digest_model, DigestOnnxModel):

if digest_model.model_version:
# ModelProto Info
self.ui.modelProtoTable.setItem(
0, 1, QTableWidgetItem(digest_model.model_version)
)

if digest_model.graph_name:
self.ui.modelProtoTable.setItem(
1, 1, QTableWidgetItem(digest_model.graph_name)
)

producer_txt = (
f"{digest_model.producer_name} {digest_model.producer_version}"
)
self.ui.modelProtoTable.setItem(2, 1, QTableWidgetItem(producer_txt))

self.ui.modelProtoTable.setItem(
3, 1, QTableWidgetItem(str(digest_model.ir_version))
)

for domain, version in digest_model.imports.items():
row_idx = self.ui.importsTable.rowCount()
self.ui.importsTable.insertRow(row_idx)
if domain == "" or domain == "ai.onnx":
self.ui.opsetVersion.setText(str(version))
domain = "ai.onnx"
self.ui.importsTable.setItem(row_idx, 0, QTableWidgetItem(domain))
self.ui.importsTable.setItem(row_idx, 1, QTableWidgetItem(str(version)))
row_idx += 1

self.ui.importsTable.resizeColumnsToContents()
self.ui.modelProtoTable.resizeColumnsToContents()
self.setObjectName(model_name)

def open_freeze_inputs(self):
if self.freeze_inputs:
Expand Down
16 changes: 13 additions & 3 deletions src/digest/multi_model_selection_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def set_directory(self, directory: str):

total_num_models = len(onnx_file_list) + len(report_file_list)

if total_num_models == 0:
self.update_message_label("No models found in the selected directory.")
return

serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list)

progress = ProgressDialog("Loading models", total_num_models, self)
Expand Down Expand Up @@ -275,15 +279,19 @@ def set_directory(self, directory: str):
except DecodeError as error:
print(f"Error decoding model {filepath}: {error}")

progress = ProgressDialog("Processing Models", total_num_models, self)
progress = ProgressDialog(
"Processing Models",
len(serialized_models_paths) + len(report_file_list),
self,
)

num_duplicates = 0
self.item_model.clear()
self.ui.duplicateListWidget.clear()
for paths in serialized_models_paths.values():
progress.step()
if progress.wasCanceled():
break
progress.step()
if len(paths) > 1:
num_duplicates += 1
self.ui.duplicateListWidget.addItem(paths[0])
Expand All @@ -298,9 +306,9 @@ def set_directory(self, directory: str):
duplicate_reports: Dict[str, List[str]] = {}
processed_files = set()
for i in range(len(report_file_list)):
progress.step()
if progress.wasCanceled():
break
progress.step()
path1 = report_file_list[i]
if path1 in processed_files:
continue # Skip already processed files
Expand All @@ -316,6 +324,8 @@ def set_directory(self, directory: str):
duplicate_reports[path1].append(path2)
processed_files.add(path2)

progress.close()

for path, dupes in duplicate_reports.items():
if dupes:
self.ui.duplicateListWidget.addItem(path)
Expand Down
Loading

0 comments on commit 1b7b9b0

Please sign in to comment.