Skip to content

Commit

Permalink
Refactor shader/execution/expression/constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
greggman committed Feb 25, 2025
1 parent 3d17e72 commit 8948efa
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 90 deletions.
100 changes: 36 additions & 64 deletions src/webgpu/shader/execution/expression/constructor/non_zero.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Execution Tests for value constructors from components
`;

import { makeTestGroup } from '../../../../../common/framework/test_group.js';
import { GPUTest } from '../../../../gpu_test.js';
import { AllFeaturesMaxLimitsGPUTest } from '../../../../gpu_test.js';
import {
ArrayValue,
MatrixType,
Expand All @@ -24,7 +24,7 @@ import {
run,
} from '../expression.js';

export const g = makeTestGroup(GPUTest);
export const g = makeTestGroup(AllFeaturesMaxLimitsGPUTest);

/** @returns true if 'v' is 'min' or 'max' */
function isMinOrMax(v: number | 'min' | 'max') {
Expand Down Expand Up @@ -73,12 +73,12 @@ g.test('scalar_identity')
.combine('value', ['min', 'max', 1, 2, 5, 100] as const)
)
.beforeAllSubcases(t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
}
t.skipIf(t.params.type === 'bool' && !isMinOrMax(t.params.value));
})
.fn(async t => {
if (t.params.type === 'f16') {
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
const type = Type[t.params.type];
const value = valueFor(t.params.value, t.params.type);
await run(
Expand All @@ -101,12 +101,10 @@ g.test('vector_identity')
.combine('width', [2, 3, 4] as const)
.combine('infer_type', [false, true] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const elementType = Type[t.params.type];
const vectorType = Type.vec(t.params.width, elementType);
const elements: number[] = [];
Expand Down Expand Up @@ -146,12 +144,12 @@ g.test('concrete_vector_splat')
.combine('infer_type', [false, true] as const)
)
.beforeAllSubcases(t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
}
t.skipIf(t.params.type === 'bool' && !isMinOrMax(t.params.value));
})
.fn(async t => {
if (t.params.type === 'f16') {
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
const value = valueFor(t.params.value, t.params.type);
const elementType = Type[t.params.type];
const vectorType = Type.vec(t.params.width, elementType);
Expand All @@ -176,12 +174,10 @@ g.test('abstract_vector_splat')
.combine('value', [1, 2, 5, 100] as const)
.combine('width', [2, 3, 4] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.concrete_type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const suffix = t.params.abstract_type === 'abstract-float' ? '.0' : '';
const concreteElementType = Type[t.params.concrete_type];
const concreteVectorType = Type.vec(t.params.width, concreteElementType);
Expand All @@ -206,12 +202,10 @@ g.test('concrete_vector_elements')
.combine('width', [2, 3, 4] as const)
.combine('infer_type', [false, true] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const elementType = Type[t.params.type];
const vectorType = Type.vec(t.params.width, elementType);
const elements: number[] = [];
Expand Down Expand Up @@ -248,12 +242,10 @@ g.test('abstract_vector_elements')
.expand('concrete_type', t => kConcreteTypesForAbstractType[t.abstract_type])
.combine('width', [2, 3, 4] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.concrete_type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const suffix = t.params.abstract_type === 'abstract-float' ? '.0' : '';
const concreteElementType = Type[t.params.concrete_type];
const concreteVectorType = Type.vec(t.params.width, concreteElementType);
Expand Down Expand Up @@ -297,12 +289,10 @@ g.test('concrete_vector_mix')
.combine('signature', kMixSignatures)
.combine('infer_type', [false, true] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const elementType = Type[t.params.type];
let width = 0;
const elementValue = (i: number) => (t.params.type === 'bool' ? i & 1 : (i + 1) * 10);
Expand Down Expand Up @@ -354,12 +344,10 @@ g.test('abstract_vector_mix')
.expand('concrete_type', t => kConcreteTypesForAbstractType[t.abstract_type])
.combine('signature', kMixSignatures)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.concrete_type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
let width = 0;
const suffix = t.params.abstract_type === 'abstract-float' ? '.0' : '';
const concreteElementType = Type[t.params.concrete_type];
Expand Down Expand Up @@ -412,12 +400,10 @@ g.test('matrix_identity')
.combine('rows', [2, 3, 4] as const)
.combine('infer_type', [false, true] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const elementType = Type[t.params.type];
const matrixType = Type.mat(t.params.columns, t.params.rows, elementType);
const elements: number[] = [];
Expand Down Expand Up @@ -453,12 +439,10 @@ g.test('concrete_matrix_elements')
.combine('rows', [2, 3, 4] as const)
.combine('infer_type', [false, true] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const elementType = Type[t.params.type];
const matrixType = Type.mat(t.params.columns, t.params.rows, elementType);
const elements: number[] = [];
Expand Down Expand Up @@ -492,12 +476,10 @@ g.test('abstract_matrix_elements')
.combine('columns', [2, 3, 4] as const)
.combine('rows', [2, 3, 4] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.concrete_type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const concreteElementType = Type[t.params.concrete_type];
const concreteMatrixType = Type.mat(t.params.columns, t.params.rows, concreteElementType);
const elements: number[] = [];
Expand Down Expand Up @@ -535,12 +517,10 @@ g.test('concrete_matrix_column_vectors')
.combine('rows', [2, 3, 4] as const)
.combine('infer_type', [false, true] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const elementType = Type[t.params.type];
const columnType = Type.vec(t.params.rows, elementType);
const matrixType = Type.mat(t.params.columns, t.params.rows, elementType);
Expand Down Expand Up @@ -580,12 +560,10 @@ g.test('abstract_matrix_column_vectors')
.combine('columns', [2, 3, 4] as const)
.combine('rows', [2, 3, 4] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.concrete_type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const concreteElementType = Type[t.params.concrete_type];
const concreteMatrixType = Type.mat(t.params.columns, t.params.rows, concreteElementType);
const elements: number[] = [];
Expand Down Expand Up @@ -625,12 +603,10 @@ g.test('concrete_array_elements')
.combine('length', [1, 5, 10] as const)
.combine('infer_type', [false, true] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const elementType = Type[t.params.type];
const arrayType = Type.array(t.params.length, elementType);
const elements: number[] = [];
Expand Down Expand Up @@ -668,12 +644,10 @@ g.test('abstract_array_elements')
.expand('concrete_type', t => kConcreteTypesForAbstractType[t.abstract_type])
.combine('length', [1, 5, 10] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (scalarTypeOf(Type[t.params.concrete_type]).kind === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const count = t.params.length;
const concreteElementType = Type[t.params.concrete_type];
const concreteArrayType = Type.array(count, concreteElementType);
Expand Down Expand Up @@ -762,12 +736,10 @@ g.test('structure')
.beginSubcases()
.expand('member_index', t => t.member_types.map((_, i) => i))
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.member_types.includes('f16')) {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const memberType = Type[t.params.member_types[t.params.member_index]];
const values = t.params.member_types.map((ty, i) => Type[ty].create(i));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,20 @@ Execution Tests for zero value constructors
`;

import { makeTestGroup } from '../../../../../common/framework/test_group.js';
import { GPUTest } from '../../../../gpu_test.js';
import { AllFeaturesMaxLimitsGPUTest } from '../../../../gpu_test.js';
import { ScalarKind, Type } from '../../../../util/conversion.js';
import { ShaderBuilderParams, basicExpressionBuilder, run } from '../expression.js';

export const g = makeTestGroup(GPUTest);
export const g = makeTestGroup(AllFeaturesMaxLimitsGPUTest);

g.test('scalar')
.specURL('https://www.w3.org/TR/WGSL/#zero-value-builtin-function')
.desc(`Test that a zero value scalar constructor produces the expected zero value`)
.params(u => u.combine('type', ['bool', 'i32', 'u32', 'f32', 'f16'] as const))
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const type = Type[t.params.type];
await run(
t,
Expand All @@ -38,12 +36,10 @@ g.test('vector')
.combine('type', ['bool', 'i32', 'u32', 'f32', 'f16'] as const)
.combine('width', [2, 3, 4] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const type = Type.vec(t.params.width, Type[t.params.type]);
await run(
t,
Expand All @@ -60,12 +56,10 @@ g.test('vector_prefix')
.params(u =>
u.combine('type', ['i32', 'u32', 'f32', 'f16'] as const).combine('width', [2, 3, 4] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const type = Type.vec(t.params.width, Type[t.params.type]);
await run(
t,
Expand All @@ -86,12 +80,10 @@ g.test('matrix')
.combine('columns', [2, 3, 4] as const)
.combine('rows', [2, 3, 4] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const type = Type.mat(t.params.columns, t.params.rows, Type[t.params.type]);
await run(
t,
Expand All @@ -111,12 +103,10 @@ g.test('array')
.combine('type', ['bool', 'i32', 'u32', 'f32', 'f16', 'vec3f', 'vec4i'] as const)
.combine('length', [1, 5, 10] as const)
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.type === 'f16') {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const type = Type.array(t.params.length, Type[t.params.type]);
await run(
t,
Expand Down Expand Up @@ -146,12 +136,10 @@ g.test('structure')
.beginSubcases()
.expand('member_index', t => t.member_types.map((_, i) => i))
)
.beforeAllSubcases(t => {
.fn(async t => {
if (t.params.member_types.includes('f16')) {
t.selectDeviceOrSkipTestCase('shader-f16');
t.skipIfDeviceDoesNotHaveFeature('shader-f16');
}
})
.fn(async t => {
const memberType = Type[t.params.member_types[t.params.member_index]];
const builder = basicExpressionBuilder(_ =>
t.params.nested
Expand Down

0 comments on commit 8948efa

Please sign in to comment.