forked from octoml/Apple-M1-BERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch_dense_cpu.py
91 lines (81 loc) · 3.04 KB
/
search_dense_cpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
import os
from numpy.lib import ufunclike
import tvm
from tvm import relay, auto_scheduler
import tvm.relay.testing
from tvm.contrib import graph_runtime
import relay_utils
from tvm.contrib.utils import tempdir
import tvm.contrib.graph_runtime as runtime
def run_tuning(tasks, task_weights, log_file):
print("Begin tuning...")
measure_runner = auto_scheduler.RPCRunner(
"m1",
"127.0.0.1",
9190,
min_repeat_ms=300,
timeout=30,
repeat=2
)
tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=10000,
runner=measure_runner,
measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
verbose=2,
)
tuner.tune(tune_option)
if __name__ == "__main__":
#name = "huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad"
name = "bert-base-uncased"
# The number of batches in an input.
batch_size = 1
# The length of each input sequence.
seq_len = 128
# target
target = "llvm -mcpu=apple-latest -mtriple=arm64-apple-macos"
target_host = "llvm -mcpu=apple-latest -mtriple=arm64-apple-macos"
# logfile
log_file = "./assets/{name}_{target}".format(
name=name.replace("/", "_"),
target="cpu"
)
print("Extract tasks...")
mod, params, shape_dict = relay_utils.load_pt_model(name.replace("/", "_"))
if not os.path.exists(log_file):
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"], params, target=target_host, target_host=target_host)
for idx, task in enumerate(tasks):
print("========== Task %d (workload key: %s) ==========" %
(idx, task.workload_key))
print(task.compute_dag)
run_tuning(tasks, task_weights, log_file)
print("Compile...")
with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(opt_level=3, config={"relay.backend.use_auto_scheduler": True}):
lib = relay.build(mod, target=target,
target_host=target_host, params=params)
print("Upload")
tmp = tempdir()
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
remote = auto_scheduler.utils.request_remote("m1", "127.0.0.1", 9190)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)
print("run")
input_shape = [1, 128]
dtype = "int64"
ctx = remote.device(str(target), 0)
module = runtime.graph_executor.GraphModule(rlib["default"](ctx))
data_tvm = tvm.nd.array(
(np.random.uniform(size=input_shape, low=0, high=10000)).astype(dtype))
module.set_input("input_ids", data_tvm)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=30)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print(
"Mean inference time (std dev): %.2f ms (%.2f ms)"
% (np.mean(prof_res), np.std(prof_res))
)