-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathsqueeze.html
139 lines (121 loc) · 4.3 KB
/
squeeze.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
<!DOCTYPE html>
<html>
<head>
<title>WONNX</title>
<style type="text/css">
html {
font-family: Verdana, Arial, Helvetica, sans-serif;
color: rgb(77, 77, 77);
font-size: 10pt;
line-height: 1.2em;
}
</style>
</head>
<body>
<h1>WONNX: Squeeze image classification</h1>
<div id="log">Loading... (If this doesn't load, you may have to enable WebGPU in your browser and reload this page)</div>
<video id="player" autoplay width="224" height="224"></video>
<div id="perf"></div>
<script type="module">
import init, { Session, Input } from "/target/pkg/wonnx.js";
async function fetchBytes(url) {
const reply = await fetch(url);
const blob = await reply.arrayBuffer();
const arr = new Uint8Array(blob);
return arr;
}
async function run() {
try {
// Load model, labels file and WONNX
const labels = fetch("../data/models/squeeze-labels.txt").then(r => r.text());
const [modelBytes, initResult, labelsResult] = await Promise.all([fetchBytes("../data/models/opt-squeeze.onnx"), init(), labels])
console.log("Initialized", { modelBytes, initResult, Session, labelsResult});
const squeezeWidth = 224;
const squeezeHeight = 224;
// Start inference session
const session = await Session.fromBytes(modelBytes);
// Parse labels
const labelsList = labelsResult.split(/\n/g);
console.log({labelsList});
// Start video
const player = document.getElementById('player');
const constraints = {
video: true,
};
const stream = await navigator.mediaDevices.getUserMedia(constraints);
player.srcObject = stream;
// Create a canvas to capture video frames
const canvas = document.createElement('canvas');
canvas.width = squeezeWidth;
canvas.height = squeezeHeight;
const context = canvas.getContext('2d', {willReadFrequently: true});
let inferenceCount = 0;
let inferenceTime = 0;
// Captures a frame and produces inference
async function inferImage() {
try {
// Draw the video frame to the canvas.
context.drawImage(player, 0, 0, canvas.width, canvas.height);
const data = context.getImageData(0, 0, canvas.width, canvas.height);
const image = new Float32Array(224 * 224 * 3);
// Transform the image data in the format expected by SqueezeNet
const planes = 3; // SqueezeNet expects RGB
const valuesPerPixel = 4; // source data is RGBA
let mean = [0.485, 0.456, 0.406];
let std = [0.229, 0.224, 0.225];
for (let plane = 0; plane < planes; plane++) {
for (let y = 0; y < squeezeHeight; y++) {
for (let x = 0; x < squeezeWidth; x++) {
const v = data.data[y * squeezeWidth * valuesPerPixel + x * valuesPerPixel + plane] / 255.0;
image[plane * (squeezeWidth * squeezeHeight) + y * squeezeWidth + x] = (v - mean[plane]) / std[plane];
}
}
}
// Start inference
const input = new Input();
input.insert("data", image);
const start = performance.now();
const result = await session.run(input);
const duration = performance.now() - start;
inferenceCount++;
inferenceTime += duration;
input.free();
// Find the label with the highest probability
const probs = result.get("squeezenet0_flatten0_reshape0");
let maxProb = -1;
let maxIndex = -1;
for (let index = 0; index < probs.length; index++) {
const p = probs[index];
if (p > maxProb) {
maxProb = p;
maxIndex = index;
}
}
// Write result
document.getElementById("log").innerText = `${labelsList[maxIndex]} (${maxProb})`;
const avgFrameTime = inferenceTime / inferenceCount;
document.getElementById("perf").innerText = `Inference time: ${avgFrameTime.toFixed(2)}ms, at most ${Math.floor(1000/avgFrameTime)} fps`;
}
catch (e) {
console.error(e, e.toString());
}
}
// Capture and infer as fast as the browser allows
function tick() {
window.requestAnimationFrame(async () => {
console.time("frame");
await inferImage();
console.timeEnd("frame");
tick();
});
}
tick();
}
catch(e) {
console.error(e, e.toString());
}
}
run();
</script>
</body>
</html>