Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using the TensorFlow.js WebGPU backend #190

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions common/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,12 @@ export function getMedianValue(array) {
(array[array.length / 2 - 1] + array[array.length / 2]) / 2;
}

// Set tf.js backend based WebNN's 'MLDeviceType' option
export async function setPolyfillBackend(device) {
// Simulate WebNN's device selection using various tf.js backends.
// MLDeviceType: ['default', 'gpu', 'cpu']
// 'default' or 'gpu': tfjs-backend-webgl, 'cpu': tfjs-backend-wasm
if (!device) device = 'gpu';
// Use 'webgl' by default for better performance.
// Set tf.js backend
export async function setPolyfillBackend(backend) {
if (!backend) backend = 'webgpu';
// Use 'webgpu' by default for better performance.
// Note: 'wasm' backend may run failed on some samples since
// some ops aren't supported on 'wasm' backend at present
const backend = device === 'cpu' ? 'wasm' : 'webgl';
const context = await navigator.ml.createContext();
const tf = context.tf;
if (tf) {
Expand All @@ -221,8 +217,8 @@ export async function setPolyfillBackend(device) {
throw new Error(`Failed to set tf.js backend ${backend}.`);
}
await tf.ready();
let backendInfo = backend == 'wasm' ? 'WASM' : 'WebGL';
if (backendInfo == 'WASM') {
let backendInfo = tf.getBackend();
if (backendInfo == 'wasm') {
const hasSimd = tf.env().features['WASM_HAS_SIMD_SUPPORT'];
const hasThreads = tf.env().features['WASM_HAS_MULTITHREAD_SUPPORT'];
if (hasThreads && hasSimd) {
Expand Down Expand Up @@ -277,7 +273,7 @@ export function getUrlParams() {
}

// Set backend for using WebNN-polyfill or WebNN
export async function setBackend(backend, device) {
export async function setBackend(backend, device, polyfillBackend) {
const webnnPolyfillId = 'webnn_polyfill';
const webnnNodeId = 'webnn_node';
const webnnPolyfillElem = document.getElementById(webnnPolyfillId);
Expand All @@ -304,7 +300,7 @@ export async function setBackend(backend, device) {
// Create WebNN-polyfill script
await loadScript(webnnPolyfillUrl, webnnPolyfillId);
}
await setPolyfillBackend(device);
await setPolyfillBackend(polyfillBackend);
} else if (backend === 'webnn') {
// For Electron
if (isElectron()) {
Expand All @@ -327,7 +323,7 @@ export async function setBackend(backend, device) {
}
}
} else {
addAlert(`Unknow backend: ${backend}`, 'warning');
addAlert(`Unknown backend: ${backend}`, 'warning');
}
}

Expand Down
7 changes: 5 additions & 2 deletions face_recognition/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
<input type="radio" name="backend" id="polyfill_cpu_wasm" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_gpu_webgl" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
22 changes: 10 additions & 12 deletions face_recognition/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,17 @@ let buildTime = 0;
let computeTime = 0;
let fdOutputs;
let frOutputs;
let deviceType = '';
let lastdeviceType = '';
let backend = '';
let lastBackend = '';
let deviceType = '';
const disabledSelectors = ['#tabs > li', '.btn'];

$(document).ready(async () => {
$('.icdisplay').hide();
if (await utils.isWebNN()) {
$('#webnn_cpu').click();
} else {
$('#polyfill_cpu').click();
$('#polyfill_cpu_wasm').click();
}
});

Expand Down Expand Up @@ -275,22 +274,21 @@ function constructNetObject(type) {
async function main() {
try {
if (fdModelName === '') return;
[backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
backend = $('input[name="backend"]:checked').attr('id');
ui.handleClick(disabledSelectors, true);
if (isFirstTimeLoad) $('#hint').hide();
const [numRuns, powerPreference, numThreads] = utils.getUrlParams();
let start;
// Only do load() and build() when model first time loads,
// there's new model choosed, backend changed or device changed
if (isFirstTimeLoad || fdInstanceType !== fdModelName + layout ||
lastdeviceType != deviceType || lastBackend != backend) {
if (lastdeviceType != deviceType || lastBackend != backend) {
// Set backend and device
await utils.setBackend(backend, deviceType);
lastdeviceType = lastdeviceType != deviceType ?
deviceType : lastdeviceType;
lastBackend = lastBackend != backend ? backend : lastBackend;
lastBackend != backend) {
if (lastBackend != backend) {
let backendType;
let polyfillType;
[backendType, deviceType, polyfillType] = backend.split('_');
await utils.setBackend(backendType, deviceType, polyfillType);
lastBackend = backend;
}
if (frInstance !== null) {
// Call dispose() to and avoid memory leak
Expand Down
7 changes: 5 additions & 2 deletions facial_landmark_detection/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
<input type="radio" name="backend" id="polyfill_cpu_wasm" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_gpu_webgl" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
22 changes: 10 additions & 12 deletions facial_landmark_detection/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,17 @@ let buildTime = 0;
let computeTime = 0;
let fdOutputs;
let fldOutputs;
let deviceType = '';
let lastdeviceType = '';
let backend = '';
let lastBackend = '';
let deviceType = '';
const disabledSelectors = ['#tabs > li', '.btn'];

$(document).ready(async () => {
$('.icdisplay').hide();
if (await utils.isWebNN()) {
$('#webnn_cpu').click();
} else {
$('#polyfill_cpu').click();
$('#polyfill_cpu_wasm').click();
}
});

Expand Down Expand Up @@ -210,22 +209,21 @@ function constructNetObject(type) {
async function main() {
try {
if (fdModelName === '') return;
[backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
backend = $('input[name="backend"]:checked').attr('id');
ui.handleClick(disabledSelectors, true);
if (isFirstTimeLoad) $('#hint').hide();
const [numRuns, powerPreference, numThreads] = utils.getUrlParams();
let start;
// Only do load() and build() when model first time loads,
// there's new model choosed, backend changed or device changed
if (isFirstTimeLoad || fdInstanceType !== fdModelName + layout ||
lastdeviceType != deviceType || lastBackend != backend) {
if (lastdeviceType != deviceType || lastBackend != backend) {
// Set backend and device
await utils.setBackend(backend, deviceType);
lastdeviceType = lastdeviceType != deviceType ?
deviceType : lastdeviceType;
lastBackend = lastBackend != backend ? backend : lastBackend;
lastBackend != backend) {
if (lastBackend != backend) {
let backendType;
let polyfillType;
[backendType, deviceType, polyfillType] = backend.split('_');
await utils.setBackend(backendType, deviceType, polyfillType);
lastBackend = backend;
}
if (fldInstance !== null) {
// Call dispose() to and avoid memory leak
Expand Down
7 changes: 5 additions & 2 deletions image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
<input type="radio" name="backend" id="polyfill_cpu_wasm_wasm" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_gpu_webgl" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
24 changes: 11 additions & 13 deletions image_classification/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ let buildTime = 0;
let computeTime = 0;
let inputOptions;
let outputBuffer;
let deviceType = '';
let lastdeviceType = '';
let backend = '';
let lastBackend = '';
let deviceType = '';
const disabledSelectors = ['#tabs > li', '.btn'];

async function fetchLabels(url) {
Expand All @@ -45,7 +44,7 @@ $(document).ready(async () => {
if (await utils.isWebNN()) {
$('#webnn_cpu').click();
} else {
$('#polyfill_cpu').click();
$('#polyfill_cpu_wasm').click();
}
});

Expand Down Expand Up @@ -205,23 +204,22 @@ function constructNetObject(type) {
async function main() {
try {
if (modelName === '') return;
[backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
backend = $('input[name="backend"]:checked').attr('id');
ui.handleClick(disabledSelectors, true);
if (isFirstTimeLoad) $('#hint').hide();
let start;
const [numRuns, powerPreference, numThreads] = utils.getUrlParams();

// Only do load() and build() when model first time loads,
// there's new model choosed, backend changed or device changed
// there's new model choosed or backend changed
if (isFirstTimeLoad || instanceType !== modelName + layout ||
lastdeviceType != deviceType || lastBackend != backend) {
if (lastdeviceType != deviceType || lastBackend != backend) {
// Set backend and device
await utils.setBackend(backend, deviceType);
lastdeviceType = lastdeviceType != deviceType ?
deviceType : lastdeviceType;
lastBackend = lastBackend != backend ? backend : lastBackend;
lastBackend != backend) {
if (lastBackend != backend) {
let backendType;
let polyfillType;
[backendType, deviceType, polyfillType] = backend.split('_');
await utils.setBackend(backendType, deviceType, polyfillType);
lastBackend = backend;
}
if (netInstance !== null) {
// Call dispose() to and avoid memory leak
Expand Down
7 changes: 5 additions & 2 deletions lenet/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
<span class='mr-3'>Backend</span>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
<input type="radio" name="backend" id="polyfill_cpu_wasm" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_gpu_webgl" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
4 changes: 2 additions & 2 deletions lenet/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ function clearResult() {
}

async function main() {
const [backend, deviceType] =
const [backend, deviceType, polyfillType] =
$('input[name="backend"]:checked').attr('id').split('_');
await utils.setBackend(backend, deviceType);
await utils.setBackend(backend, deviceType, polyfillType);
drawNextDigitFromMnist();
const pen = new Pen(visualCanvas);
const weightUrl = '../test-data/models/lenet_nchw/weights/lenet.bin';
Expand Down
7 changes: 5 additions & 2 deletions nsnet2/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
<span class='mr-3'>Backend</span>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
<input type="radio" name="backend" id="polyfill_cpu_wasm" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_gpu_webgl" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
4 changes: 2 additions & 2 deletions nsnet2/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ browseButton.onclick = () => {

export async function main() {
try {
const [backend, deviceType] =
let [backend, deviceType, polyfillType] =
$('input[name="backend"]:checked').attr('id').split('_');
await setBackend(backend, deviceType);
deviceType = await setBackend(backend, deviceType, polyfillType);
// Handle frames parameter.
const searchParams = new URLSearchParams(location.search);
let frames = parseInt(searchParams.get('frames'));
Expand Down
7 changes: 5 additions & 2 deletions object_detection/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@
<div class="col-md-auto">
<div class="btn-group-toggle" data-toggle="buttons" id="backendBtns">
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
<input type="radio" name="backend" id="polyfill_cpu_wasm" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_gpu_webgl" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
22 changes: 10 additions & 12 deletions object_detection/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ let buildTime = 0;
let computeTime = 0;
let inputOptions;
let outputs;
let deviceType = '';
let lastdeviceType = '';
let backend = '';
let lastBackend = '';
let deviceType = '';
const disabledSelectors = ['#tabs > li', '.btn'];

async function fetchLabels(url) {
Expand All @@ -43,7 +42,7 @@ $(document).ready(async () => {
if (await utils.isWebNN()) {
$('#webnn_cpu').click();
} else {
$('#polyfill_cpu').click();
$('#polyfill_cpu_wasm').click();
}
});

Expand Down Expand Up @@ -177,8 +176,7 @@ function constructNetObject(type) {
async function main() {
try {
if (modelName === '') return;
[backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
backend = $('input[name="backend"]:checked').attr('id');
ui.handleClick(disabledSelectors, true);
if (isFirstTimeLoad) $('#hint').hide();
let start;
Expand All @@ -187,13 +185,13 @@ async function main() {
// Only do load() and build() when model first time loads,
// there's new model choosed, backend changed or device changed
if (isFirstTimeLoad || instanceType !== modelName + layout ||
lastdeviceType != deviceType || lastBackend != backend) {
if (lastdeviceType != deviceType || lastBackend != backend) {
// Set backend and device
await utils.setBackend(backend, deviceType);
lastdeviceType = lastdeviceType != deviceType ?
deviceType : lastdeviceType;
lastBackend = lastBackend != backend ? backend : lastBackend;
lastBackend != backend) {
if (lastBackend != backend) {
let backendType;
let polyfillType;
[backendType, deviceType, polyfillType] = backend.split('_');
await utils.setBackend(backendType, deviceType, polyfillType);
lastBackend = backend;
}
if (netInstance !== null) {
// Call dispose() to and avoid memory leak
Expand Down
Loading