From 7efa458360c62e7e2958a7a4464846c9ff7f3399 Mon Sep 17 00:00:00 2001 From: Reilly Grant Date: Fri, 10 Nov 2023 16:57:38 -0800 Subject: [PATCH 1/2] Add support for using the TensorFlow.js WebGPU backend --- common/utils.js | 25 +++++++++++++++++-------- face_recognition/index.html | 5 ++++- facial_landmark_detection/index.html | 5 ++++- image_classification/index.html | 5 ++++- lenet/index.html | 5 ++++- nsnet2/index.html | 5 ++++- nsnet2/main.js | 4 ++-- object_detection/index.html | 5 ++++- rnnoise/index.html | 5 ++++- rnnoise/main.js | 5 +++-- semantic_segmentation/index.html | 5 ++++- style_transfer/index.html | 5 ++++- 12 files changed, 58 insertions(+), 21 deletions(-) diff --git a/common/utils.js b/common/utils.js index c28185fa..89fd13f2 100644 --- a/common/utils.js +++ b/common/utils.js @@ -202,13 +202,14 @@ export function getMedianValue(array) { // Set tf.js backend based WebNN's 'MLDeviceType' option export async function setPolyfillBackend(device) { // Simulate WebNN's device selection using various tf.js backends. - // MLDeviceType: ['default', 'gpu', 'cpu'] - // 'default' or 'gpu': tfjs-backend-webgl, 'cpu': tfjs-backend-wasm - if (!device) device = 'gpu'; + // MLDeviceType: ['default', 'webgl', 'webgpu', 'cpu'] + // 'default' or 'webgl': tfjs-backend-webgl, 'webgpu': tfjs-backend-webgpu, + // 'cpu': tfjs-backend-wasm + if (!device) device = 'webgpu'; // Use 'webgl' by default for better performance. // Note: 'wasm' backend may run failed on some samples since // some ops aren't supported on 'wasm' backend at present - const backend = device === 'cpu' ? 'wasm' : 'webgl'; + const backend = device === 'cpu' ? 'wasm' : device; const context = await navigator.ml.createContext(); const tf = context.tf; if (tf) { @@ -221,8 +222,8 @@ export async function setPolyfillBackend(device) { throw new Error(`Failed to set tf.js backend ${backend}.`); } await tf.ready(); - let backendInfo = backend == 'wasm' ? 'WASM' : 'WebGL'; - if (backendInfo == 'WASM') { + let backendInfo = tf.getBackend(); + if (backendInfo == 'wasm') { const hasSimd = tf.env().features['WASM_HAS_SIMD_SUPPORT']; const hasThreads = tf.env().features['WASM_HAS_MULTITHREAD_SUPPORT']; if (hasThreads && hasSimd) { @@ -239,6 +240,13 @@ export async function setPolyfillBackend(device) { `WebNN-polyfill with tf.js ${tf.version_core} ` + `${backendInfo} backend.`, 'info'); } + switch (device) { + case 'webgl': + case 'webgpu': + return 'gpu'; + default: + return 'cpu'; + } } // Get url params @@ -304,7 +312,7 @@ export async function setBackend(backend, device) { // Create WebNN-polyfill script await loadScript(webnnPolyfillUrl, webnnPolyfillId); } - await setPolyfillBackend(device); + return await setPolyfillBackend(device); } else if (backend === 'webnn') { // For Electron if (isElectron()) { @@ -326,8 +334,9 @@ export async function setBackend(backend, device) { addAlert(`WebNN is not supported!`, 'warning'); } } + return device; } else { - addAlert(`Unknow backend: ${backend}`, 'warning'); + addAlert(`Unknown backend: ${backend}`, 'warning'); } } diff --git a/face_recognition/index.html b/face_recognition/index.html index dd8db030..28cd18b5 100644 --- a/face_recognition/index.html +++ b/face_recognition/index.html @@ -35,7 +35,10 @@ Wasm (CPU) + + + + + + + + +