Skip to content

Commit

Permalink
handle dupes with reports in multimodel analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Colangelo committed Dec 31, 2024
1 parent dd84d64 commit 03af35c
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 136 deletions.
142 changes: 103 additions & 39 deletions src/digest/model_class/digest_report_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import csv
import ast
import re
from typing import Tuple, Optional
from typing import Tuple, Optional, List, Dict, Any, Union
import yaml
from digest.model_class.digest_model import (
DigestModel,
Expand All @@ -15,7 +15,9 @@
)


def parse_tensor_info(csv_tensor_cell_value) -> Tuple[str, list, str, float]:
def parse_tensor_info(
csv_tensor_cell_value,
) -> Tuple[str, list, str, Union[str, float]]:
"""This is a helper function that expects the input to come from parsing
the nodes csv and extracting either an input or output tensor."""

Expand All @@ -38,7 +40,10 @@ def parse_tensor_info(csv_tensor_cell_value) -> Tuple[str, list, str, float]:
if not isinstance(shape, list):
shape = list(shape)

return name.strip(), shape, dtype.strip(), float(size.split()[0])
if size != "None":
size = float(size.split()[0])

return name.strip(), shape, dtype.strip(), size


class DigestReportModel(DigestModel):
Expand All @@ -49,7 +54,7 @@ def __init__(

self.model_type = SupportedModelTypes.REPORT

self.is_valid = self.validate_yaml(report_filepath)
self.is_valid = validate_yaml(report_filepath)

if not self.is_valid:
print(f"The yaml file {report_filepath} is not a valid digest report.")
Expand Down Expand Up @@ -131,41 +136,6 @@ def __init__(
}
)

def validate_yaml(self, report_file_path: str) -> bool:
"""Check that the provided yaml file is indeed a Digest Report file."""
expected_keys = [
"report_date",
"model_file",
"model_type",
"model_name",
"flops",
"node_type_flops",
"node_type_parameters",
"node_type_counts",
"input_tensors",
"output_tensors",
]
try:
with open(report_file_path, "r", encoding="utf-8") as file:
yaml_content = yaml.safe_load(file)

if not isinstance(yaml_content, dict):
print("Error: YAML content is not a dictionary")
return False

for key in expected_keys:
if key not in yaml_content:
# print(f"Error: Missing required key '{key}'")
return False

return True
except yaml.YAMLError as _:
# print(f"Error parsing YAML file: {e}")
return False
except IOError as _:
# print(f"Error reading file: {e}")
return False

def parse_model_nodes(self) -> None:
"""There are no model nodes to parse"""

Expand All @@ -174,3 +144,97 @@ def save_yaml_report(self, filepath: str) -> None:

def save_text_report(self, filepath: str) -> None:
"""Report models are not intended to be saved"""


def validate_yaml(report_file_path: str) -> bool:
"""Check that the provided yaml file is indeed a Digest Report file."""
expected_keys = [
"report_date",
"model_file",
"model_type",
"model_name",
"flops",
"node_type_flops",
"node_type_parameters",
"node_type_counts",
"input_tensors",
"output_tensors",
]
try:
with open(report_file_path, "r", encoding="utf-8") as file:
yaml_content = yaml.safe_load(file)

if not isinstance(yaml_content, dict):
print("Error: YAML content is not a dictionary")
return False

for key in expected_keys:
if key not in yaml_content:
# print(f"Error: Missing required key '{key}'")
return False

return True
except yaml.YAMLError as _:
# print(f"Error parsing YAML file: {e}")
return False
except IOError as _:
# print(f"Error reading file: {e}")
return False


def compare_yaml_files(
file1: str, file2: str, skip_keys: Optional[List[str]] = None
) -> bool:
"""
Compare two YAML files, ignoring specified keys.
:param file1: Path to the first YAML file
:param file2: Path to the second YAML file
:param skip_keys: List of keys to ignore in the comparison
:return: True if the files are equal (ignoring specified keys), False otherwise
"""

def load_yaml(file_path: str) -> Dict[str, Any]:
with open(file_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)

def compare_dicts(
dict1: Dict[str, Any], dict2: Dict[str, Any], path: str = ""
) -> List[str]:
differences = []
all_keys = set(dict1.keys()) | set(dict2.keys())

for key in all_keys:
if skip_keys and key in skip_keys:
continue

current_path = f"{path}.{key}" if path else key

if key not in dict1:
differences.append(f"Key '{current_path}' is missing in the first file")
elif key not in dict2:
differences.append(
f"Key '{current_path}' is missing in the second file"
)
elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
differences.extend(compare_dicts(dict1[key], dict2[key], current_path))
elif dict1[key] != dict2[key]:
differences.append(
f"Value mismatch for key '{current_path}': {dict1[key]} != {dict2[key]}"
)

return differences

yaml1 = load_yaml(file1)
yaml2 = load_yaml(file2)

differences = compare_dicts(yaml1, yaml2)

if differences:
# print("Differences found:")
# for diff in differences:
# print(f"- {diff}")
return False
else:
# print("No differences found.")
return True
36 changes: 17 additions & 19 deletions src/digest/multi_model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,13 @@ def __init__(
self.ui.dataTable.resizeColumnsToContents()
self.ui.dataTable.resizeRowsToContents()

node_type_counter = {}
# Until we use the unique_id to represent the model contents we store
# the entire model as the key so that we can store models that happen to have
# the same name. There is a guarantee that the models will not be duplicates.
node_type_counter: Dict[
Union[DigestOnnxModel, DigestReportModel], NodeTypeCounts
] = {}

for i, digest_model in enumerate(model_list):
progress.step()
progress.setLabelText(f"Analyzing model {digest_model.model_name}")
Expand Down Expand Up @@ -143,26 +149,18 @@ def __init__(
"flops": digest_model.flops,
}

# Here we are creating a name that is a combination of the model name
# and the model type.
node_type_counter_key = (
f"{digest_model.model_name}-{digest_model.model_type.value}"
)

if node_type_counter_key in node_type_counter:
if digest_model in node_type_counter:
print(
f"Warning! {digest_model.model_name} with model type "
f"{digest_model.model_type.value} has already been added to "
"to the stacked histogram, skipping."
f"{digest_model.model_type.value} and id {digest_model.unique_id} "
"has already been added to the stacked histogram, skipping."
)
continue

node_type_counter[node_type_counter_key] = digest_model.node_type_counts
node_type_counter[digest_model] = digest_model.node_type_counts

# Update global data structure for node type counter
self.global_node_type_counter.update(
node_type_counter[node_type_counter_key]
)
self.global_node_type_counter.update(node_type_counter[digest_model])

node_shape_counts = digest_model.get_node_shape_counts()

Expand All @@ -180,20 +178,20 @@ def __init__(
# Create stacked op histograms
max_count = 0
top_ops = [key for key, _ in self.global_node_type_counter.most_common(20)]
for model_name, _ in node_type_counter.items():
max_local = Counter(node_type_counter[model_name]).most_common()[0][1]
for model, _ in node_type_counter.items():
max_local = Counter(node_type_counter[model]).most_common()[0][1]
if max_local > max_count:
max_count = max_local
for idx, model_name in enumerate(node_type_counter):
for idx, model in enumerate(node_type_counter):
stacked_histogram_widget = StackedHistogramWidget()
ordered_dict = OrderedDict()
model_counter = Counter(node_type_counter[model_name])
model_counter = Counter(node_type_counter[model])
for key in top_ops:
ordered_dict[key] = model_counter.get(key, 0)
title = "Stacked Op Histogram" if idx == 0 else ""
stacked_histogram_widget.set_data(
ordered_dict,
model_name=model_name,
model_name=model.model_name,
y_max=max_count,
title=title,
set_ticks=False,
Expand Down
49 changes: 35 additions & 14 deletions src/digest/multi_model_selection_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from digest.multi_model_analysis import MultiModelAnalysis
from digest.qt_utils import apply_dark_style_sheet, prompt_user_ram_limit
from digest.model_class.digest_onnx_model import DigestOnnxModel
from digest.model_class.digest_report_model import DigestReportModel
from digest.model_class.digest_report_model import DigestReportModel, compare_yaml_files
from utils import onnx_utils


Expand Down Expand Up @@ -203,7 +203,7 @@ def set_directory(self, directory: str):
else:
return

progress = ProgressDialog("Searching Directory for ONNX Files", 0, self)
progress = ProgressDialog("Searching directory for model files", 0, self)

onnx_file_list = list(
glob.glob(os.path.join(directory, "**/*.onnx"), recursive=True)
Expand All @@ -227,11 +227,11 @@ def set_directory(self, directory: str):
serialized_models_paths: defaultdict[bytes, List[str]] = defaultdict(list)

progress.close()
progress = ProgressDialog("Loading Models", total_num_models, self)
progress = ProgressDialog("Loading models", total_num_models, self)

memory_limit_percentage = 90
models_loaded = 0
for filepath in onnx_file_list:
for filepath in onnx_file_list + report_file_list:
progress.step()
if progress.user_canceled:
break
Expand Down Expand Up @@ -284,17 +284,38 @@ def set_directory(self, directory: str):
self.ui.duplicateListWidget.addItem(paths[0])
for dupe in paths[1:]:
self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}")
item = QStandardItem(paths[0])
item.setCheckable(True)
item.setCheckState(Qt.CheckState.Checked)
self.item_model.appendRow(item)
else:
item = QStandardItem(paths[0])
item.setCheckable(True)
item.setCheckState(Qt.CheckState.Checked)
self.item_model.appendRow(item)
item = QStandardItem(paths[0])
item.setCheckable(True)
item.setCheckState(Qt.CheckState.Checked)
self.item_model.appendRow(item)

for path in report_file_list:
# Use a standard nested loop to detect duplicate reports
duplicate_reports: Dict[str, List[str]] = {}
processed_files = set()
for i in range(len(report_file_list)):
progress.step()
if progress.user_canceled:
break
path1 = report_file_list[i]
if path1 in processed_files:
continue # Skip already processed files

# We will use path1 as the unique model and save a list of duplicates
duplicate_reports[path1] = []
for j in range(i + 1, len(report_file_list)):
path2 = report_file_list[j]
if compare_yaml_files(
path1, path2, ["report_date", "model_files", "digest_version"]
):
num_duplicates += 1
duplicate_reports[path1].append(path2)
processed_files.add(path2)

for path, dupes in duplicate_reports.items():
if dupes:
self.ui.duplicateListWidget.addItem(path)
for dupe in dupes:
self.ui.duplicateListWidget.addItem(f"- Duplicate: {dupe}")
item = QStandardItem(path)
item.setCheckable(True)
item.setCheckState(Qt.CheckState.Checked)
Expand Down
Loading

0 comments on commit 03af35c

Please sign in to comment.