Skip to content

Commit

Permalink
[js/web] Update API for ort.env.webgpu (#23026)
Browse files Browse the repository at this point in the history
### Description

This PR is a replacement of #21671. It offers a new way for accessing
the following:
- `ort.env.webgpu.adapter`:
- **deprecating**. There is no point to get the value of it. Once
`GPUDevice.adapterInfo` is supported, there is no point to set the value
too.
- `ort.env.webgpu.device`:
  - set value of `GPUDevice` if user created it. Use at user's own risk.
- get value of `Promise<GPUDevice>`. if not exist, create a new one. if
exist return it.
- `ort.env.webgpu.powerPreference`:
- **deprecating**. encouraging users to set `ort.env.webgpu.device` if
necessary.
- `ort.env.webgpu.forceFallbackAdapter`:
- **deprecating**. encouraging users to set `ort.env.webgpu.device` if
necessary.
  • Loading branch information
fs-eire authored Dec 11, 2024
1 parent 8800830 commit e605870
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
33 changes: 25 additions & 8 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,19 @@ export declare namespace Env {
*
* This setting is available only when WebAssembly SIMD feature is available in current context.
*
* @defaultValue `true`
*
* @deprecated This property is deprecated. Since SIMD is supported by all major JavaScript engines, non-SIMD
* build is no longer provided. This property will be removed in future release.
* @defaultValue `true`
*/
simd?: boolean;

/**
* set or get a boolean value indicating whether to enable trace.
*
* @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
* @defaultValue `false`
*
* @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
*/
trace?: boolean;

Expand Down Expand Up @@ -153,7 +155,7 @@ export declare namespace Env {
/**
* Set or get the profiling configuration.
*/
profiling?: {
profiling: {
/**
* Set or get the profiling mode.
*
Expand All @@ -176,6 +178,9 @@ export declare namespace Env {
* See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
*
* @defaultValue `undefined`
*
* @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if
* you want to use a specific power preference.
*/
powerPreference?: 'low-power' | 'high-performance';
/**
Expand All @@ -187,6 +192,9 @@ export declare namespace Env {
* See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
*
* @defaultValue `undefined`
*
* @deprecated Create your own GPUAdapter, use it to create a GPUDevice instance and set {@link device} property if
* you want to use a specific fallback option.
*/
forceFallbackAdapter?: boolean;
/**
Expand All @@ -199,16 +207,25 @@ export declare namespace Env {
* value will be the GPU adapter that created by the underlying WebGPU backend.
*
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
*
* @deprecated It is no longer recommended to use this property. The latest WebGPU spec adds `GPUDevice.adapterInfo`
* (https://www.w3.org/TR/webgpu/#dom-gpudevice-adapterinfo), which allows to get the adapter information from the
* device. When it's available, there is no need to set/get the {@link adapter} property.
*/
adapter: TryGetGlobalType<'GPUAdapter'>;
/**
* Get the device for WebGPU.
*
* This property is only available after the first WebGPU inference session is created.
* Set or get the GPU device for WebGPU.
*
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
* There are 3 valid scenarios of accessing this property:
* - Set a value before the first WebGPU inference session is created. The value will be used by the WebGPU backend
* to perform calculations. If the value is not a `GPUDevice` object, an error will be thrown.
* - Get the value before the first WebGPU inference session is created. This will try to create a new GPUDevice
* instance. Returns a `Promise` that resolves to a `GPUDevice` object.
* - Get the value after the first WebGPU inference session is created. Returns a resolved `Promise` to the
* `GPUDevice` object used by the WebGPU backend.
*/
readonly device: TryGetGlobalType<'GPUDevice'>;
get device(): Promise<TryGetGlobalType<'GPUDevice'>>;
set device(value: TryGetGlobalType<'GPUDevice'>);
/**
* Set or get whether validate input content.
*
Expand Down
12 changes: 6 additions & 6 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -586,11 +586,11 @@ export class TensorResultValidator {
}
}

function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
async function createGpuTensorForInput(cpuTensor: ort.Tensor): Promise<ort.Tensor> {
if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) {
throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`);
}
const device = ort.env.webgpu.device as GPUDevice;
const device = await ort.env.webgpu.device;
const gpuBuffer = device.createBuffer({
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
Expand All @@ -612,14 +612,14 @@ function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor {
});
}

function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
async function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) {
if (!isGpuBufferSupportedType(type)) {
throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`);
}

const size = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(type), dims)!;

const device = ort.env.webgpu.device as GPUDevice;
const device = await ort.env.webgpu.device;
const gpuBuffer = device.createBuffer({
// eslint-disable-next-line no-bitwise
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
Expand Down Expand Up @@ -725,7 +725,7 @@ export async function sessionRun(options: {
if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') {
feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]);
} else {
feeds[name] = createGpuTensorForInput(feeds[name]);
feeds[name] = await createGpuTensorForInput(feeds[name]);
}
}
}
Expand All @@ -742,7 +742,7 @@ export async function sessionRun(options: {
if (options.ioBinding === 'ml-tensor') {
fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims);
} else {
fetches[name] = createGpuTensorForOutput(type, dims);
fetches[name] = await createGpuTensorForOutput(type, dims);
}
}
}
Expand Down

0 comments on commit e605870

Please sign in to comment.