Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🏂 Bert SNPE conversion & evaluation case #925

Merged
merged 10 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/bert/npu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Bert model optimization on Qualcomm NPU with SNPE SDK
This folder contains a sample use case of Olive to convert an bert model Onnx model, then to SNPE DLC and to evaluate the accuracy of the DLC model.

Performs optimization pipeline:
- *Pytorch Model -> Onnx Model -> SNPE Model*

## Prerequisites
### Download and unzip SNPE SDK
Download the SNPE SDK zip following [instructions from Qualcomm](https://developer.qualcomm.com/software/qualcomm-neural-processing-sdk)

We test it with SNPE v2.18.0.240101.

Unzip the file and set the unzipped directory path as environment variable `SNPE_ROOT`.

### Configure SNPE
```sh
python -m olive.platform_sdk.qualcomm.configure --py_version 3.8 --sdk snpe
```

## Run sample
Run the conversion and quantization locally. Only supports `x64-Linux`.
```
python -m olive.workflows.run --config bert_snpe.json
```

## Issues

1. "Module 'qti.aisw.converters' has no attribute 'onnx':
Refer to this: https://developer.qualcomm.com/comment/21810#comment-21810,
change the import statement in `{SNPE_ROOT}/lib/python/qti/aisw/converters/onnx/onnx_to_ir.py:L30` to:
```python
from qti.aisw.converters.onnx import composable_custom_op_utils as ComposableCustomOp
```
91 changes: 91 additions & 0 deletions examples/bert/npu/bert_snpe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
"input_model":{
"type": "PyTorchModel",
"config": {
"hf_config": {
"model_name": "Intel/bert-base-uncased-mrpc",
"task": "text-classification"
},
"io_config": {
"input_names": ["input_ids", "attention_mask", "token_type_ids"],
"input_types": ["int64", "int64", "int64"],
"input_shapes": [[2, 128], [2, 128], [2, 128]],
"output_names": ["logits"]
}
}
},
"evaluators": {
"snpe_evaluator": {
"metrics": [
{
"name": "accuracy",
"type": "accuracy",
"backend": "huggingface_metrics",
"data_config": "snpe_dataset",
"user_config": {
"inference_settings": {
"snpe":{
"return_numpy_results": true
}
}
},
"sub_types": [
{"name": "accuracy", "priority": 1, "goal": {"type": "max-degradation", "value": 0.05}},
{"name": "f1"}
]
}
]
}
},
"data_configs": {
"snpe_dataset": {
"name": "snpe_dataset",
"type": "HuggingfaceContainer",
"user_script": "user_script.py",
"components": {
"post_process_data": {
"type": "snpe_post_process"
}
},
"params_config": {
"model_name": "Intel/bert-base-uncased-mrpc",
trajepl marked this conversation as resolved.
Show resolved Hide resolved
"task": "text-classification",
"batch_size": 2,
"data_name": "glue",
"input_cols": ["sentence1", "sentence2"],
"label_cols": ["label"],
"split": "validation",
"subset": "mrpc",
"component_kwargs": {
"pre_process_data": {
"max_length": 128,
"padding": "max_length"
}
}
}
}
},
"passes": {
"conversion": {
"type": "OnnxConversion",
"config": {
"target_opset": 13
}
},
"snpe_to_dlc": {
"type": "SNPEConversion",
"config": {
"input_names": ["input_ids", "attention_mask", "token_type_ids"],
"input_shapes": [[2, 128], [2, 128], [2, 128]],
"output_names": ["logits"]
}
}
},
"engine": {
"log_severity_level": 0,
"clean_cache": true,
"evaluator": "snpe_evaluator",
"evaluate_input_model": false,
"output_dir" : "models/bert_snpe"
}
}
11 changes: 11 additions & 0 deletions examples/bert/npu/user_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from olive.data.registry import Registry


@Registry.register_post_process()
def snpe_post_process(output_data, **kwargs):
import torch

logits = torch.tensor(output_data["logits"])
_, preds = torch.max(logits, dim=-1)

return preds
2 changes: 1 addition & 1 deletion examples/inception/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@

@Registry.register_post_process()
def inception_post_process(output):
return output["results"]["InceptionV3/Predictions/Reshape_1:0"].squeeze(1).argmax(axis=1)
return output["InceptionV3/Predictions/Reshape_1:0"].argmax(axis=1)
2 changes: 1 addition & 1 deletion examples/mobilenet/raw_qnn_sdk_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"data_dir": "data",
"batch_size": 1,
"dataloader_func": "qnn_data_loader",
"post_processing_func": "post_process",
"post_processing_func": "qnn_sdk_post_process",
"inference_settings": {
"qnn":{
"backend": "libQnnCpu"
Expand Down
4 changes: 4 additions & 0 deletions examples/mobilenet/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,9 @@ def post_process(output):
return output.argmax(axis=1)


def qnn_sdk_post_process(output):
return np.array([output.argmax(axis=-1)])


def mobilenet_calibration_reader(data_dir, batch_size, *args, **kwargs):
return MobileNetCalibrationDataReader(data_dir, batch_size=batch_size)
1 change: 1 addition & 0 deletions olive/data/component/pre_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _tokenizer_and_align_labels(examples):
*[examples[input_col] for input_col in input_cols],
padding=kwargs.get("padding", True),
truncation=kwargs.get("truncation", True),
max_length=kwargs.get("max_length"),
is_split_into_words=kwargs.get("is_split_into_words", False),
add_special_tokens=kwargs.get("add_special_tokens", True),
)
Expand Down
62 changes: 41 additions & 21 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,23 +942,34 @@ def _inference(
execution_providers: Union[str, List[str]] = None,
) -> Tuple[OliveModelOutput, Any]:
dataloader = self._prepare_dataloader(dataloader, model)
session = model.prepare_session(
inference_settings=metric.get_inference_settings(self.framework.lower()), device=device
)
inference_settings = metric.get_inference_settings(self.framework.lower())
# for accuracy evaluation, the `return_numpy_results` is required to be True
# but for model inference, it is not required to be True.
# We just set it to True for simple evaluation.
inference_settings["return_numpy_results"] = True

session = model.prepare_session(inference_settings=inference_settings, device=device)

preds = []
targets = []
logits = []
for data_dir, input_list, labels in dataloader:
result = session(input_list, data_dir)
if post_func:
outputs = post_func(result)
else:
raise ValueError("Post processing function is required for SNPE model")
preds.extend(outputs.tolist())
targets.extend(labels.tolist())
lg = result["results"].get("logits")
logits.extend(lg.to_list() if lg else [])
# as the SNPE inference will return a list of outputs which is beyond the model output shape
# we need to squeeze the fist dimensions of output to get right accuracy metrics
for idx, output in enumerate(result.get("results")):
post_output = output
if post_func:
post_output = post_func(output)
else:
raise ValueError("Post processing function is required for SNPE model")
preds.extend(post_output.tolist())
if isinstance(labels[idx], list):
targets.extend(labels[idx])
else:
targets.append(labels[idx])
# only when return_numpy_results is True, the result is a dict with "logits" key
logits.extend(output.get("logits", np.array([])).tolist())
return OliveModelOutput(preds=preds, logits=logits), targets

def _evaluate_accuracy(
Expand Down Expand Up @@ -998,7 +1009,9 @@ def _evaluate_raw_latency(
def _prepare_dataloader(self, dataloader: Dataset, model: SNPEModelHandler) -> FileListDataLoader:
if isinstance(dataloader, FileListDataLoader):
return dataloader
return FileListCommonDataLoader(dataloader, model.io_config)
# batch_size=1 guarantees that the dataloader returns one input at a time. And the input has
# the same batch_size in original dataloader.
return FileListCommonDataLoader(dataloader, model.io_config, batch_size=1)
jambayk marked this conversation as resolved.
Show resolved Hide resolved


class OpenVINOEvaluator(OliveEvaluator, framework=Framework.OPENVINO):
Expand Down Expand Up @@ -1087,14 +1100,19 @@ def _inference(
targets = []
logits = []
for data_dir, input_list, labels in dataloader:
result = session(input_list, data_dir).get("result")
if post_func:
outputs = post_func(result)
else:
raise ValueError("Post processing function is required for QNN model")
preds.extend(outputs.tolist())
targets.extend(labels.tolist())
logits.extend(result.tolist())
result = session(input_list, data_dir)
for idx, output in enumerate(result.get("result")):
post_output = output
if post_func:
post_output = post_func(output)
else:
raise ValueError("Post processing function is required for QNN model")
preds.extend(post_output.tolist())
if isinstance(labels[idx], list):
targets.extend(labels[idx])
else:
targets.append(labels[idx])
logits.extend(output.tolist())
return OliveModelOutput(preds=preds, logits=logits), targets

def _evaluate_accuracy(
Expand Down Expand Up @@ -1135,7 +1153,9 @@ def _evaluate_raw_latency(
def _prepare_dataloader(self, dataloader: Dataset, model: QNNModelHandler) -> FileListDataLoader:
if isinstance(dataloader, FileListDataLoader):
return dataloader
return FileListCommonDataLoader(dataloader, model.io_config)
# batch_size=1 guarantees that the dataloader returns one input at a time. And the input has
# the same batch_size in original dataloader.
return FileListCommonDataLoader(dataloader, model.io_config, batch_size=1)


class OliveEvaluatorFactory:
Expand Down
3 changes: 1 addition & 2 deletions olive/platform_sdk/qualcomm/snpe/tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def snpe_net_run(
for output_name in output_names:
results[output_name].sort(key=lambda x: x[0])
results[output_name] = [x[1] for x in results[output_name]]
results[output_name] = np.stack(results[output_name])

if workspace is not None:
# sort the result files by the input id
Expand Down Expand Up @@ -303,7 +302,7 @@ def snpe_net_run(

output_dict = {"latencies": latencies}
if return_numpy_results:
output_dict["results"] = results
output_dict["results"] = [{k: v[i] for k, v in results.items()} for i in range(len(results[output_names[0]]))]
if output_dir is not None:
output_dict["output_dir"] = str(output_dir)
output_dict["result_files"] = result_files
Expand Down
2 changes: 1 addition & 1 deletion olive/platform_sdk/qualcomm/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def load_data(self) -> Tuple[str, str, np.ndarray]:
input_file_path = input_dir_path / input_file_name
data.tofile(input_file_path)

annotations.append(annotation)
annotations.append(annotation.tolist())

annotations = None if annotations[0] is None else np.array(annotations)

Expand Down
Loading