From af18824f430225330ca58dd9a5a08bdc57900955 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 26 Aug 2024 14:46:04 -0700 Subject: [PATCH] [JS/WebGPU] Add GatherBlockQuantized op support (#21734) ### Description Add GatherBlockQuantized operator to JSEP. ### Motivation and Context Gemma model requires this. --- js/common/lib/tensor-impl.ts | 4 +- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/common.ts | 5 +- .../jsep/webgpu/ops/gather-block-quantized.ts | 196 +++++++++++++ .../jsep/webgpu/ops/multihead-attention.ts | 18 +- js/web/lib/wasm/wasm-common.ts | 4 +- js/web/script/generate-webgpu-operator-md.ts | 1 + ...nt4.jsonc => dequantize-linear-int4.jsonc} | 0 .../data/ops/gather-block-quantized.jsonc | 257 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 3 + .../contrib_ops/js/js_contrib_kernels.cc | 13 +- .../js/quantization/gather_block_quantized.cc | 31 +++ .../js/quantization/gather_block_quantized.h | 46 ++++ 14 files changed, 567 insertions(+), 14 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/gather-block-quantized.ts rename js/web/test/data/ops/{dequantize-linear_int4.jsonc => dequantize-linear-int4.jsonc} (100%) create mode 100644 js/web/test/data/ops/gather-block-quantized.jsonc create mode 100644 onnxruntime/contrib_ops/js/quantization/gather_block_quantized.cc create mode 100644 onnxruntime/contrib_ops/js/quantization/gather_block_quantized.h diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 12c6d79d88d2b..4e0ef821dde57 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -140,7 +140,9 @@ export class Tensor implements TensorInterface { type !== 'int64' && type !== 'uint32' && type !== 'uint8' && - type !== 'bool' + type !== 'bool' && + type !== 'uint4' && + type !== 'int4' ) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index cf21fe8ed117d..425d479ad305e 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -48,6 +48,7 @@ Do not modify directly.* | Floor | ai.onnx(6-12,13+) | | | FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | +| GatherBlockQuantized | com.microsoft(1+) | | | GatherElements | ai.onnx(11-12,13+) | | | Gelu | ai.onnx(20+); com.microsoft(1+) | | | Gemm | ai.onnx(7-8,9-10,11-12,13+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 0808d45a307ca..fe824a5c4558a 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -16,6 +16,7 @@ import { einsum, parseEinsumAttributes } from './ops/einsum'; import { expand } from './ops/expand'; import { fastGelu } from './ops/fast-gelu'; import { gather, parseGatherAttributes } from './ops/gather'; +import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized'; import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; import { gemm, parseGemmAttributes } from './ops/gemm'; import { groupQueryAttention, parseGroupQueryAttentionAttributes } from './ops/group-query-attention'; @@ -96,6 +97,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], + ['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]], ['Gelu', [unaryOps.gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 7696f22d44abd..65e54414e957e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -243,7 +243,10 @@ const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [s throw new Error('bool must be vec4'); } return ['u32', 'vec4']; - + case DataType.int4: + return 'i32'; + case DataType.uint4: + return 'u32'; default: throw new Error(`Unknown data type: ${type}`); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-block-quantized.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-block-quantized.ts new file mode 100644 index 0000000000000..f0f1f28342936 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-block-quantized.ts @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; + +import { + createTensorShapeVariables, + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglValueType, + UniformsArrayType, +} from './common'; + +export interface GatherBlockQuantizedAttributes extends AttributeWithCacheKey { + gatherAxis: number; + quantizeAxis: number; + blockSize: number; +} + +export const validateInputs = (inputs: readonly TensorView[], attributes: GatherBlockQuantizedAttributes): void => { + if (inputs.length < 3 || inputs.length > 4) { + throw new Error('GatherBlockQuantized requires 3 or 4 inputs.'); + } + const quantizeAxis = ShapeUtil.normalizeAxis(attributes.quantizeAxis, inputs[0].dims.length); + const blockSize = attributes.blockSize; + const data = inputs[0]; + const scales = inputs[2]; + const zeroPoint = inputs.length === 4 ? inputs[3] : undefined; + if ( + scales.dims.length !== data.dims.length || + !data.dims + .map((d, i) => (i === quantizeAxis ? Math.ceil(d / blockSize) === scales.dims[i] : d === scales.dims[i])) + .reduce((a, b) => a && b, true) + ) { + throw new Error( + 'Scales must have the same rank as the input tensor and the dims should match except on gatherAxis.', + ); + } + // TODO Uncomment the following check once the test case creation code is fixed to create data correctly aligned. + // const indices = inputs[1]; + // const validIndex = (index: number) => index >= 0 && index < data.dims[attributes.gatherAxis]; + // if (indices.dataType === DataType.int32 && indices.getInt32Array().some((v) => !validIndex(v)) || + // indices.dataType === DataType.int64 && indices.getBigInt64Array().some((v) => !validIndex(Number(v)))) { + // throw new Error('Indices must be within the bounds of the gatherAxis.'); + // } + if (zeroPoint) { + if (zeroPoint.dataType !== data.dataType) { + throw new Error('Zero point must have the same data type as the input tensor.'); + } + if ( + zeroPoint.dims.length !== scales.dims.length || + !zeroPoint.dims.map((d, i) => d === scales.dims[i]).reduce((a, b) => a && b, true) + ) { + throw new Error( + 'Zero point must have the same rank as the input tensor and the dims should match except on quantizeAxis.', + ); + } + } +}; + +const createGatherBlockQuantizedProgramInfo = ( + inputs: readonly TensorView[], + attributes: GatherBlockQuantizedAttributes, +): ProgramInfo => { + const inputShape = inputs[0].dims; + const indicesShape = inputs[1].dims; + const inputRank = inputShape.length; + const gatherAxis = ShapeUtil.normalizeAxis(attributes.gatherAxis, inputRank); + const quantizeAxis = ShapeUtil.normalizeAxis(attributes.quantizeAxis, inputRank); + const outputShape = inputShape.slice(0); + outputShape.splice(gatherAxis, 1, ...indicesShape); + const outputSize = ShapeUtil.size(outputShape); + const outputType = inputs[2].dataType; + const inputType = inputs[0].dataType; + const isSigned = inputType === DataType.int4; // input data type is either int4 or uint4. + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: quantizeAxis }, + { type: DataType.uint32, data: gatherAxis }, + { type: DataType.uint32, data: attributes.blockSize }, + ...createTensorShapeVariables(...inputs.map((input, _) => input.dims), outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length); + const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const zeroPoint = + inputs.length > 3 ? inputVariable('zeroPoint', inputs[3].dataType, inputs[3].dims.length) : undefined; + const output = outputVariable('output', outputType, outputShape.length); + const inputVariables = [data, indices, scales]; + if (zeroPoint) { + inputVariables.push(zeroPoint); + } + const uniforms: UniformsArrayType = [ + { name: 'output_size', type: 'u32' }, + { name: 'quantize_axis', type: 'u32' }, + { name: 'gather_axis', type: 'u32' }, + { name: 'block_size', type: 'u32' }, + ]; + return ` + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + let output_indices = ${output.offsetToIndices('global_idx')}; + var indices_indices = ${indices.type.indices}(0); + ${(() => { + if (indicesShape.length > 1) { + return ` + for (var i: u32 = 0; i < ${indicesShape.length}; i++) { + let index = ${output.indicesGet('output_indices', 'uniforms.gather_axis + i')}; + ${indices.indicesSet('indices_indices', 'i', 'index')}; + }`; + } else { + return `indices_indices = ${output.indicesGet('output_indices', 'uniforms.gather_axis')};`; + } + })()}; + var data_indices = ${data.type.indices}(0); + for (var i: u32 = 0; i < uniforms.gather_axis; i++) { + let index = ${output.indicesGet('output_indices', 'i')}; + ${data.indicesSet('data_indices', 'i', 'index')}; + } + var index_from_indices = ${indices.getByIndices('indices_indices')}; + if (index_from_indices < 0) { + index_from_indices += ${inputShape[gatherAxis]}; + } + ${data.indicesSet('data_indices', 'uniforms.gather_axis', 'u32(index_from_indices)')}; + for (var i = uniforms.gather_axis + 1; i < ${outputShape.length}; i++) { + let index = ${output.indicesGet('output_indices', `i + ${indicesShape.length} - 1`)}; + ${data.indicesSet('data_indices', 'i', 'index')}; + } + let data_offset = ${data.indicesToOffset('data_indices')}; + let data_index = data_offset % 8; + // Convert 4-bit packed data to 8-bit packed data. + let packed_4bit_quantized_data = ${data.getByOffset('data_offset / 8')}; + let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f; + let quantized_data_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_quantized_data)); + let quantized_data = quantized_data_vec[data_index / 2]; + var scale_indices = data_indices; + let quantize_axis_index = ${scales.indicesGet('data_indices', 'uniforms.quantize_axis')} / uniforms.block_size; + ${scales.indicesSet('scale_indices', 'uniforms.quantize_axis', 'quantize_axis_index')}; + var scale = ${scales.getByIndices('scale_indices')}; + ${(() => { + if (!zeroPoint) { + return 'var zero_point = 0'; + } else { + return ` + let zero_point_indices = scale_indices; + let zero_point_offset = ${zeroPoint.indicesToOffset('zero_point_indices')}; + let zero_point_index = zero_point_offset % 8; + let packed_4bit_zero_points = ${zeroPoint.getByOffset('zero_point_offset / 8')}; + let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f; + let zero_point_vec = ${isSigned ? 'unpack4xI8' : 'unpack4xU8'}(u32(packed_8bit_zero_points)); + let zero_point = zero_point_vec[zero_point_index / 2];`; + } + })()}; + let dequantized_data = ${tensorTypeToWsglValueType(outputType)}(quantized_data - zero_point) * scale; + ${output.setByOffset('global_idx', 'dequantized_data')}; + }`; + }; + return { + name: 'GatherBlockQuantized', + shaderCache: { + hint: `${attributes.cacheKey};${inputs + .filter((_, i) => i !== 1) + .map((input) => input.dims.join('_')) + .join(';')}`, + inputDependencies: Array.from({ length: inputs.length }, (_v, _i) => 'rank'), + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: outputType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; + +export const gatherBlockQuantized = (context: ComputeContext, attributes: GatherBlockQuantizedAttributes): void => { + const inputs = context.inputs; + validateInputs(inputs, attributes); + context.compute(createGatherBlockQuantizedProgramInfo(context.inputs, attributes)); +}; + +export const parseGatherBlockQuantizedAttributes = ( + attributes: Record, +): GatherBlockQuantizedAttributes => + createAttributeWithCacheKey({ + blockSize: attributes.blockSize as number, + gatherAxis: attributes.gatherAxis as number, + quantizeAxis: attributes.quantizeAxis as number, + }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 485ebec9847fd..0949d65174b41 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -85,7 +85,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr let pastSequenceLength = 0; let maxSequenceLength = 0; const headSize = Math.floor(hiddenSize / attributes.numHeads); - if (pastKey && pastValue) { + if (pastKey && pastValue && ShapeUtil.size(pastKey.dims) && ShapeUtil.size(pastValue.dims)) { if (pastKey.dims.length !== 4) { throw new Error('Input "past_key" is expected to have 4 dimensions'); } @@ -107,12 +107,12 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr } pastSequenceLength = pastKey.dims[2]; maxSequenceLength = pastKey.dims[2]; - } else if (pastKey || pastValue) { + } else if ((pastKey && ShapeUtil.size(pastKey.dims)) || (pastValue && ShapeUtil.size(pastValue.dims))) { throw new Error('Input "past_key" and "past_value" shall be both present or both absent'); } let qkvFormat: AttentionQkvFormat; - if (key) { + if (key && ShapeUtil.size(key.dims) > 0) { if (query.dims.length !== 3) { throw new Error('Input "query" is expected to have 3 dimensions when key is given'); } @@ -159,7 +159,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr qkvFormat = AttentionQkvFormat.qkvBSN3H; } - if (bias) { + if (bias && ShapeUtil.size(bias.dims) > 0) { if (bias.dims.length !== 1) { throw new Error('Input "bias" is expected to have 1 dimension'); } @@ -174,7 +174,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr const totalSequenceLength = pastSequenceLength + kvSequenceLength; let maskType: AttentionMaskType = AttentionMaskType.none; - if (keyPaddingMask) { + if (keyPaddingMask && ShapeUtil.size(keyPaddingMask.dims) > 0) { maskType = AttentionMaskType.maskUnknown; const maskDims = keyPaddingMask.dims; if (maskDims.length === 1) { @@ -194,7 +194,7 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr let passPastInKv = false; let vHiddenSize = hiddenSize; - if (value) { + if (value && ShapeUtil.size(value.dims) > 0) { if (value.dims.length !== 3 && value.dims.length !== 4) { throw new Error('Input "value" is expected to have 3 or 4 dimensions'); } @@ -220,11 +220,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr const broadcastResPosBias = false; - if (keyPaddingMask) { + if (keyPaddingMask && ShapeUtil.size(keyPaddingMask.dims) > 0) { throw new Error('Key padding mask is not supported'); } - if (attentionBias) { + if (attentionBias && ShapeUtil.size(attentionBias.dims) > 0) { if (attentionBias.dims.length !== 4) { throw new Error('Input "attention_bias" is expected to have 4 dimensions'); } @@ -334,7 +334,7 @@ export const maybeTransposeToBNSHAndAddBias = ( // const newDims = []; let reshapedInput = input; - if (!bias) { + if (!(bias && ShapeUtil.size(bias.dims) > 0)) { if (input.dims.length === 3) { reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); } diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index fd5d93675154c..78ff14540d8cb 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -236,7 +236,9 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB type === 'int64' || type === 'uint32' || type === 'uint8' || - type === 'bool'; + type === 'bool' || + type === 'uint4' || + type === 'int4'; /** * Map string data location to integer value diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index 5e9a7152bf185..26640749defc2 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -26,6 +26,7 @@ const MATCHERS = [ /class ONNX_OPERATOR_KERNEL_CLASS_NAME\(\s*(?\w+),\s*(?\w+),\s*(?\d+),\s*(?\w+)\)/g, /class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME\(\s*(?\w+),\s*(?\w+),\s*(?\d+),\s*(?\d+),\s*(?\w+),\s*(?\w+)\)/g, /class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME\(\s*(?\w+),\s*(?\w+),\s*(?\d+),\s*(?\w+),\s*(?\w+)\)/g, + /class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME\(\s*(?\w+),\s*(?\w+),\s*(?\d+),\s*(?\w+),\s*(?\w+),\s*(?\w+)\)/g, ]; /* eslint-enable max-len */ diff --git a/js/web/test/data/ops/dequantize-linear_int4.jsonc b/js/web/test/data/ops/dequantize-linear-int4.jsonc similarity index 100% rename from js/web/test/data/ops/dequantize-linear_int4.jsonc rename to js/web/test/data/ops/dequantize-linear-int4.jsonc diff --git a/js/web/test/data/ops/gather-block-quantized.jsonc b/js/web/test/data/ops/gather-block-quantized.jsonc new file mode 100644 index 0000000000000..a6a346ab42a39 --- /dev/null +++ b/js/web/test/data/ops/gather-block-quantized.jsonc @@ -0,0 +1,257 @@ +[ + { + "name": "GatherBlockQuantized; quantize_axis=0, gather_axis=1, signed input, block_size=16", + "operator": "GatherBlockQuantized", + "opset": { + "domain": "com.microsoft", + "version": 1 + }, + "attributes": [ + { + "name": "block_size", + "data": 16, + "type": "int" + }, + { + "name": "gather_axis", + "data": 0, + "type": "int" + }, + { + "name": "quantize_axis", + "data": 2, + "type": "int" + } + ], + "cases": [ + { + "name": "GatherBlockQuantized; quantize_axis=0, gather_axis=1, block_size=16, signed input", + "inputs": [ + // data + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 6, 7, 0, 1, 2, 3, 4, 5, 5, 6, 7, 0, 1, 2, 3, 4, 4, 5, 6, + 7, 0, 1, 2, 3, 3, 4, 5, 6, 7, 0, 1, 2, 2, 3, 4, 5, 6, 7, 0, 1, 1, 2, 3, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5, + 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 6, 7, 0, 1, 2, 3, 4, 5, 5, 6, 7, 0, 1, 2, 3, 4 + ], + "dims": [2, 3, 16], + "type": "int4" + }, + // indices + { + "data": [1], + "dims": [1, 1, 1, 1], + "type": "int32" + }, + // scale + { + "data": [1.0, 2.0, 1.0, 2.0, 1.0, 2.0], + "dims": [2, 3, 1], + "type": "float32" + }, + // zero + { + "data": [1, 1, 0, 0, 1, -1], + "dims": [2, 3, 1], + "type": "int4" + } + ], + "outputs": [ + { + "data": [ + 4, 6, 8, 10, 12, 14, 0, 2, 2, 4, 6, 8, 10, 12, 14, 0, -1, 0, 1, 2, 3, 4, 5, 6, 6, -1, 0, 1, 2, 3, 4, 5, + 14, 16, 2, 4, 6, 8, 10, 12, 12, 14, 16, 2, 4, 6, 8, 10 + ], + "dims": [1, 1, 1, 1, 3, 16], + "type": "float32" + } + ] + }, + { + "name": "GatherBlockQuantized; quantize_axis=0, gather_axis=1, signed block_size=16, signed input, negative indices", + "inputs": [ + // data + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 6, 7, 0, 1, 2, 3, 4, 5, 5, 6, 7, 0, 1, 2, 3, 4, 4, 5, 6, + 7, 0, 1, 2, 3, 3, 4, 5, 6, 7, 0, 1, 2, 2, 3, 4, 5, 6, 7, 0, 1, 1, 2, 3, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5, + 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 6, 7, 0, 1, 2, 3, 4, 5, 5, 6, 7, 0, 1, 2, 3, 4 + ], + "dims": [2, 3, 16], + "type": "int4" + }, + // indices + { + "data": [-1], + "dims": [1], + "type": "int32" + }, + // scale + { + "data": [0.5, 1.0, 1.25, 1.5, 1.75, 2.0], + "dims": [2, 3, 1], + "type": "float32" + }, + // zero + { + "data": [0, 1, 2, 3, 4, 5], + "dims": [2, 3, 1], + "type": "int4" + } + ], + "outputs": [ + { + "data": [ + -1.5, 0, 1.5, 3, 4.5, 6, -4.5, -3, -3, -1.5, 0, 1.5, 3, 4.5, 6, -4.5, -7, -5.25, -3.5, -1.75, 0, 1.75, + 3.5, 5.25, 5.25, -7, -5.25, -3.5, -1.75, 0, 1.75, 3.5, 2, 4, -10, -8, -6, -4, -2, 0, 0, 2, 4, -10, -8, -6, + -4, -2 + ], + "dims": [1, 3, 16], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherBlockQuantized; quantize_axis=0, gather_axis=1, unsigned input, block_size=16", + "operator": "GatherBlockQuantized", + "opset": { + "domain": "com.microsoft", + "version": 1 + }, + "attributes": [ + { + "name": "block_size", + "data": 16, + "type": "int" + }, + { + "name": "gather_axis", + "data": 0, + "type": "int" + }, + { + "name": "quantize_axis", + "data": 2, + "type": "int" + } + ], + "cases": [ + { + "name": "GatherBlockQuantized; quantize_axis=0, gather_axis=1, block_size=16, unsigned input", + "inputs": [ + // data + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10 + ], + "dims": [2, 3, 16], + "type": "uint4" + }, + // indices + { + "data": [1], + "dims": [1], + "type": "int32" + }, + // scale + { + "data": [1.0, 2.0, 1.0, 2.0, 1.0, 2.0], + "dims": [2, 3, 1], + "type": "float32" + }, + // zero + { + "data": [1, 1, 0, 0, 1, 1], + "dims": [2, 3, 1], + "type": "uint4" + } + ], + "outputs": [ + { + "data": [ + 26, 28, 30, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 11, 12, 13, 14, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 20, 22, 24, 26, 28, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18 + ], + "dims": [1, 3, 16], + "type": "float32" + } + ] + } + ] + }, + { + "name": "GatherBlockQuantized; quantize_axis=0, gather_axis=1, signed block_size=16", + "operator": "GatherBlockQuantized", + "opset": { + "domain": "com.microsoft", + "version": 1 + }, + "attributes": [ + { + "name": "block_size", + "data": 16, + "type": "int" + }, + { + "name": "gather_axis", + "data": 0, + "type": "int" + }, + { + "name": "quantize_axis", + "data": 2, + "type": "int" + } + ], + "cases": [ + { + "name": "GatherBlockQuantized; quantize_axis=0, gather_axis=1, signed block_size=16, signed input; indices dim > 1", + "inputs": [ + // data + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 6, 7, 0, 1, 2, 3, 4, 5, 5, 6, 7, 0, 1, 2, 3, 4, 4, 5, 6, + 7, 0, 1, 2, 3, 3, 4, 5, 6, 7, 0, 1, 2, 2, 3, 4, 5, 6, 7, 0, 1, 1, 2, 3, 4, 5, 6, 7, 0, 0, 1, 2, 3, 4, 5, + 6, 7, 7, 0, 1, 2, 3, 4, 5, 6, 6, 7, 0, 1, 2, 3, 4, 5, 5, 6, 7, 0, 1, 2, 3, 4 + ], + "dims": [2, 3, 16], + "type": "int4" + }, + // indices + { + "data": [1], + "dims": [1, 1], + "type": "int32" + }, + // scale + { + "data": [1.0, 2.0, 1.0, 2.0, 1.0, 2.0], + "dims": [2, 3, 1], + "type": "float32" + }, + // zero + { + "data": [1, 1, 0, 0, 1, -1], + "dims": [2, 3, 1], + "type": "int4" + } + ], + "outputs": [ + { + "data": [ + 4, 6, 8, 10, 12, 14, 0, 2, 2, 4, 6, 8, 10, 12, 14, 0, -1, 0, 1, 2, 3, 4, 5, 6, 6, -1, 0, 1, 2, 3, 4, 5, + 14, 16, 2, 4, 6, 8, 10, 12, 12, 14, 16, 2, 4, 6, 8, 10 + ], + "dims": [1, 1, 3, 16], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 829e55a625102..7f0c1cc3e420c 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1354,6 +1354,7 @@ "div_int32.jsonc", "depth-to-space.jsonc", "dequantizelinear.jsonc", + "dequantize-linear-int4.jsonc", "equal.jsonc", "exp.jsonc", "expand.jsonc", @@ -1361,6 +1362,8 @@ "floor.jsonc", "fused-conv.jsonc", "fused-conv3dncdhw.jsonc", + "gather.jsonc", + "gather-block-quantized.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 11899feb6e1dc..36a6f9bd87013 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -24,6 +24,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLa class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Int4x2, int64_t, GatherBlockQuantized); + template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -51,8 +56,12 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo}; - + SkipSimplifiedLayerNormalization)>, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); if (info.kernel_def != nullptr) { // filter disabled entries where type is void diff --git a/onnxruntime/contrib_ops/js/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/js/quantization/gather_block_quantized.cc new file mode 100644 index 0000000000000..ea4a5448bb6ea --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/gather_block_quantized.cc @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_data_types.h" +#include "contrib_ops/js/quantization/gather_block_quantized.h" +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +#define ONNX_GATHER_BLOCK_QUANTIZED_KERNELS(T1, Tind) \ + ONNX_OPERATOR_TWO_TYPED_KERNEL_EX( \ + GatherBlockQuantized, \ + kMSDomain, 1, \ + T1, Tind, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", JsepSupportedFloatTypes()) \ + .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ + GatherBlockQuantized); + +ONNX_GATHER_BLOCK_QUANTIZED_KERNELS(UInt4x2, int32_t); +ONNX_GATHER_BLOCK_QUANTIZED_KERNELS(UInt4x2, int64_t); +ONNX_GATHER_BLOCK_QUANTIZED_KERNELS(Int4x2, int32_t); +ONNX_GATHER_BLOCK_QUANTIZED_KERNELS(Int4x2, int64_t); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/js/quantization/gather_block_quantized.h new file mode 100644 index 0000000000000..69f258916895d --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/gather_block_quantized.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_data_types.h" +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class GatherBlockQuantized : public JsKernel { + public: + GatherBlockQuantized(const OpKernelInfo& info) : JsKernel(info) { + int64_t gather_axis; + int64_t quantize_axis; + int64_t block_size; + if (!info.GetAttr("gather_axis", &gather_axis).IsOK()) { + gather_axis = 0; + } + + if (!info.GetAttr("quantize_axis", &quantize_axis).IsOK()) { + quantize_axis = 1; + } + + if (!info.GetAttr("block_size", &block_size).IsOK()) { + block_size = 128; + } + + ORT_ENFORCE(block_size >= 16 && ((block_size - 1) & block_size) == 0, + "'block_size' must be 2's power and not less than 16."); + JSEP_INIT_KERNEL_ATTRIBUTE(GatherBlockQuantized, ({ + "gatherAxis" : $1, + "quantizeAxis" : $2, + "blockSize" : $3 + }), + static_cast(gather_axis), + static_cast(quantize_axis), + static_cast(block_size)); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime