Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using the TensorFlow.js WebGPU backend #190

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions common/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,14 @@ export function getMedianValue(array) {
// Set tf.js backend based WebNN's 'MLDeviceType' option
export async function setPolyfillBackend(device) {
// Simulate WebNN's device selection using various tf.js backends.
// MLDeviceType: ['default', 'gpu', 'cpu']
// 'default' or 'gpu': tfjs-backend-webgl, 'cpu': tfjs-backend-wasm
if (!device) device = 'gpu';
// MLDeviceType: ['default', 'webgl', 'webgpu', 'cpu']
// 'default' or 'webgl': tfjs-backend-webgl, 'webgpu': tfjs-backend-webgpu,
// 'cpu': tfjs-backend-wasm
reillyeon marked this conversation as resolved.
Show resolved Hide resolved
if (!device) device = 'webgpu';
// Use 'webgl' by default for better performance.
// Note: 'wasm' backend may run failed on some samples since
// some ops aren't supported on 'wasm' backend at present
const backend = device === 'cpu' ? 'wasm' : 'webgl';
const backend = device === 'cpu' ? 'wasm' : device;
const context = await navigator.ml.createContext();
const tf = context.tf;
if (tf) {
Expand All @@ -221,8 +222,8 @@ export async function setPolyfillBackend(device) {
throw new Error(`Failed to set tf.js backend ${backend}.`);
}
await tf.ready();
let backendInfo = backend == 'wasm' ? 'WASM' : 'WebGL';
if (backendInfo == 'WASM') {
let backendInfo = tf.getBackend();
if (backendInfo == 'wasm') {
const hasSimd = tf.env().features['WASM_HAS_SIMD_SUPPORT'];
const hasThreads = tf.env().features['WASM_HAS_MULTITHREAD_SUPPORT'];
if (hasThreads && hasSimd) {
Expand All @@ -239,6 +240,13 @@ export async function setPolyfillBackend(device) {
`WebNN-polyfill</a> with tf.js ${tf.version_core} ` +
`<b>${backendInfo}</b> backend.`, 'info');
}
switch (device) {
case 'webgl':
case 'webgpu':
return 'gpu';
default:
return 'cpu';
}
}

// Get url params
Expand Down Expand Up @@ -304,7 +312,7 @@ export async function setBackend(backend, device) {
// Create WebNN-polyfill script
await loadScript(webnnPolyfillUrl, webnnPolyfillId);
}
await setPolyfillBackend(device);
return await setPolyfillBackend(device);
} else if (backend === 'webnn') {
// For Electron
if (isElectron()) {
Expand All @@ -326,8 +334,9 @@ export async function setBackend(backend, device) {
addAlert(`WebNN is not supported!`, 'warning');
}
}
return device;
} else {
addAlert(`Unknow backend: ${backend}`, 'warning');
addAlert(`Unknown backend: ${backend}`, 'warning');
}
}

Expand Down
5 changes: 4 additions & 1 deletion face_recognition/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion facial_landmark_detection/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
reillyeon marked this conversation as resolved.
Show resolved Hide resolved
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion image_classification/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion lenet/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion nsnet2/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
4 changes: 2 additions & 2 deletions nsnet2/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ browseButton.onclick = () => {

export async function main() {
try {
const [backend, deviceType] =
let [backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
await setBackend(backend, deviceType);
deviceType = await setBackend(backend, deviceType);
// Handle frames parameter.
const searchParams = new URLSearchParams(location.search);
let frames = parseInt(searchParams.get('frames'));
Expand Down
5 changes: 4 additions & 1 deletion object_detection/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion rnnoise/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 3 additions & 2 deletions rnnoise/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,14 @@ export async function main() {
try {
const [backend, deviceType] =
$('input[name="backend"]:checked').attr('id').split('_');
await utils.setBackend(backend, deviceType);
const contextOptions = {};
contextOptions['deviceType'] =
await utils.setBackend(backend, deviceType);
modelInfo.innerHTML = '';
await log(modelInfo, `Creating RNNoise with input shape ` +
`[${batchSize} (batch_size) x 100 (frames) x 42].`, true);
await log(modelInfo, '- Loading model...');
const powerPreference = utils.getUrlParams()[1];
const contextOptions = {deviceType};
if (powerPreference) {
contextOptions['powerPreference'] = powerPreference;
}
Expand Down
5 changes: 4 additions & 1 deletion semantic_segmentation/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down
5 changes: 4 additions & 1 deletion style_transfer/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
<input type="radio" name="backend" id="polyfill_cpu" autocomplete="off">Wasm (CPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_gpu" autocomplete="off">WebGL (GPU)
<input type="radio" name="backend" id="polyfill_webgl_gpu" autocomplete="off">WebGL (GPU)
</label>
<label class="btn btn-outline-info custom" name="polyfill">
<input type="radio" name="backend" id="polyfill_webgpu" autocomplete="off">WebGPU (GPU)
</label>
<label class="btn btn-outline-info custom" name="webnn">
<input type="radio" name="backend" id="webnn_cpu" autocomplete="off">WebNN (CPU)
Expand Down