diff --git a/demos/benchmarks/conv_gpu_benchmark.ts b/demos/benchmarks/conv_gpu_benchmark.ts index 32192fa6ba..1af6c0cba1 100644 --- a/demos/benchmarks/conv_gpu_benchmark.ts +++ b/demos/benchmarks/conv_gpu_benchmark.ts @@ -29,22 +29,22 @@ export const BENCHMARK_TEST: BenchmarkTest = (size: number) => { const texManager = new TextureManager(gpgpu); initializeGPU(gpgpu, texManager); - const inputDepth = 1; - const inputShape: [number, number, number] = [size, size, inputDepth]; - const outputDepth = 1; - const fieldSize = 11; + const inDepth = 1; + const inShape: [number, number, number] = [size, size, inDepth]; + const outDepth = 1; + const filterSize = 11; const stride = 1; - const zeroPad = conv_util.computeDefaultPad(inputShape, fieldSize, stride); - const hasBias = true; - const program = new Conv2DProgram( - inputShape, fieldSize, outputDepth, stride, zeroPad, hasBias); + const convInfo = conv_util.computeConvInfo( + inShape, filterSize, filterSize, outDepth, stride, stride, 'same'); + const program = new Conv2DProgram(convInfo, hasBias); const outputShape = program.outputShape as [number, number, number]; const out = Array3D.zeros(outputShape); - const x = Array3D.randUniform(inputShape, -1, 1); - const wShape = conv_util.computeWeightsShape4D(1, outputDepth, fieldSize); + const x = Array3D.randUniform(inShape, -1, 1); + const wShape = + conv_util.computeWeightsShape4D(1, outDepth, filterSize, filterSize); const W = Array4D.randUniform(wShape, -1, 1); - const b = Array1D.randUniform([outputDepth], -1, 1); + const b = Array1D.randUniform([outDepth], -1, 1); const inputs = [x, W, b]; const binary = gpgpu_math.compileProgram(gpgpu, program, inputs, out); diff --git a/demos/benchmarks/conv_transpose_gpu_benchmark.ts b/demos/benchmarks/conv_transpose_gpu_benchmark.ts index 4d81f60d5e..2e11dee706 100644 --- a/demos/benchmarks/conv_transpose_gpu_benchmark.ts +++ b/demos/benchmarks/conv_transpose_gpu_benchmark.ts @@ -15,7 +15,7 @@ limitations under the License. import * as conv_util from '../../src/math/conv_util'; import {Array3D, Array4D, initializeGPU} from '../../src/math/ndarray'; -import {Conv2DTransposeProgram} from '../../src/math/webgl/conv_backprop_gpu'; +import {Conv2DDerInputProgram} from '../../src/math/webgl/conv_backprop_gpu'; import {GPGPUContext} from '../../src/math/webgl/gpgpu_context'; import * as gpgpu_math from '../../src/math/webgl/gpgpu_math'; import {TextureManager} from '../../src/math/webgl/texture_manager'; @@ -25,8 +25,8 @@ const OP_RUNS = 40; export const BENCHMARK_TEST: BenchmarkTest = (size: number) => { const origInputDepth = 1; - const origOutputDepth = 2; - const xShape: [number, number, number] = [size, size, 1]; + const origOutputDepth = 1; + const xShape: [number, number, number] = [size, size, origOutputDepth]; const fieldSize = 11; const origStride = 1; const origPad = 1; @@ -36,14 +36,15 @@ export const BENCHMARK_TEST: BenchmarkTest = (size: number) => { initializeGPU(gpgpu, texManager); gpgpu.enableAutomaticDebugValidation(true); - const hasBias = false; - const program = new Conv2DTransposeProgram( - xShape, fieldSize, origInputDepth, origStride, origPad, hasBias); + const convInfo = conv_util.computeConvInfo( + xShape, fieldSize, fieldSize, origOutputDepth, origStride, origStride, + origPad); + const program = new Conv2DDerInputProgram(convInfo); const outputShape = program.outputShape as [number, number, number]; const out = Array3D.zeros(outputShape); const x = Array3D.randUniform(xShape, -1, 1); const wShape = conv_util.computeWeightsShape4D( - origInputDepth, origOutputDepth, fieldSize); + origInputDepth, origOutputDepth, fieldSize, fieldSize); const W = Array4D.randUniform(wShape, -1, 1); const inputs = [x, W]; const binary = gpgpu_math.compileProgram(gpgpu, program, inputs, out); diff --git a/demos/benchmarks/max_pool_backprop_gpu_benchmark.ts b/demos/benchmarks/max_pool_backprop_gpu_benchmark.ts index 96ef41ce27..4f873ae1aa 100644 --- a/demos/benchmarks/max_pool_backprop_gpu_benchmark.ts +++ b/demos/benchmarks/max_pool_backprop_gpu_benchmark.ts @@ -29,12 +29,14 @@ export const BENCHMARK_TEST: BenchmarkTest = (size: number) => { const texManager = new TextureManager(gpgpu); initializeGPU(gpgpu, texManager); - const outputDepth = 1; - const dyShape: [number, number, number] = [size, size, outputDepth]; + const depth = 1; + const dyShape: [number, number, number] = [size, size, depth]; + const xShape: [number, number, number] = [size, size, depth]; const fSize = 11; const stride = 1; - const zeroPad = conv_util.computeDefaultPad(dyShape, fSize, stride); - const program = new MaxPool2DBackpropProgram(dyShape, fSize, stride, zeroPad); + const convInfo = conv_util.computeConvInfo( + xShape, fSize, fSize, depth, stride, stride, 'same'); + const program = new MaxPool2DBackpropProgram(convInfo); const res = NDArray.zeros(program.outputShape); const dy = Array3D.randUniform(dyShape, -1, 1); const positionsData = new Float32Array(dy.size); diff --git a/demos/benchmarks/max_pool_gpu_benchmark.ts b/demos/benchmarks/max_pool_gpu_benchmark.ts index 25de453ada..14dbdf107c 100644 --- a/demos/benchmarks/max_pool_gpu_benchmark.ts +++ b/demos/benchmarks/max_pool_gpu_benchmark.ts @@ -43,10 +43,9 @@ function testMaxPool(size: number, positions: boolean): number { const xShape: [number, number, number] = [size, size, outputDepth]; const fieldSize = 11; const stride = 1; - const zeroPad = conv_util.computeDefaultPad(xShape, fieldSize, stride); - - const program = - new Pool2DProgram(xShape, fieldSize, stride, zeroPad, 'max', positions); + const convInfo = conv_util.computeConvInfo( + xShape, fieldSize, fieldSize, outputDepth, stride, stride, 'same'); + const program = new Pool2DProgram(convInfo, 'max', positions); const res = NDArray.zeros(program.outputShape); const x = Array3D.randUniform(xShape, -1, 1); const binary = gpgpu_math.compileProgram(gpgpu, program, [x], res); diff --git a/demos/model-builder/layer_builder.ts b/demos/model-builder/layer_builder.ts index c7061f6a4f..41ef124a02 100644 --- a/demos/model-builder/layer_builder.ts +++ b/demos/model-builder/layer_builder.ts @@ -207,7 +207,7 @@ export class Convolution2DLayerBuilder implements LayerBuilder { { label: 'Output depth', initialValue: (inputShape: number[]) => - this.outputDepth != null ? this.outputDepth : 1, + this.outputDepth != null ? this.outputDepth : 1, type: 'number', min: 1, max: 1000, @@ -319,7 +319,7 @@ export class ReshapeLayerBuilder implements LayerBuilder { initialValue: (inputShape: number[]) => inputShape.join(', '), type: 'text' as 'text', setValue: (value: string) => this.outputShape = - value.split(',').map((value) => +value), + value.split(',').map((value) => +value), getValue: () => this.outputShape.join(', ') }]; } diff --git a/src/graph.ts b/src/graph.ts index db5f01862b..5d65b7fe09 100644 --- a/src/graph.ts +++ b/src/graph.ts @@ -694,10 +694,9 @@ export class MaxPoolNode extends Node { graph: Graph, private x: Tensor, public fieldSize: number, public stride = 1, public zeroPad?: number) { super( - graph, 'Max pool', {x}, - new Tensor(conv_util.computeOutputShape3D( - x.shape as [number, number, number], fieldSize, x.shape[2], stride, - zeroPad))); + graph, 'Max pool', {x}, new Tensor(conv_util.computeOutputShape3D( + x.shape as [number, number, number], + fieldSize, x.shape[2], stride, zeroPad))); } validate() { util.assert( @@ -875,4 +874,4 @@ export class ArgMaxEqualsNode extends Node { * @hidden */ export type ArrayData = - NDArray|number|number[]|number[][]|number[][][]|number[][][][]; + NDArray | number | number[] | number[][] | number[][][] | number[][][][]; diff --git a/src/math/conv_util.ts b/src/math/conv_util.ts index 15f4ecba3d..25d8f2e2a2 100644 --- a/src/math/conv_util.ts +++ b/src/math/conv_util.ts @@ -15,14 +15,88 @@ limitations under the License. import * as util from '../util'; +/** + * Information about the forward pass of a convolution/pooling operation. + * It includes input and output shape, strides, filter size and padding + * information. + */ +export type ConvInfo = { + inShape: [number, number, number], + outShape: [number, number, number], + strideHeight: number, + strideWidth: number, + filterHeight: number, + filterWidth: number, + padInfo: {top: number, left: number, right: number, bottom: number} +}; + +/** + * Computes the information about a forward pass of a convolution/pooling + * operation. + */ +export function computeConvInfo( + inShape: [number, number, number], filterHeight: number, + filterWidth: number, outDepth: number, strideHeight: number, + strideWidth: number, pad: 'same'|'valid'|number): ConvInfo { + if (typeof pad === 'number') { + const outShape = computeOutputShape3D( + inShape, filterHeight, outDepth, strideHeight, pad); + return { + inShape, + outShape, + padInfo: {top: pad, bottom: pad, left: pad, right: pad}, + strideHeight, + strideWidth, + filterHeight, + filterWidth + }; + } + const inHeight = inShape[0]; + const inWidth = inShape[1]; + let outShape: [number, number, number]; + let padInfo: {left: number, top: number, bottom: number, right: number}; + if (pad === 'same') { + const outHeight = Math.ceil(inHeight / strideHeight); + const outWidth = Math.ceil(inWidth / strideWidth); + outShape = [outHeight, outWidth, outDepth]; + const padAlongHeight = + (outHeight - 1) * strideHeight + filterHeight - inHeight; + const padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; + const top = Math.floor(padAlongHeight / 2); + const bottom = padAlongHeight - top; + const left = Math.floor(padAlongWidth / 2); + const right = padAlongWidth - left; + padInfo = {top, bottom, left, right}; + } else if (pad === 'valid') { + const outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); + const outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); + outShape = [outHeight, outWidth, outDepth]; + padInfo = {top: 0, bottom: 0, left: 0, right: 0}; + } else { + throw Error(`Unknown padding parameter: ${pad}`); + } + return { + inShape, + outShape, + padInfo, + strideHeight, + strideWidth, + filterHeight, + filterWidth + }; +} + +/** + * @deprecated Use `conv_util.computeConvInfo` instead. + */ export function computeOutputShape3D( - inputShapeRowColDepth: [number, number, number], fieldSize: number, - depth: number, stride: number, zeroPad?: number): [number, number, number] { + inShape: [number, number, number], fieldSize: number, outDepth: number, + stride: number, zeroPad?: number): [number, number, number] { if (zeroPad == null) { - zeroPad = computeDefaultPad(inputShapeRowColDepth, fieldSize, stride); + zeroPad = computeDefaultPad(inShape, fieldSize, stride); } - const inputRows = inputShapeRowColDepth[0]; - const inputCols = inputShapeRowColDepth[1]; + const inputRows = inShape[0]; + const inputCols = inShape[1]; const outputRows = (inputRows - fieldSize + 2 * zeroPad) / stride + 1; util.assert( util.isInt(outputRows), @@ -35,7 +109,7 @@ export function computeOutputShape3D( `The output # of columns (${outputCols}) must be an integer. Change ` + `the stride and/or zero pad parameters`); - return [outputRows, outputCols, depth]; + return [outputRows, outputCols, outDepth]; } export function computeDefaultPad( @@ -50,9 +124,9 @@ export function computeTexShapeFrom3D( } export function computeWeightsShape4D( - inputDepth: number, outputDepth: number, - fSize: number): [number, number, number, number] { - return [fSize, fSize, inputDepth, outputDepth]; + inputDepth: number, outputDepth: number, filterHeight: number, + filterWidth: number): [number, number, number, number] { + return [filterHeight, filterWidth, inputDepth, outputDepth]; } export function computeDilatedRC( diff --git a/src/math/conv_util_test.ts b/src/math/conv_util_test.ts new file mode 100644 index 0000000000..7c84168052 --- /dev/null +++ b/src/math/conv_util_test.ts @@ -0,0 +1,77 @@ +/* Copyright 2017 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as conv_util from './conv_util'; + +describe('conv_util computeConvInfo', () => { + it('1x1 conv over 1x1 array with same pad', () => { + const inShape: [number, number, number] = [1, 1, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 1, 1, 1, 1, 1, 'same'); + expect(convInfo.outShape).toEqual([1, 1, 1]); + }); + + it('2x2 conv over 3x3 array with same pad', () => { + const inShape: [number, number, number] = [3, 3, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 1, 1, 'same'); + expect(convInfo.outShape).toEqual([3, 3, 1]); + // Should produce non-even padding with extra pixel at the right/bottom. + expect(convInfo.padInfo.left).toBe(0); + expect(convInfo.padInfo.right).toBe(1); + expect(convInfo.padInfo.top).toBe(0); + expect(convInfo.padInfo.bottom).toBe(1); + }); + + it('2x2 conv over 3x3 array with same pad', () => { + const inShape: [number, number, number] = [3, 3, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 1, 1, 'same'); + expect(convInfo.outShape).toEqual([3, 3, 1]); + }); + + it('2x2 conv over 3x3 array with valid pad', () => { + const inShape: [number, number, number] = [3, 3, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 1, 1, 'valid'); + expect(convInfo.outShape).toEqual([2, 2, 1]); + }); + + it('2x2 conv over 3x3 array with valid pad with stride 2', () => { + const inShape: [number, number, number] = [3, 3, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 2, 2, 'valid'); + expect(convInfo.outShape).toEqual([1, 1, 1]); + }); + + it('2x2 conv over 3x3 array with valid pad with stride 2', () => { + const inShape: [number, number, number] = [3, 3, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 2, 2, 1, 2, 2, 'valid'); + expect(convInfo.outShape).toEqual([1, 1, 1]); + }); + + it('2x1 conv over 3x3 array with valid pad with stride 1', () => { + const inShape: [number, number, number] = [3, 3, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 2, 1, 1, 1, 1, 'valid'); + expect(convInfo.outShape).toEqual([2, 3, 1]); + }); + + it('2x1 conv over 3x3 array with valid pad with strides h=2, w=1', () => { + const inShape: [number, number, number] = [3, 3, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 2, 1, 1, 2, 1, 'valid'); + expect(convInfo.outShape).toEqual([1, 3, 1]); + }); + + it('1x2 conv over 3x3 array with valid pad with stride 1', () => { + const inShape: [number, number, number] = [3, 3, 1]; + const convInfo = conv_util.computeConvInfo(inShape, 1, 2, 1, 1, 1, 'valid'); + expect(convInfo.outShape).toEqual([3, 2, 1]); + }); +}); diff --git a/src/math/math.ts b/src/math/math.ts index 2979b0f276..4e9e40c347 100644 --- a/src/math/math.ts +++ b/src/math/math.ts @@ -15,11 +15,12 @@ limitations under the License. import * as util from '../util'; import * as concat3d_util from './concat3d_util'; +import * as conv_util from './conv_util'; +import {ConvInfo} from './conv_util'; import * as copy2d_util from './copy2d_util'; - import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray'; -export type ScopeResult = NDArray[] | NDArray | void; +export type ScopeResult = NDArray[]|NDArray|void; export interface LSTMCell { (data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D]; @@ -892,163 +893,248 @@ export abstract class NDArrayMath { /** * Computes a 2D convolution over the input x. - * @param x The input image, must be rank 3, of shape [rows, cols, depth1]. - * @param weights The weights NDArray, must be rank 4, of shape [f, f, depth1, - * depth2]. - * @param biases Optional biases NDArray, must be rank 1 of shape [depth2]. - * @param stride The stride of the convolution. - * @param zeroPad The zero padding of each side of the input NDArray. Will pad - * equally on all sides. + * @param x The input image, rank 3, of shape [height, width, inDepth]. + * @param filter The filter, rank 4, of shape + * [filterHeight, filterWidth, inDepth, outDepth]. + * @param bias Optional bias, rank 1 of shape [outDepth]. + * @param strides The strides of the convolution: [strideHeight, strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm. + * - 'same' pad and stride 1: output will be of same size as input, + * regardless of filter size. + * - 'valid' pad: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * https://www.tensorflow.org/api_guides/python/nn#Convolution */ conv2d( - x: Array3D, weights: Array4D, biases: Array1D|null, stride: number, - zeroPad: number): Array3D { + x: Array3D, filter: Array4D, bias: Array1D|null, + strides: [number, number]|number, pad: 'valid'|'same'|number): Array3D { util.assert( x.rank === 3, `Error in conv2d: x must be rank 3, but got rank ${x.rank}.`); util.assert( - weights.rank === 4, - `Error in conv2d: weights must be rank 4, but got rank ` + - `${weights.rank}.`); - if (biases != null) { + filter.rank === 4, + `Error in conv2d: filter must be rank 4, but got rank ` + + `${filter.rank}.`); + if (bias != null) { util.assert( - biases.rank === 1, - `Error in conv2d: biases must be rank 1, but got rank ` + - `${biases.rank}.`); + bias.rank === 1, + `Error in conv2d: bias must be rank 1, but got rank ` + + `${bias.rank}.`); } util.assert( - x.shape[2] === weights.shape[2], + x.shape[2] === filter.shape[2], `Error in conv2d: depth of input (${x.shape[2]}) must match ` + - `input depth for weights ${weights.shape[2]}.`); - - + `input depth for filter ${filter.shape[2]}.`); + + const filterHeight = filter.shape[0]; + const filterWidth = filter.shape[1]; + const outDepth = filter.shape[3]; + const [strideHeight, strideWidth] = parseTupleParam(strides); + const convInfo = conv_util.computeConvInfo( + x.shape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + pad); return this.executeOp( - 'conv2d', - () => this.conv2dInternal(x, weights, biases, stride, zeroPad)); + 'conv2d', () => this.conv2dInternal(x, filter, bias, convInfo)); } protected abstract conv2dInternal( - x: Array3D, weights: Array4D, biases: Array1D|null, stride: number, - zeroPad: number): Array3D; + x: Array3D, filter: Array4D, bias: Array1D|null, + convInfo: ConvInfo): Array3D; /** * Computes the backprop of a 2D convolution. - * @param x The input image, must be rank 3, of shape [xrows, xcols, depth1]. - * @param dy The dy image, must be rank 3, of shape [yrows, ycols, depth2]. - * @param weights The weights NDArray, must be rank 4, of shape [f, f, depth1, - * depth2]. - * @param stride The stride of the original convolution. - * @param pad The padding of the original convolution. + * @param x The input image, rank 3, of shape [height, width, inDepth]. + * @param dy The dy image, rank 3, of shape [height, width, outDepth]. + * @param filter The filter, rank 4, of shape + * [filterHeight, filterWidth, inDepth, outDepth]. + * @param strides The strides of the convolution: [strideHeight, strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. */ conv2dBackProp( - x: Array3D, dy: Array3D, weights: Array4D, stride: number, - pad: number): {dx: Array3D, dw: Array4D, db: Array1D} { + x: Array3D, dy: Array3D, filter: Array4D, + strides: [number, number]|number, + pad: 'valid'|'same'|number): {dx: Array3D, dw: Array4D, db: Array1D} { + const dw = this.conv2dDerFilter(x, dy, filter.shape, strides, pad); + const db = this.conv2dDerBias(dy); + const dx = this.conv2dDerInput(x.shape, dy, filter, strides, pad); + return {db, dw, dx}; + } + + /** + * Computes the derivative of the input of a 2D convolution. + * + * @param inShape The shape of the input. Length 3 [height, width, inDepth]. + * @param dy The derivative of the output. Rank 3 + * [outHeight, outWidth, outDepth]. + * @param filter The filter, rank 4, of shape + * [filterHeight, filterWidth, inDepth, outDepth]. + * @param strides The strides of the convolution: [strideHeight, strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + */ + conv2dDerInput( + inShape: [number, number, number], dy: Array3D, filter: Array4D, + strides: [number, number]|number, pad: 'valid'|'same'|number): Array3D { + const inDepth = inShape[2]; + const outDepth = dy.shape[2]; util.assert( - x.rank === 3, - `Error in conv2dBackProp: x must be rank 3, but got shape ` + - `${x.shape}.`); + inShape.length === 3, + `Error in conv2dDerInput: x must be rank 3, but got rank ` + + `${inShape.length}.`); util.assert( dy.rank === 3, - `Error in conv2dBackProp: dy must be rank 3, but got shape ` + - `${dy.shape}.`); + `Error in conv2dDerInput: dy must be rank 3, but got ` + + `rank ${dy.rank}`); util.assert( - weights.rank === 4, - `Error in conv2dBackProp: weights must be rank 4, but got shape ` + - `${weights.shape}.`); + filter.rank === 4, + `Error in conv2dDerInput: filter must be rank 4, but got ` + + `rank ${filter.rank}`); util.assert( - x.shape[2] === weights.shape[2], - `Error in conv2dBackProp: depth of x ${x.shape[2]}) must ` + - `match input depth for weights (${weights.shape[2]}.`); + inDepth === filter.shape[2], + `Error in conv2dDerInput: depth of input (${inDepth}) must ` + + `match input depth for filter ${filter.shape[2]}.`); util.assert( - dy.shape[2] === weights.shape[3], - `Error in conv2dBackProp: depth of dy (${dy.shape[2]}) must ` + - `match output depth for weights (${weights.shape[3]}).`); + outDepth === filter.shape[3], + `Error in conv2dDerInput: depth of output (${outDepth}) must` + + `match output depth for filter ${filter.shape[3]}.`); + const filterHeight = filter.shape[0]; + const filterWidth = filter.shape[1]; - let result: {dx: Array3D, dw: Array4D, db: Array1D}; - this.executeOp('conv2dBackProp', () => { - result = this.conv2dBackPropInternal(x, dy, weights, stride, pad); - return result.dx; - }); - this.track(result.db); - this.track(result.dw); - return result; + const [strideHeight, strideWidth] = parseTupleParam(strides); + + const convInfo = conv_util.computeConvInfo( + inShape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + pad); + return this.executeOp( + 'conv2dDerInput', + () => this.conv2dDerInputInternal(dy, filter, convInfo)); } - protected abstract conv2dBackPropInternal( - x: Array3D, dy: Array3D, weights: Array4D, stride: number, - pad: number): {dx: Array3D, dw: Array4D, db: Array1D}; + protected abstract conv2dDerInputInternal( + dy: Array3D, filter: Array4D, convInfo: ConvInfo): Array3D; /** - * Computes the transposed 2D convolution of an image, also known as a - * deconvolution. - * @param x The input image, must be rank 3, of shape [xrows, xcols, depth1]. - * @param weights The weights NDArray, must be rank 4, of shape [f, f, depth1, - * depth2]. - * @param biases Optional biases NDArray, must be rank 1 of shape [depth2]. - * @param stride The stride of the convolution. - * @param pad The padding of each side of the input NDArray. Will pad equally - * on all sides. + * Computes the derivative of the bias of a 2D convolution. + * + * @param dy The gradient for the output of this op. Rank 3 of shape + * [height, width, outDepth]. */ - conv2dTranspose( - x: Array3D, weights: Array4D, biases: Array1D|null, stride: number, - pad: number): Array3D { + conv2dDerBias(dy: Array3D): Array1D { + return this.track(this.conv2dDerBiasInternal(dy)); + } + protected abstract conv2dDerBiasInternal(dY: Array3D): Array1D; + + /** + * Computes the derivative of the filter of a 2D convolution. + * + * @param x The input image, rank 3, of shape [height, width, inDepth]. + * @param dy The dy image, rank 3, of shape [height, width, outDepth]. + * @param filterSize The size of the filter, length 4, + * [filterHeight, filterWidth, inDepth, outDepth]. + * @param strides The strides of the convolution: [strideHeight, strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + */ + conv2dDerFilter( + x: Array3D, dy: Array3D, filterSize: [number, number, number, number], + strides: [number, number]|number, pad: 'valid'|'same'|number): Array4D { util.assert( x.rank === 3, - `Error in conv2dTranspose: x must be rank 3, but got rank ` + - `${x.rank}.`); + `Error in conv2dDerFilter: x must be rank 3, but got shape ` + + `${x.shape}.`); util.assert( - weights.rank === 4, - `Error in conv2dTranspose: weights must be rank 4, but got ` + - `rank ${weights.rank}`); - if (biases != null) { - util.assert( - biases.rank === 1, - `Error in conv2dTranspose: biases must be rank 1, but got ' + - 'rank ${biases.rank}.`); - } + dy.rank === 3, + `Error in conv2dDerFilter: dy must be rank 3, but got shape ` + + `${dy.shape}.`); + util.assert( + filterSize.length === 4, + `Error in conv2dDerFilter: filterSize must be length 4, but got ` + + `${filterSize}.`); + util.assert( + x.shape[2] === filterSize[2], + `Error in conv2dDerFilter: depth of x ${x.shape[2]}) must ` + + `match input depth in filter (${filterSize[2]}.`); util.assert( - x.shape[2] === weights.shape[3], - `Error in conv2dTranspose: depth of input (${x.shape[2]}) must ` + - `match input depth for weights ${weights.shape[3]}.`); + dy.shape[2] === filterSize[3], + `Error in conv2dDerFilter: depth of dy (${dy.shape[2]}) must ` + + `match output depth for filter (${filterSize[3]}).`); + + const filterHeight = filterSize[0]; + const filterWidth = filterSize[1]; + const outDepth = filterSize[3]; + const [strideHeight, strideWidth] = parseTupleParam(strides); + const convInfo = conv_util.computeConvInfo( + x.shape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + pad); + return this.track(this.conv2dDerFilterInternal(x, dy, convInfo)); + } + protected abstract conv2dDerFilterInternal( + x: Array3D, dy: Array3D, convInfo: ConvInfo): Array4D; - return this.executeOp( - 'conv2dTranspose', - () => this.conv2dTransposeInternal(x, weights, biases, stride, pad)); + /** + * Computes the transposed 2D convolution of an image, also known as a + * deconvolution. + * + * @param x The input image, rank 3, of shape [height, width, inDepth]. + * @param filter The filter, rank 4, of shape + * `[filterHeight, filterWidth, outDepth, inDepth]`. + * `inDepth` must match `inDepth` in `x`. + * @param outputShape Output shape, rank 3 [height, width, outDepth]. + * @param strides The strides of the original convolution: + * `[strideHeight, strideWidth]`. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the non-transpose version of the op. + */ + conv2dTranspose( + x: Array3D, filter: Array4D, outputShape: [number, number, number], + strides: [number, number]|number, pad: 'valid'|'same'|number): Array3D { + return this.conv2dDerInput(outputShape, x, filter, strides, pad); } - protected abstract conv2dTransposeInternal( - x: Array3D, weights: Array4D, biases: Array1D|null, stride: number, - pad: number): Array3D; /** * Computes the 2D max pooling of an image. - * @param x The input image, must be rank 3. - * @param fSize The field size of the max pool. - * @param stride The stride of the max pool. - * @param pad The padding of each side of the input NDArray. Will pad equally - * on all sides. - */ - maxPool(x: Array3D, fSize: number, stride: number, pad: number): Array3D { + * @param x The input image, rank 3 of shape [height, width, inDepth]. + * @param filterSize The filter size, a tuple [filterHeight, filterWidth]. + * @param strides The strides of the pooling: [strideHeight, strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm. + * - 'same' pad and stride 1: output will be of same size as input, + * regardless of filter size. + * - 'valid' pad: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * https://www.tensorflow.org/api_guides/python/nn#Convolution + */ + maxPool( + x: Array3D, filterSize: [number, number]|number, + strides: [number, number]|number, pad: 'valid'|'same'|number): Array3D { util.assert( x.rank === 3, 'Error in maxPool: x must be rank 3 but got rank ' + x.rank + '.'); - return this.executeOp( - 'maxPool', () => this.maxPoolInternal(x, fSize, stride, pad)); + + const [filterHeight, filterWidth] = parseTupleParam(filterSize); + const outDepth = x.shape[2]; + const [strideHeight, strideWidth] = parseTupleParam(strides); + const convInfo = conv_util.computeConvInfo( + x.shape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + pad); + return this.executeOp('maxPool', () => this.maxPoolInternal(x, convInfo)); } - protected abstract maxPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D; + protected abstract maxPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D; /** * Computes the backprop of a max pool. * @param dy The dy error. - * @param x The input image, must be rank 3. - * @param fSize The field size of the max pool. - * @param stride The stride of the max pool. - * @param pad The padding of each side of the input NDArray. Will pad equally - * on all sides. + * @param x The input image, rank 3 of shape [height, width, inDepth]. + * @param filterSize The filter size, a tuple [filterHeight, filterWidth]. + * @param strides The strides of the pooling: [strideHeight, strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. */ maxPoolBackprop( - dy: Array3D, x: Array3D, fSize: number, stride: number, - pad: number): Array3D { + dy: Array3D, x: Array3D, filterSize: [number, number]|number, + strides: [number, number]|number, pad: 'valid'|'same'|number): Array3D { util.assert( dy.rank === 3, `Error in maxPoolBackprop: dy must be rank 3 but got rank ` + @@ -1058,49 +1144,77 @@ export abstract class NDArrayMath { `Error in maxPoolBackprop: x must be rank 3 but got rank ` + `${x.rank}.`); + const [filterHeight, filterWidth] = parseTupleParam(filterSize); + const outDepth = x.shape[2]; + const [strideHeight, strideWidth] = parseTupleParam(strides); + const convInfo = conv_util.computeConvInfo( + x.shape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + pad); return this.executeOp( - 'maxPoolBackprop', - () => this.maxPoolBackpropInternal(dy, x, fSize, stride, pad)); + 'maxPoolBackprop', () => this.maxPoolBackpropInternal(dy, x, convInfo)); } protected abstract maxPoolBackpropInternal( - dy: Array3D, x: Array3D, fSize: number, stride: number, - pad: number): Array3D; + dy: Array3D, x: Array3D, convInfo: ConvInfo): Array3D; /** * Computes the 2D min pooling of an image. - * @param x The input image, must be rank 3. - * @param fSize The field size of the max pool. - * @param stride The stride of the max pool. - * @param pad The padding of each side of the input NDArray. Will pad equally - * on all sides. - */ - minPool(x: Array3D, fSize: number, stride: number, pad: number): Array3D { + * @param x The input image, rank 3 of shape [height, width, inDepth]. + * @param filterSize The filter size, a tuple [filterHeight, filterWidth]. + * @param strides The strides of the pooling: [strideHeight, strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm. + * - 'same' pad and stride 1: output will be of same size as input, + * regardless of filter size. + * - 'valid' pad: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * https://www.tensorflow.org/api_guides/python/nn#Convolution + */ + minPool( + x: Array3D, filterSize: [number, number]|number, + strides: [number, number]|number, pad: 'valid'|'same'|number): Array3D { util.assert( x.rank === 3, `Error in minPool: x must be rank 3 but got rank ${x.rank}.`); - return this.executeOp( - 'minPool', () => this.minPoolInternal(x, fSize, stride, pad)); + + const [filterHeight, filterWidth] = parseTupleParam(filterSize); + const outDepth = x.shape[2]; + const [strideHeight, strideWidth] = parseTupleParam(strides); + const convInfo = conv_util.computeConvInfo( + x.shape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + pad); + return this.executeOp('minPool', () => this.minPoolInternal(x, convInfo)); } - protected abstract minPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D; + protected abstract minPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D; /** * Computes the 2D average pooling of an image. - * @param x The input image, must be rank 3. - * @param fSize The field size of the max pool. - * @param stride The stride of the max pool. - * @param pad The padding of each side of the input NDArray. Will pad equally - * on all sides. - */ - avgPool(x: Array3D, fSize: number, stride: number, pad: number): Array3D { + * @param x The input image, rank 3 of shape [height, width, inDepth]. + * @param filterSize The filter size, a tuple [filterHeight, filterWidth]. + * @param strides The strides of the pooling: [strideHeight, strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm. + * - 'same' pad and stride 1: output will be of same size as input, + * regardless of filter size. + * - 'valid' pad: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * https://www.tensorflow.org/api_guides/python/nn#Convolution + */ + avgPool( + x: Array3D, filterSize: [number, number]|number, + strides: [number, number]|number, pad: 'valid'|'same'|number): Array3D { util.assert( x.rank === 3, `Error in avgPool: x must be rank 3 but got rank ${x.rank}.`); - return this.executeOp( - 'avgPool', () => this.avgPoolInternal(x, fSize, stride, pad)); + + const [filterHeight, filterWidth] = parseTupleParam(filterSize); + const outDepth = x.shape[2]; + const [strideHeight, strideWidth] = parseTupleParam(strides); + const convInfo = conv_util.computeConvInfo( + x.shape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + pad); + return this.executeOp('avgPool', () => this.avgPoolInternal(x, convInfo)); } - protected abstract avgPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D; + protected abstract avgPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D; /* * Bilinear resize a 3D array per each channel to a new 2D shape. @@ -1228,7 +1342,7 @@ export abstract class NDArrayMath { * Derived from tf.contrib.rnn.BasicLSTMCell. * @param forgetBias Forget bias for the cell. * @param lstmKernel The weights for the cell. - * @param lstmBias The biases for the cell. + * @param lstmBias The bias for the cell. * @param data The input to the cell. * @param c Previous cell state. * @param h Previous cell output. @@ -1262,10 +1376,11 @@ export abstract class NDArrayMath { const o = this.slice2D( res, [0, res.shape[1] / 4 * 3], [res.shape[0], res.shape[1] / 4]); - const newC = this.add( - this.multiplyStrict( - c, this.sigmoid(this.scalarPlusArray(forgetBias, f))), - this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D; + const newC = + this.add( + this.multiplyStrict( + c, this.sigmoid(this.scalarPlusArray(forgetBias, f))), + this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D; const newH = this.multiplyStrict(this.tanh(newC), this.sigmoid(o)) as Array2D; @@ -1279,3 +1394,7 @@ export enum MatrixOrientation { REGULAR, TRANSPOSED } + +function parseTupleParam(param: number|[number, number]): [number, number] { + return typeof param === 'number' ? [param, param] : param; +} diff --git a/src/math/math_cpu.ts b/src/math/math_cpu.ts index eca36ac269..c947f3a83b 100644 --- a/src/math/math_cpu.ts +++ b/src/math/math_cpu.ts @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as conv_util from '../math/conv_util'; import * as util from '../util'; import * as concat3d_util from './concat3d_util'; +import * as conv_util from './conv_util'; +import {ConvInfo} from './conv_util'; import * as copy2D_util from './copy2d_util'; import {MatrixOrientation, NDArrayMath} from './math'; import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray'; @@ -371,28 +372,26 @@ export class NDArrayMathCPU extends NDArrayMath { return NDArray.make(ndarray.shape, {values: resultValues}); } - /** - * image is of shape [r, c, d1]. - * weights is of shape [F, F, d1, d2]. - */ protected conv2dInternal( - x: Array3D, weights: Array4D, biases: Array1D|null, stride: number, - pad: number): Array3D { + x: Array3D, filter: Array4D, bias: Array1D|null, + convInfo: ConvInfo): Array3D { const [xRows, xCols, inputDepth] = x.shape; - const fieldSize = weights.shape[0]; - const outputDepth = weights.shape[3]; - const outputShape = conv_util.computeOutputShape3D( - [xRows, xCols, inputDepth], fieldSize, outputDepth, stride, pad); - const y = Array3D.zeros(outputShape); - for (let d2 = 0; d2 < outputDepth; ++d2) { + const filterHeight = filter.shape[0]; + const filterWidth = filter.shape[1]; + const outDepth = filter.shape[3]; + const padLeft = convInfo.padInfo.left; + const padTop = convInfo.padInfo.top; + + const y = Array3D.zeros(convInfo.outShape); + for (let d2 = 0; d2 < outDepth; ++d2) { for (let yR = 0; yR < y.shape[0]; ++yR) { - const xRCorner = yR * stride - pad; + const xRCorner = yR * convInfo.strideHeight - padLeft; const xRMin = Math.max(0, xRCorner); - const xRMax = Math.min(xRows, fieldSize + xRCorner); + const xRMax = Math.min(xRows, filterHeight + xRCorner); for (let yC = 0; yC < y.shape[1]; ++yC) { - const xCCorner = yC * stride - pad; + const xCCorner = yC * convInfo.strideWidth - padTop; const xCMin = Math.max(0, xCCorner); - const xCMax = Math.min(xCols, fieldSize + xCCorner); + const xCMax = Math.min(xCols, filterWidth + xCCorner); let dotProd = 0; for (let xR = xRMin; xR < xRMax; ++xR) { const wR = xR - xRCorner; @@ -400,147 +399,76 @@ export class NDArrayMathCPU extends NDArrayMath { const wC = xC - xCCorner; for (let d1 = 0; d1 < inputDepth; ++d1) { const pixel = x.get(xR, xC, d1); - const weight = weights.get(wR, wC, d1, d2); + const weight = filter.get(wR, wC, d1, d2); dotProd += pixel * weight; } } } - const bias = (biases != null) ? biases.get(d2) : 0; - y.set(dotProd + bias, yR, yC, d2); + const biasVal = (bias != null) ? bias.get(d2) : 0; + y.set(dotProd + biasVal, yR, yC, d2); } } } return y; } - protected conv2dBackPropInternal( - x: Array3D, dy: Array3D, weights: Array4D, stride: number, - pad: number): {dx: Array3D, dw: Array4D, db: Array1D} { - const fSize = weights.shape[0]; - const dw = this.conv2dDerWeights(x, dy, fSize, stride, pad); - const db = this.conv2dDerBias(dy); - const dx = this.conv2dTransposeInternal(dy, weights, null, stride, pad); - return {dx, db, dw}; - } - - /** - * image is of shape [r, c, d1]. - * weights is of shape [F, F, d1, d2]. - */ - protected conv2dTransposeInternal( - x: Array3D, weights: Array4D, biases: Array1D|null, origStride: number, - origPad: number): Array3D { - const fSize = weights.shape[0]; - const pad = fSize - 1 - origPad; - const origInputDepth = weights.shape[2]; - const origOutputDepth = weights.shape[3]; - const xRows = x.shape[0]; - const xCols = x.shape[1]; - - // Dilate the input. - const xRowsDilated = (xRows - 1) * origStride + 1; - const xColsDilated = (xCols - 1) * origStride + 1; - - const outputShape = conv_util.computeOutputShape3D( - [xRowsDilated, xColsDilated, origOutputDepth], fSize, origInputDepth, 1, - pad); - const y = Array3D.zeros(outputShape); - for (let d2 = 0; d2 < origInputDepth; ++d2) { - for (let yR = 0; yR < y.shape[0]; ++yR) { - const xRCorner = yR - pad; - const xRMin = Math.max(0, Math.ceil(xRCorner / origStride)); - const xRMax = Math.min(xRows, (fSize + xRCorner) / origStride); - - for (let yC = 0; yC < y.shape[1]; ++yC) { - const xCCorner = yC - pad; - const xCMin = Math.max(0, Math.ceil(xCCorner / origStride)); - const xCMax = Math.min(xCols, (fSize + xCCorner) / origStride); + protected conv2dDerInputInternal( + dy: Array3D, filter: Array4D, convInfo: ConvInfo): Array3D { + const inDepth = filter.shape[2]; + const outDepth = filter.shape[3]; + const yRows = dy.shape[0]; + const yCols = dy.shape[1]; + const filterHeight = filter.shape[0]; + const filterWidth = filter.shape[1]; + const topPad = filterHeight - 1 - convInfo.padInfo.top; + const leftPad = filterWidth - 1 - convInfo.padInfo.left; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + + const dx = Array3D.zeros(convInfo.inShape); + for (let d1 = 0; d1 < inDepth; ++d1) { + for (let xR = 0; xR < dx.shape[0]; ++xR) { + const xRCorner = xR - leftPad; + const xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight)); + const yRMax = Math.min(yRows, (filterHeight + xRCorner) / strideHeight); + + for (let xC = 0; xC < dx.shape[1]; ++xC) { + const xCCorner = xC - topPad; + const xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth)); + const yCMax = Math.min(yCols, (filterWidth + xCCorner) / strideWidth); let dotProd = 0; - for (let xR = xRMin; xR < xRMax; ++xR) { - const wR = xR * origStride - xRCorner; + for (let yR = xRMin; yR < yRMax; ++yR) { + const wR = yR * strideHeight - xRCorner; - for (let xC = xCMin; xC < xCMax; ++xC) { - const wC = xC * origStride - xCCorner; + for (let yC = xCMin; yC < yCMax; ++yC) { + const wC = yC * strideWidth - xCCorner; - for (let d1 = 0; d1 < origOutputDepth; ++d1) { - const pixel = x.get(xR, xC, d1); - const weight = - weights.get(fSize - 1 - wR, fSize - 1 - wC, d2, d1); + for (let d2 = 0; d2 < outDepth; ++d2) { + const pixel = dy.get(yR, yC, d2); + const weight = filter.get( + filterHeight - 1 - wR, filterWidth - 1 - wC, d1, d2); dotProd += pixel * weight; } } } - const bias = biases != null ? biases.get(d2) : 0; - y.set(dotProd + bias, yR, yC, d2); + dx.set(dotProd, xR, xC, d1); } } } - return y; - } - - /** - * image is of shape [r, c, d1]. - * weights is of shape [F, F, d1, d2]. - */ - protected conv2dTransposeShaderLike( - x: Array3D, origWeights: Array4D, origStride: number, - origPad: number): Array3D { - const fSize = origWeights.shape[0]; - const pad = fSize - 1 - origPad; - const origInputDepth = origWeights.shape[2]; - const origOutputDepth = origWeights.shape[3]; - const xRows = x.shape[0]; - const xCols = x.shape[1]; - - // Dilate the input. - const xRowsDilated = (xRows - 1) * origStride + 1; - const xColsDilated = (xCols - 1) * origStride + 1; - - const outputShape = conv_util.computeOutputShape3D( - [xRowsDilated, xColsDilated, origOutputDepth], fSize, origInputDepth, 1, - pad); - const y = Array3D.zeros(outputShape); - - for (let d2 = 0; d2 < origInputDepth; ++d2) { - for (let yR = 0; yR < y.shape[0]; ++yR) { - for (let yC = 0; yC < y.shape[1]; ++yC) { - // Shader code begins. - const xRCorner = yR - pad; - const xCCorner = yC - pad; - let dotProd = 0; - for (let wR = 0; wR < fSize; ++wR) { - const xR = (xRCorner + wR) / origStride; - if (xR < 0 || xR >= xRows || Math.floor(xR) !== xR) { - continue; - } - for (let wC = 0; wC < fSize; ++wC) { - const xC = (xCCorner + wC) / origStride; - if (xC < 0 || xC >= xCols || Math.floor(xC) !== xC) { - continue; - } - for (let d1 = 0; d1 < origOutputDepth; ++d1) { - const pixel = x.get(xR, xC, d1); - const weight = - origWeights.get(fSize - 1 - wR, fSize - 1 - wC, d2, d1); - dotProd += pixel * weight; - } - } - } - y.set(dotProd, yR, yC, d2); - } - } - } - return y; + return dx; } - conv2dDerWeights( - x: Array3D, dY: Array3D, fSize: number, stride: number, - zeroPad: number): Array4D { + protected conv2dDerFilterInternal( + x: Array3D, dY: Array3D, convInfo: ConvInfo): Array4D { const inputDepth = x.shape[2]; const outputDepth = dY.shape[2]; - const weightsShape = - conv_util.computeWeightsShape4D(inputDepth, outputDepth, fSize); + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const weightsShape = conv_util.computeWeightsShape4D( + inputDepth, outputDepth, filterHeight, filterWidth); const dW = Array4D.zeros(weightsShape); const yNumRows = dY.shape[0]; @@ -548,22 +476,26 @@ export class NDArrayMathCPU extends NDArrayMath { const xNumRows = x.shape[0]; const xNumCols = x.shape[1]; - for (let wR = 0; wR < fSize; ++wR) { - const yRMin = Math.max(0, Math.ceil((zeroPad - wR) / stride)); - const yRMax = Math.min(yNumRows, (xNumRows + zeroPad - wR) / stride); + const leftPad = convInfo.padInfo.left; + const topPad = convInfo.padInfo.top; + + for (let wR = 0; wR < filterHeight; ++wR) { + const yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight)); + const yRMax = Math.min(yNumRows, (xNumRows + topPad - wR) / strideHeight); - for (let wC = 0; wC < fSize; ++wC) { - const yCMin = Math.max(0, Math.ceil((zeroPad - wC) / stride)); - const yCMax = Math.min(yNumCols, (xNumCols + zeroPad - wC) / stride); + for (let wC = 0; wC < filterWidth; ++wC) { + const yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth)); + const yCMax = + Math.min(yNumCols, (xNumCols + leftPad - wC) / strideWidth); for (let d1 = 0; d1 < inputDepth; ++d1) { for (let d2 = 0; d2 < outputDepth; ++d2) { // Need to convolve. let dotProd = 0; for (let yR = yRMin; yR < yRMax; ++yR) { - const xR = wR + yR * stride - zeroPad; + const xR = wR + yR * strideHeight - topPad; for (let yC = yCMin; yC < yCMax; ++yC) { - const xC = wC + yC * stride - zeroPad; + const xC = wC + yC * strideWidth - leftPad; dotProd += x.get(xR, xC, d1) * dY.get(yR, yC, d2); } } @@ -575,7 +507,7 @@ export class NDArrayMathCPU extends NDArrayMath { return dW; } - conv2dDerBias(dY: Array3D): Array1D { + protected conv2dDerBiasInternal(dY: Array3D): Array1D { const outputDepth = dY.shape[2]; const numRows = dY.shape[0]; const numCols = dY.shape[1]; @@ -615,22 +547,24 @@ export class NDArrayMathCPU extends NDArrayMath { return result; } - private pool( - x: Array3D, fSize: number, stride: number, pad: number, - poolType: 'max'|'min'|'avg') { + private pool(x: Array3D, convInfo: ConvInfo, poolType: 'max'|'min'|'avg') { const [xRows, xCols, depth] = x.shape; - const outputShape = conv_util.computeOutputShape3D( - [xRows, xCols, depth], fSize, depth, stride, pad); - const y = Array3D.zeros(outputShape); + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const y = Array3D.zeros(convInfo.outShape); + const padTop = convInfo.padInfo.top; + const padLeft = convInfo.padInfo.left; for (let d = 0; d < depth; ++d) { for (let yR = 0; yR < y.shape[0]; ++yR) { - const xRCorner = yR * stride - pad; + const xRCorner = yR * strideHeight - padTop; const xRMin = Math.max(0, xRCorner); - const xRMax = Math.min(xRows, fSize + xRCorner); + const xRMax = Math.min(xRows, filterHeight + xRCorner); for (let yC = 0; yC < y.shape[1]; ++yC) { - const xCCorner = yC * stride - pad; + const xCCorner = yC * strideWidth - padLeft; const xCMin = Math.max(0, xCCorner); - const xCMax = Math.min(xCols, fSize + xCCorner); + const xCMax = Math.min(xCols, filterWidth + xCCorner); let minMaxValue = @@ -650,7 +584,7 @@ export class NDArrayMathCPU extends NDArrayMath { (poolType === 'min' && pixel < minMaxValue)) { minMaxValue = pixel; } else if (poolType === 'avg') { - avgValue += pixel / (fSize * fSize); + avgValue += pixel / (filterHeight * filterWidth); } } if (isNaN(minMaxValue)) { @@ -664,25 +598,30 @@ export class NDArrayMathCPU extends NDArrayMath { return y; } - protected maxPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D { - return this.pool(x, fSize, stride, pad, 'max'); + protected maxPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D { + return this.pool(x, convInfo, 'max'); } - maxPoolPositions(x: Array3D, fSize: number, stride: number, pad: number) { + maxPoolPositions(x: Array3D, convInfo: ConvInfo) { const [xRows, xCols, depth] = x.shape; - const outputShape = - conv_util.computeOutputShape3D(x.shape, fSize, depth, stride, pad); + const outputShape = convInfo.outShape; const maxPositions = Array3D.zeros(outputShape); + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padTop = convInfo.padInfo.top; + const padLeft = convInfo.padInfo.left; + for (let d = 0; d < depth; ++d) { for (let yR = 0; yR < outputShape[0]; ++yR) { - const xRCorner = yR * stride - pad; + const xRCorner = yR * strideHeight - padTop; const xRMin = Math.max(0, xRCorner); - const xRMax = Math.min(xRows, fSize + xRCorner); + const xRMax = Math.min(xRows, filterHeight + xRCorner); for (let yC = 0; yC < outputShape[1]; ++yC) { - const xCCorner = yC * stride - pad; + const xCCorner = yC * strideWidth - padLeft; const xCMin = Math.max(0, xCCorner); - const xCMax = Math.min(xCols, fSize + xCCorner); + const xCMax = Math.min(xCols, filterWidth + xCCorner); let maxValue = Number.NEGATIVE_INFINITY; let maxPosition = -1; for (let xR = xRMin; xR < xRMax; ++xR) { @@ -692,7 +631,7 @@ export class NDArrayMathCPU extends NDArrayMath { const pixel = x.get(xR, xC, d); if (pixel > maxValue) { maxValue = pixel; - maxPosition = wR * fSize + wC; + maxPosition = wR * filterWidth + wC; } } } @@ -704,39 +643,37 @@ export class NDArrayMathCPU extends NDArrayMath { } protected maxPoolBackpropInternal( - dy: Array3D, x: Array3D, fSize: number, origStride: number, - origPad: number): Array3D { - const maxPositions = this.maxPoolPositions(x, fSize, origStride, origPad); - const pad = fSize - 1 - origPad; + dy: Array3D, x: Array3D, convInfo: ConvInfo): Array3D { + const maxPositions = this.maxPoolPositions(x, convInfo); + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const padLeft = filterWidth - 1 - convInfo.padInfo.left; + const padTop = filterHeight - 1 - convInfo.padInfo.top; const [dyRows, dyCols, depth] = dy.shape; - - // Dilate the input. - const dyRowsDilated = (dyRows - 1) * origStride + 1; - const dxColsDilated = (dyCols - 1) * origStride + 1; - - const outputShape = conv_util.computeOutputShape3D( - [dyRowsDilated, dxColsDilated, depth], fSize, depth, 1, pad); - const dx = Array3D.zeros(outputShape); + const dx = Array3D.zeros(x.shape); for (let d = 0; d < depth; ++d) { for (let dxR = 0; dxR < dx.shape[0]; ++dxR) { for (let dxC = 0; dxC < dx.shape[1]; ++dxC) { // Shader code begins. - const dyRCorner = dxR - pad; - const dyCCorner = dxC - pad; + const dyRCorner = dxR - padTop; + const dyCCorner = dxC - padLeft; let dotProd = 0; - for (let wR = 0; wR < fSize; ++wR) { - const dyR = (dyRCorner + wR) / origStride; + for (let wR = 0; wR < filterHeight; ++wR) { + const dyR = (dyRCorner + wR) / strideHeight; if (dyR < 0 || dyR >= dyRows || Math.floor(dyR) !== dyR) { continue; } - for (let wC = 0; wC < fSize; ++wC) { - const dyC = (dyCCorner + wC) / origStride; + for (let wC = 0; wC < filterWidth; ++wC) { + const dyC = (dyCCorner + wC) / strideWidth; if (dyC < 0 || dyC >= dyCols || Math.floor(dyC) !== dyC) { continue; } - const maxPos = fSize * fSize - 1 - maxPositions.get(dyR, dyC, d); - const curPos = wR * fSize + wC; + const maxPos = filterHeight * filterWidth - 1 - + maxPositions.get(dyR, dyC, d); + const curPos = wR * filterWidth + wC; const mask = maxPos === curPos ? 1 : 0; if (mask === 0) { @@ -754,14 +691,12 @@ export class NDArrayMathCPU extends NDArrayMath { return dx; } - protected minPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D { - return this.pool(x, fSize, stride, pad, 'min'); + protected minPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D { + return this.pool(x, convInfo, 'min'); } - protected avgPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D { - return this.pool(x, fSize, stride, pad, 'avg'); + protected avgPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D { + return this.pool(x, convInfo, 'avg'); } protected resizeBilinear3DInternal( diff --git a/src/math/math_gpu.ts b/src/math/math_gpu.ts index 7af86a1358..00a68c1fbf 100644 --- a/src/math/math_gpu.ts +++ b/src/math/math_gpu.ts @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import {ConvInfo} from './conv_util'; import {MatrixOrientation, NDArrayMath} from './math'; import * as ndarray from './ndarray'; import {Array1D, Array2D, Array3D, Array4D, NDArray, Scalar} from './ndarray'; @@ -23,7 +24,7 @@ import {BatchNormProgram} from './webgl/batchnorm_gpu'; import {BinaryOpProgram} from './webgl/binaryop_gpu'; import {Concat3DProgram} from './webgl/concat3d_gpu'; // tslint:disable-next-line:max-line-length -import {Conv2DDerBiasProgram, Conv2DDerWeightsProgram, Conv2DTransposeProgram} from './webgl/conv_backprop_gpu'; +import {Conv2DDerBiasProgram, Conv2DDerInputProgram, Conv2DDerWeightsProgram} from './webgl/conv_backprop_gpu'; import {Conv2DProgram} from './webgl/conv_gpu'; import {Copy2DProgram} from './webgl/copy_gpu'; import {GPGPUContext} from './webgl/gpgpu_context'; @@ -279,84 +280,54 @@ export class NDArrayMathGPU extends NDArrayMath { } protected conv2dInternal( - x: Array3D, weights: Array4D, bias: Array1D|null, stride: number, - zeroPad: number): Array3D { - const fieldSize = weights.shape[0]; - const outputDepth = weights.shape[3]; - const program = new Conv2DProgram( - x.shape, fieldSize, outputDepth, stride, zeroPad, bias != null); - const inputs = bias != null ? [x, weights, bias] : [x, weights]; + x: Array3D, filter: Array4D, bias: Array1D|null, + convInfo: ConvInfo): Array3D { + const program = new Conv2DProgram(convInfo, bias != null); + const inputs = bias != null ? [x, filter, bias] : [x, filter]; return this.compileAndRun(program, inputs); } - protected conv2dBackPropInternal( - x: Array3D, dy: Array3D, weights: Array4D, stride: number, - pad: number): {dx: Array3D, dw: Array4D, db: Array1D} { - const fSize = weights.shape[0]; - const dw = this.conv2dDerWeights(x, dy, fSize, stride, pad); - const db = this.conv2dDerBias(dy); - const dx = this.conv2dTransposeInternal( - dy, weights, null /** biases */, stride, pad); - return {dx, db, dw}; - } - - protected conv2dTransposeInternal( - x: Array3D, weights: Array4D, bias: Array1D|null, origStride: number, - origPad: number): Array3D { - const origInputDepth = weights.shape[2]; - const fieldSize = weights.shape[0]; - const program = new Conv2DTransposeProgram( - x.shape, fieldSize, origInputDepth, origStride, origPad, bias != null); - const inputs = bias != null ? [x, weights, bias] : [x, weights]; - return this.compileAndRun(program, inputs); + protected conv2dDerInputInternal( + dy: Array3D, filter: Array4D, convInfo: ConvInfo): Array3D { + const program = new Conv2DDerInputProgram(convInfo); + return this.compileAndRun(program, [dy, filter]); } - conv2dDerWeights( - x: Array3D, dY: Array3D, fSize: number, stride: number, - zeroPad: number): Array4D { - const outputDepth = dY.shape[2]; - const program = new Conv2DDerWeightsProgram( - x.shape, fSize, outputDepth, stride, zeroPad); + protected conv2dDerFilterInternal( + x: Array3D, dY: Array3D, convInfo: ConvInfo): Array4D { + const program = new Conv2DDerWeightsProgram(convInfo); return this.compileAndRun(program, [x, dY]); } - conv2dDerBias(dY: Array3D): Array1D { + protected conv2dDerBiasInternal(dY: Array3D): Array1D { const program = new Conv2DDerBiasProgram(dY.shape); return this.compileAndRun(program, [dY]); } - protected maxPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D { - const program = - new Pool2DProgram(x.shape, fSize, stride, pad, 'max', false); + protected maxPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D { + const program = new Pool2DProgram(convInfo, 'max', false); return this.compileAndRun(program, [x]); } - protected minPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D { - const program = - new Pool2DProgram(x.shape, fSize, stride, pad, 'min', false); + protected minPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D { + const program = new Pool2DProgram(convInfo, 'min', false); return this.compileAndRun(program, [x]); } - protected avgPoolInternal( - x: Array3D, fSize: number, stride: number, pad: number): Array3D { - const program = - new Pool2DProgram(x.shape, fSize, stride, pad, 'avg', false); + protected avgPoolInternal(x: Array3D, convInfo: ConvInfo): Array3D { + const program = new Pool2DProgram(convInfo, 'avg', false); return this.compileAndRun(program, [x]); } protected maxPoolBackpropInternal( - dy: Array3D, x: Array3D, fSize: number, origStride: number, - origPad: number): Array3D { + dy: Array3D, x: Array3D, convInfo: ConvInfo): Array3D { const getPositions = true; - const maxPoolPositionsProgram = new Pool2DProgram( - x.shape, fSize, origStride, origPad, 'max', getPositions); + const maxPoolPositionsProgram = + new Pool2DProgram(convInfo, 'max', getPositions); const maxPoolPositions: Array3D = this.compileAndRun(maxPoolPositionsProgram, [x]); - const maxPoolBackPropProgram = - new MaxPool2DBackpropProgram(dy.shape, fSize, origStride, origPad); + const maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); const result = this.compileAndRun(maxPoolBackPropProgram, [dy, maxPoolPositions]); diff --git a/src/math/math_gpu_test.ts b/src/math/math_gpu_test.ts index cb7fc18017..68b138e854 100644 --- a/src/math/math_gpu_test.ts +++ b/src/math/math_gpu_test.ts @@ -1656,10 +1656,9 @@ describe('NDArrayMathGPU conv2dTranspose', () => { const x = Array3D.new(inputShape, [2]); const w = Array4D.new( [fSize, fSize, origInputDepth, origOutputDepth], [3, 1, 5, 0]); - const b = Array1D.new([1]); - const result = math.conv2dTranspose(x, w, b, origStride, origPad); - const expected = new Float32Array([7, 3, 11, 1]); + const result = math.conv2dTranspose(x, w, [2, 2, 1], origStride, origPad); + const expected = new Float32Array([6, 2, 10, 0]); expect(result.inGPU()).toBe(true); expect(result.shape).toEqual([2, 2, 1]); @@ -1667,7 +1666,6 @@ describe('NDArrayMathGPU conv2dTranspose', () => { x.dispose(); w.dispose(); - b.dispose(); }); it('throws when x is not rank 3', () => { @@ -1681,14 +1679,12 @@ describe('NDArrayMathGPU conv2dTranspose', () => { const x: any = Array2D.new([2, 1], [2, 2]); const w = Array4D.new( [fSize, fSize, origInputDepth, origOutputDepth], [3, 1, 5, 0]); - const b = Array1D.new([1]); - expect(() => math.conv2dTranspose(x, w, b, origStride, origPad)) + expect(() => math.conv2dTranspose(x, w, [2, 2, 1], origStride, origPad)) .toThrowError(); x.dispose(); w.dispose(); - b.dispose(); }); it('throws when weights is not rank 4', () => { @@ -1702,36 +1698,12 @@ describe('NDArrayMathGPU conv2dTranspose', () => { const x = Array3D.new(inputShape, [2]); // tslint:disable-next-line:no-any const w: any = Array3D.new([fSize, fSize, origInputDepth], [3, 1, 5, 0]); - const b = Array1D.new([1]); - - expect(() => math.conv2dTranspose(x, w, b, origStride, origPad)) - .toThrowError(); - - x.dispose(); - w.dispose(); - b.dispose(); - }); - - it('throws when biases is not rank 1', () => { - const origInputDepth = 1; - const origOutputDepth = 1; - const inputShape: [number, number, number] = [1, 1, origOutputDepth]; - const fSize = 2; - const origPad = 0; - const origStride = 1; - - const x = Array3D.new(inputShape, [2]); - const w = Array4D.new( - [fSize, fSize, origInputDepth, origOutputDepth], [3, 1, 5, 0]); - // tslint:disable-next-line:no-any - const b: any = Array2D.new([2, 1], [1, 2]); - expect(() => math.conv2dTranspose(x, w, b, origStride, origPad)) + expect(() => math.conv2dTranspose(x, w, [2, 2, 1], origStride, origPad)) .toThrowError(); x.dispose(); w.dispose(); - b.dispose(); }); it('throws when x depth does not match weights original output depth', () => { @@ -1746,14 +1718,12 @@ describe('NDArrayMathGPU conv2dTranspose', () => { const x = Array3D.new(inputShape, [2, 2]); const w = NDArray.randNormal( [fSize, fSize, origInputDepth, wrongOrigOutputDepth]); - const b = Array1D.new([1]); - expect(() => math.conv2dTranspose(x, w, b, origStride, origPad)) + expect(() => math.conv2dTranspose(x, w, [2, 2, 2], origStride, origPad)) .toThrowError(); x.dispose(); w.dispose(); - b.dispose(); }); }); @@ -1777,12 +1747,13 @@ describe('NDArrayMathGPU conv2dDerWeights', () => { const stride = 1; const pad = 0; - const weightsShape = [fSize, fSize, inputDepth, outputDepth]; + const weightsShape: [number, number, number, number] = + [fSize, fSize, inputDepth, outputDepth]; const x = Array3D.new(inputShape, [1, 2, 3, 4, 5, 6, 7, 8, 9]); const dy = Array3D.new([2, 2, 1], [3, 1, 2, 0]); - const result = math.conv2dDerWeights(x, dy, fSize, stride, pad); + const result = math.conv2dDerFilter(x, dy, weightsShape, stride, pad); const expected = new Float32Array([13, 19, 31, 37]); expect(result.inGPU()).toBe(true); diff --git a/src/math/webgl/conv_backprop_gpu.ts b/src/math/webgl/conv_backprop_gpu.ts index 77dc4eb45d..3f798065de 100644 --- a/src/math/webgl/conv_backprop_gpu.ts +++ b/src/math/webgl/conv_backprop_gpu.ts @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ import * as conv_util from '../conv_util'; +import {ConvInfo} from '../conv_util'; import {GPGPUProgram} from './gpgpu_math'; export class Conv2DDerWeightsProgram implements GPGPUProgram { @@ -22,18 +23,18 @@ export class Conv2DDerWeightsProgram implements GPGPUProgram { outputShape: number[]; userCode: string; - constructor( - xShape: [number, number, number], fSize: number, outputDepth: number, - stride: number, zeroPad: number) { - const yShape = conv_util.computeOutputShape3D( - xShape, fSize, outputDepth, stride, zeroPad); - const yNumRows = yShape[0]; - const yNumCols = yShape[1]; - const xNumRows = xShape[0]; - const xNumCols = xShape[1]; - this.outputShape = - conv_util.computeWeightsShape4D(xShape[2], outputDepth, fSize); - this.params = [stride, zeroPad]; + constructor(convInfo: ConvInfo) { + const [yNumRows, yNumCols, outDepth] = convInfo.outShape; + const [xNumRows, xNumCols, inDepth] = convInfo.inShape; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + this.outputShape = conv_util.computeWeightsShape4D( + inDepth, outDepth, convInfo.filterHeight, convInfo.filterWidth); + const padTop = convInfo.padInfo.top; + const padLeft = convInfo.padInfo.left; + + this.params = [strideHeight, strideWidth, padLeft, padTop]; + this.userCode = ` void main() { ivec4 coords = getOutputCoords(); @@ -46,14 +47,14 @@ export class Conv2DDerWeightsProgram implements GPGPUProgram { // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; for (int yR = 0; yR < ${yNumRows}; yR++) { - int xR = wR + yR * ${stride} - ${zeroPad}; + int xR = wR + yR * ${strideHeight} - ${padTop}; if (xR < 0 || xR >= ${xNumRows}) { continue; } for (int yC = 0; yC < ${yNumCols}; yC++) { - int xC = wC + yC * ${stride} - ${zeroPad}; + int xC = wC + yC * ${strideWidth} - ${padLeft}; if (xC < 0 || xC >= ${xNumCols}) { continue; @@ -70,69 +71,66 @@ export class Conv2DDerWeightsProgram implements GPGPUProgram { } } -export class Conv2DTransposeProgram implements GPGPUProgram { - variableNames = ['x', 'W', 'bias']; +export class Conv2DDerInputProgram implements GPGPUProgram { + variableNames = ['dy', 'W']; params: Array<{}>; outputShape: number[]; userCode: string; - constructor( - xShape: [number, number, number], fSize: number, origInputDepth: number, - origStride: number, origPad: number, hasBias: boolean) { - const [xRows, xCols, origOutputDepth] = xShape; - const biasSnippet = hasBias ? 'dotProd += getBias(d2);' : ''; - - // Figure out the output shape by dilating the input. - const xRowsDilated = (xRows - 1) * origStride + 1; - const xColsDilated = (xCols - 1) * origStride + 1; - const pad = fSize - 1 - origPad; - this.outputShape = conv_util.computeOutputShape3D( - [xRowsDilated, xColsDilated, origOutputDepth], fSize, origInputDepth, 1, - pad); - this.params = [pad, fSize, origStride, hasBias]; + constructor(convInfo: ConvInfo) { + const [yRows, yCols, outDepth] = convInfo.outShape; + + this.outputShape = convInfo.inShape; + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + + const padTop = filterHeight - 1 - convInfo.padInfo.top; + const padLeft = filterWidth - 1 - convInfo.padInfo.left; + this.params = [strideHeight, strideWidth, padLeft, padTop]; this.userCode = ` - const ivec2 pads = ivec2(${pad}, ${pad}); + const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec3 coords = getOutputCoords(); - int d2 = coords.z; + int d1 = coords.z; - ivec2 xRCCorner = coords.xy - pads; - int xRCorner = xRCCorner.x; - int xCCorner = xRCCorner.y; + ivec2 dyCorner = coords.xy - pads; + int dyRCorner = dyCorner.x; + int dyCCorner = dyCorner.y; - // Convolve x(?, ?, d1) with w(:, :, d2, d1) to get y(yR, yC, d2). + // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; - for (int wR = 0; wR < ${fSize}; wR++) { - float xR = float(xRCorner + wR) / ${origStride}.0; + for (int wR = 0; wR < ${filterHeight}; wR++) { + float dyR = float(dyRCorner + wR) / ${strideHeight}.0; - if (xR < 0.0 || xR >= ${xRows}.0 || fract(xR) > 0.0) { + if (dyR < 0.0 || dyR >= ${yRows}.0 || fract(dyR) > 0.0) { continue; } - int ixR = int(xR); + int idyR = int(dyR); - int wRPerm = ${fSize} - 1 - wR; + int wRPerm = ${filterHeight} - 1 - wR; - for (int wC = 0; wC < ${fSize}; wC++) { - float xC = float(xCCorner + wC) / ${origStride}.0; + for (int wC = 0; wC < ${filterWidth}; wC++) { + float dyC = float(dyCCorner + wC) / ${strideWidth}.0; - if (xC < 0.0 || xC >= ${xCols}.0 || fract(xC) > 0.0) { + if (dyC < 0.0 || dyC >= ${yCols}.0 || fract(dyC) > 0.0) { continue; } - int ixC = int(xC); + int idyC = int(dyC); - int wCPerm = ${fSize} - 1 - wC; + int wCPerm = ${filterWidth} - 1 - wC; - for (int d1 = 0; d1 < ${origOutputDepth}; d1++) { - float xValue = getX(ixR, ixC, d1); - float wValue = getW(wRPerm, wCPerm, d2, d1); + for (int d2 = 0; d2 < ${outDepth}; d2++) { + float xValue = getDy(idyR, idyC, d2); + float wValue = getW(wRPerm, wCPerm, d1, d2); dotProd += xValue * wValue; } } } - ${biasSnippet} setOutput(dotProd); } `; diff --git a/src/math/webgl/conv_backprop_gpu_derweights_test.ts b/src/math/webgl/conv_backprop_gpu_derweights_test.ts index 33e93cf9bb..e18c171f2a 100644 --- a/src/math/webgl/conv_backprop_gpu_derweights_test.ts +++ b/src/math/webgl/conv_backprop_gpu_derweights_test.ts @@ -34,10 +34,11 @@ describe('conv_gpu derWeights', () => { gpgpu.enableAutomaticDebugValidation(true); const outputDepth = dy.shape[2]; const inDepth = x.shape[2]; - const program = new Conv2DDerWeightsProgram( - x.shape, fSize, outputDepth, stride, zeroPad); + const convInfo = conv_util.computeConvInfo( + x.shape, fSize, fSize, outputDepth, stride, stride, zeroPad); + const program = new Conv2DDerWeightsProgram(convInfo); const out = Array4D.zeros( - conv_util.computeWeightsShape4D(inDepth, outputDepth, fSize)); + conv_util.computeWeightsShape4D(inDepth, outputDepth, fSize, fSize)); const binary = gpgpu_math.compileProgram(gpgpu, program, [x, dy], out); gpgpu_math.runProgram(binary, [x, dy], out); const result = out.getValues(); @@ -50,15 +51,17 @@ describe('conv_gpu derWeights', () => { } function compareToCPU( - inputShape: [number, number, number], fSize: number, outputDepth: number, + inputShape: [number, number, number], fSize: number, outDepth: number, stride: number, zeroPad: number) { const x = NDArray.randNormal(inputShape); const outputShape = conv_util.computeOutputShape3D( - x.shape, fSize, outputDepth, stride, zeroPad); + x.shape, fSize, outDepth, stride, zeroPad); const dy = NDArray.randNormal(outputShape); const mathCPU = new NDArrayMathCPU(); - const dwCPU = mathCPU.conv2dDerWeights(x, dy, fSize, stride, zeroPad); + const inDepth = x.shape[2]; + const dwCPU = mathCPU.conv2dDerFilter( + x, dy, [fSize, fSize, inDepth, outDepth], stride, zeroPad); const dwGPU = uploadDerWeightsDownload(x, dy, fSize, stride, zeroPad); test_util.expectArraysClose(dwGPU, dwCPU.getValues(), 1e-5); diff --git a/src/math/webgl/conv_backprop_transpose_gpu_test.ts b/src/math/webgl/conv_backprop_transpose_gpu_test.ts index 0227f80ab4..d2f6a099cf 100644 --- a/src/math/webgl/conv_backprop_transpose_gpu_test.ts +++ b/src/math/webgl/conv_backprop_transpose_gpu_test.ts @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ import * as test_util from '../../test_util'; +import * as conv_util from '../conv_util'; import {NDArrayMathCPU} from '../math_cpu'; -import {Array1D, Array3D, Array4D, initializeGPU, NDArray} from '../ndarray'; +import {Array3D, Array4D, initializeGPU, NDArray} from '../ndarray'; -import {Conv2DTransposeProgram} from './conv_backprop_gpu'; +import {Conv2DDerInputProgram} from './conv_backprop_gpu'; import {GPGPUContext} from './gpgpu_context'; import * as gpgpu_math from './gpgpu_math'; import {TextureManager} from './texture_manager'; @@ -25,17 +26,22 @@ import {TextureManager} from './texture_manager'; describe('conv_gpu transpose', () => { function uploadConvTransposeDownload( - x: Array3D, W: Array4D, bias: Array1D|null, fSize: number, - origStride: number, origPad: number): Float32Array { + x: Array3D, W: Array4D, origInputShape: [number, number, number], + fSize: number, origStride: number, origPad: number): Float32Array { const gpgpu = new GPGPUContext(); gpgpu.enableAutomaticDebugValidation(true); const textureManager = new TextureManager(gpgpu); initializeGPU(gpgpu, textureManager); - const origInputDepth = W.shape[2]; - const program = new Conv2DTransposeProgram( - x.shape, fSize, origInputDepth, origStride, origPad, bias != null); + + const filterHeight = W.shape[0]; + const filterWidth = W.shape[1]; + const origOutDepth = W.shape[3]; + const convInfo = conv_util.computeConvInfo( + origInputShape, filterHeight, filterWidth, origOutDepth, origStride, + origStride, origPad); + const program = new Conv2DDerInputProgram(convInfo); const res = NDArray.zeros(program.outputShape); - const inputs = bias != null ? [x, W, bias] : [x, W]; + const inputs = [x, W]; const binary = gpgpu_math.compileProgram(gpgpu, program, inputs, res); gpgpu_math.runProgram(binary, inputs, res); const resValues = res.getValues(); @@ -49,31 +55,32 @@ describe('conv_gpu transpose', () => { function compareToCPU( origInputShape: [number, number, number], fSize: number, origOutputDepth: number, origStride: number, origPad: number) { - const [xNumRows, xNumCols, origInputDepth] = origInputShape; + const origInputDepth = origInputShape[2]; - const x = - NDArray.randNormal([xNumRows, xNumCols, origOutputDepth]); + const convInfo = conv_util.computeConvInfo( + origInputShape, fSize, fSize, origOutputDepth, origStride, origStride, + origPad); + const x = NDArray.randNormal(convInfo.outShape); const weights = NDArray.randNormal( [fSize, fSize, origInputDepth, origOutputDepth]); - const biases = NDArray.randNormal([origInputDepth]); const mathCPU = new NDArrayMathCPU(); - const yCPU = - mathCPU.conv2dTranspose(x, weights, biases, origStride, origPad); + const yCPU = mathCPU.conv2dTranspose( + x, weights, origInputShape, origStride, origPad); const yGPU = uploadConvTransposeDownload( - x, weights, biases, fSize, origStride, origPad); + x, weights, origInputShape, fSize, origStride, origPad); test_util.expectArraysClose(yGPU, yCPU.getValues(), 1e-5); } it('matches CPU on random input, d1=1,d2=1,f=2,s=1,p=0', () => { const inputDepth = 1; - const inputShape: [number, number, number] = [8, 8, inputDepth]; + const origInputShape: [number, number, number] = [8, 8, inputDepth]; const fSize = 2; const outputDepth = 1; const stride = 1; const zeroPad = 0; - compareToCPU(inputShape, fSize, outputDepth, stride, zeroPad); + compareToCPU(origInputShape, fSize, outputDepth, stride, zeroPad); }); it('matches CPU on random input, d1=1,d2=1,f=3,s=2,p=1', () => { diff --git a/src/math/webgl/conv_gpu.ts b/src/math/webgl/conv_gpu.ts index 4565206861..41a3f89199 100644 --- a/src/math/webgl/conv_gpu.ts +++ b/src/math/webgl/conv_gpu.ts @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as conv_util from '../conv_util'; +import {ConvInfo} from '../conv_util'; import {GPGPUProgram} from './gpgpu_math'; export class Conv2DProgram implements GPGPUProgram { @@ -22,19 +22,22 @@ export class Conv2DProgram implements GPGPUProgram { outputShape: number[]; userCode: string; - constructor( - xShape: [number, number, number], fieldSize: number, outputDepth: number, - stride: number, pad: number, hasBias: boolean) { - this.outputShape = conv_util.computeOutputShape3D( - xShape, fieldSize, outputDepth, stride, pad); - const inputDepth = xShape[2]; - this.params = [fieldSize, stride, pad, hasBias]; + constructor(convInfo: ConvInfo, hasBias: boolean) { + this.outputShape = convInfo.outShape; const biasSnippet = hasBias ? 'dotProd += getBias(d2);' : ''; - const xNumRows = xShape[0]; - const xNumCols = xShape[1]; + const [xNumRows, xNumCols, inputDepth] = convInfo.inShape; + const padTop = convInfo.padInfo.top; + const padLeft = convInfo.padInfo.left; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + + this.params = [strideHeight, strideWidth, hasBias, padLeft, padTop]; + this.userCode = ` - const ivec2 strides = ivec2(${stride}, ${stride}); - const ivec2 pads = ivec2(${pad}, ${pad}); + const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); + const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec3 coords = getOutputCoords(); @@ -47,14 +50,14 @@ export class Conv2DProgram implements GPGPUProgram { // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; - for (int wR = 0; wR < ${fieldSize}; wR++) { + for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR; if (xR < 0 || xR >= ${xNumRows}) { continue; } - for (int wC = 0; wC < ${fieldSize}; wC++) { + for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC; if (xC < 0 || xC >= ${xNumCols}) { diff --git a/src/math/webgl/conv_gpu_test.ts b/src/math/webgl/conv_gpu_test.ts index ef33bca443..50721b5cd3 100644 --- a/src/math/webgl/conv_gpu_test.ts +++ b/src/math/webgl/conv_gpu_test.ts @@ -26,16 +26,18 @@ import {TextureManager} from './texture_manager'; describe('conv_gpu', () => { function uploadConvolveDownload( - xVals: Float32Array, xShapeRCD: [number, number, number], - weights: Float32Array, biasVals: Float32Array|null, resultDepth: number, - fieldSize: number, stride: number, zeroPad?: number): Float32Array { - zeroPad = zeroPad != null ? - zeroPad : - conv_util.computeDefaultPad(xShapeRCD, fieldSize, stride); - - const x = Array3D.new(xShapeRCD, xVals); - const wShape = - conv_util.computeWeightsShape4D(xShapeRCD[2], resultDepth, fieldSize); + xVals: Float32Array, xShape: [number, number, number], + weights: Float32Array, biasVals: Float32Array|null, outDepth: number, + filterSizes: [number, number]|number, strides: [number, number]|number, + zeroPad?: number|'valid'|'same'): Float32Array { + zeroPad = zeroPad != null ? zeroPad : 'same'; + + const [filterHeight, filterWidth] = parseTuple(filterSizes); + const [strideHeight, strideWidth] = parseTuple(strides); + + const x = Array3D.new(xShape, xVals); + const wShape = conv_util.computeWeightsShape4D( + xShape[2], outDepth, filterHeight, filterWidth); const W = Array4D.new(wShape, weights); const b = biasVals != null ? Array1D.new(biasVals) : null; @@ -44,8 +46,10 @@ describe('conv_gpu', () => { const textureManager = new TextureManager(gpgpu); initializeGPU(gpgpu, textureManager); - const program = new Conv2DProgram( - xShapeRCD, fieldSize, resultDepth, stride, zeroPad, biasVals != null); + const convInfo = conv_util.computeConvInfo( + xShape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + zeroPad); + const program = new Conv2DProgram(convInfo, biasVals != null); const res = NDArray.zeros(program.outputShape); const inputs = biasVals != null ? [x, W, b] : [x, W]; const binary = gpgpu_math.compileProgram(gpgpu, program, inputs, res); @@ -231,6 +235,24 @@ describe('conv_gpu', () => { expect(result[3]).toBe(12); }); + it('2x2x1 in, 1d out, 2x1 filter, s=1, p=valid', () => { + const x = new Float32Array([1, 2, 3, 4]); + const w = new Float32Array([3, 5]); + const bias: Float32Array = null; + const result = + uploadConvolveDownload(x, [2, 2, 1], w, bias, 1, [2, 1], 1, 'valid'); + expect(result).toEqual(new Float32Array([18, 26])); + }); + + it('2x2x1 in, 1d out, 1x2 filter, s=1, p=valid', () => { + const x = new Float32Array([1, 2, 3, 4]); + const w = new Float32Array([3, 5]); + const bias: Float32Array = null; + const result = + uploadConvolveDownload(x, [2, 2, 1], w, bias, 1, [1, 2], 1, 'valid'); + expect(result).toEqual(new Float32Array([13, 29])); + }); + it('2x2x1 in, 1d out, 2x2 filter, 1 stride, bias=-1', () => { const x = new Float32Array([1, 2, 3, 4]); const w = new Float32Array([3, 1, 5, 0]); @@ -367,3 +389,7 @@ describe('conv_gpu', () => { compareToCPU(inputShape, fSize, outputDepth, stride, zeroPad); }); }); + +function parseTuple(a: number|[number, number]): [number, number] { + return typeof a === 'number' ? [a, a] : a; +} diff --git a/src/math/webgl/max_pool_backprop_gpu.ts b/src/math/webgl/max_pool_backprop_gpu.ts index 083e2d7399..5cc4af7a73 100644 --- a/src/math/webgl/max_pool_backprop_gpu.ts +++ b/src/math/webgl/max_pool_backprop_gpu.ts @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as conv_util from '../conv_util'; +import {ConvInfo} from '../conv_util'; import {GPGPUProgram} from './gpgpu_math'; @@ -23,23 +23,23 @@ export class MaxPool2DBackpropProgram implements GPGPUProgram { outputShape: number[]; userCode: string; - constructor( - dyShape: [number, number, number], fSize: number, origStride: number, - origPad: number) { - const pad = fSize - 1 - origPad; - const dyRows = dyShape[0]; - const dyCols = dyShape[1]; - this.params = [fSize, origStride, origPad]; - - const dilatedDyRC = - conv_util.computeDilatedRC([dyShape[0], dyShape[1]], origStride); - this.outputShape = conv_util.computeOutputShape3D( - [dilatedDyRC[0], dilatedDyRC[1], dyShape[2]], fSize, dyShape[2], 1, - pad); - - const lastIndex = fSize * fSize - 1; + constructor(convInfo: ConvInfo) { + this.outputShape = convInfo.inShape; + const dyRows = convInfo.outShape[0]; + const dyCols = convInfo.outShape[1]; + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + + const padTop = filterHeight - 1 - convInfo.padInfo.top; + const padLeft = filterWidth - 1 - convInfo.padInfo.left; + this.params = + [filterHeight, filterWidth, strideHeight, strideWidth, padTop, padLeft]; + + const lastIndex = filterHeight * filterWidth - 1; this.userCode = ` - const ivec2 pads = ivec2(${pad}, ${pad}); + const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec3 coords = getOutputCoords(); @@ -52,16 +52,16 @@ export class MaxPool2DBackpropProgram implements GPGPUProgram { // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d). // ? = to be determined. : = across all values in that axis. float dotProd = 0.0; - for (int wR = 0; wR < ${fSize}; wR++) { - float dyR = float(dyRCorner + wR) / ${origStride}.0; + for (int wR = 0; wR < ${filterHeight}; wR++) { + float dyR = float(dyRCorner + wR) / ${strideHeight}.0; if (dyR < 0.0 || dyR >= ${dyRows}.0 || fract(dyR) > 0.0) { continue; } int idyR = int(dyR); - for (int wC = 0; wC < ${fSize}; wC++) { - float dyC = float(dyCCorner + wC) / ${origStride}.0; + for (int wC = 0; wC < ${filterWidth}; wC++) { + float dyC = float(dyCCorner + wC) / ${strideWidth}.0; if (dyC < 0.0 || dyC >= ${dyCols}.0 || fract(dyC) > 0.0) { continue; @@ -73,7 +73,7 @@ export class MaxPool2DBackpropProgram implements GPGPUProgram { // Get the current value, check it against the value from the // position matrix. - int curPosValue = wR * ${fSize} + wC; + int curPosValue = wR * ${filterWidth} + wC; float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0); dotProd += dyValue * mask; diff --git a/src/math/webgl/max_pool_backprop_gpu_test.ts b/src/math/webgl/max_pool_backprop_gpu_test.ts index e674dd84d0..7d7164e91d 100644 --- a/src/math/webgl/max_pool_backprop_gpu_test.ts +++ b/src/math/webgl/max_pool_backprop_gpu_test.ts @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ import * as test_util from '../../test_util'; +import * as conv_util from '../conv_util'; import {NDArrayMathCPU} from '../math_cpu'; import {Array3D, initializeGPU, NDArray} from '../ndarray'; - import {GPGPUContext} from './gpgpu_context'; import * as gpgpu_math from './gpgpu_math'; import {MaxPool2DBackpropProgram} from './max_pool_backprop_gpu'; @@ -34,15 +34,15 @@ describe('max_pool_backprop_gpu', () => { initializeGPU(gpgpu, textureManager); const getPositions = true; - const positionsProgram = new Pool2DProgram( - x.shape, fSize, origStride, origPad, 'max', getPositions); + const outDepth = x.shape[2]; + const convInfo = conv_util.computeConvInfo( + x.shape, fSize, fSize, outDepth, origStride, origStride, origPad); + const positionsProgram = new Pool2DProgram(convInfo, 'max', getPositions); const positionsRes = NDArray.zeros(positionsProgram.outputShape); const positionsBinary = gpgpu_math.compileProgram(gpgpu, positionsProgram, [x], positionsRes); gpgpu_math.runProgram(positionsBinary, [x], positionsRes); - - const program = - new MaxPool2DBackpropProgram(dy.shape, fSize, origStride, origPad); + const program = new MaxPool2DBackpropProgram(convInfo); const res = NDArray.zeros(program.outputShape); const binary = gpgpu_math.compileProgram(gpgpu, program, [dy, positionsRes], res); diff --git a/src/math/webgl/max_pool_gpu_test.ts b/src/math/webgl/max_pool_gpu_test.ts index 8f9380ca80..266c92c624 100644 --- a/src/math/webgl/max_pool_gpu_test.ts +++ b/src/math/webgl/max_pool_gpu_test.ts @@ -20,10 +20,11 @@ import * as pool_gpu_test_util from './pool_gpu_test_util'; describe('max_pool_gpu', () => { function uploadMaxPoolDownload( - a: Float32Array, xShape: [number, number, number], fieldSize: number, - stride: number, zeroPad: number): Float32Array { + a: Float32Array, xShape: [number, number, number], + filterSizes: [number, number]|number, strides: [number, number]|number, + zeroPad: number|'valid'|'same'): Float32Array { return pool_gpu_test_util.uploadPoolDownload( - a, xShape, fieldSize, stride, zeroPad, 'max'); + a, xShape, filterSizes, strides, zeroPad, 'max'); } function compareToCPU( @@ -74,4 +75,18 @@ describe('max_pool_gpu', () => { const zeroPad = 1; compareToCPU(inputShape, fSize, stride, zeroPad); }); + + it('non even filter 1x2 on 3x3 input', () => { + const x = Array3D.new([3, 3, 1], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const res = + uploadMaxPoolDownload(x.getValues(), x.shape, [1, 2], [1, 1], 'valid'); + expect(res).toEqual(new Float32Array([2, 3, 5, 6, 8, 9])); + }); + + it('non even filter 2x1 on 3x3 input', () => { + const x = Array3D.new([3, 3, 1], [1, 2, 3, 4, 5, 6, 7, 8, 9]); + const res = + uploadMaxPoolDownload(x.getValues(), x.shape, [2, 1], [1, 1], 'valid'); + expect(res).toEqual(new Float32Array([4, 5, 6, 7, 8, 9])); + }); }); diff --git a/src/math/webgl/max_pool_positions_gpu_test.ts b/src/math/webgl/max_pool_positions_gpu_test.ts index 59ca005408..61547ac152 100644 --- a/src/math/webgl/max_pool_positions_gpu_test.ts +++ b/src/math/webgl/max_pool_positions_gpu_test.ts @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ import * as test_util from '../../test_util'; +import * as conv_util from '../conv_util'; import {NDArrayMathCPU} from '../math_cpu'; import {Array3D, initializeGPU, NDArray} from '../ndarray'; @@ -31,8 +32,10 @@ describe('max_pool_position', () => { const textureManager = new TextureManager(gpgpu); initializeGPU(gpgpu, textureManager); const getPositions = true; - const program = - new Pool2DProgram(xShape, fieldSize, stride, pad, 'max', getPositions); + const outDepth = xShape[2]; + const convInfo = conv_util.computeConvInfo( + xShape, fieldSize, fieldSize, outDepth, stride, stride, pad); + const program = new Pool2DProgram(convInfo, 'max', getPositions); const res = NDArray.zeros(program.outputShape); const x = Array3D.new(xShape, xVals); const binary = gpgpu_math.compileProgram(gpgpu, program, [x], res); @@ -51,7 +54,10 @@ describe('max_pool_position', () => { const x = NDArray.randNormal(xShape); const mathCPU = new NDArrayMathCPU(); - const yCPU = mathCPU.maxPoolPositions(x, fSize, stride, pad); + const outDepth = x.shape[2]; + const convInfo = conv_util.computeConvInfo( + x.shape, fSize, fSize, outDepth, stride, stride, pad); + const yCPU = mathCPU.maxPoolPositions(x, convInfo); const yGPU = uploadMaxPoolPositionDownload( x.getValues(), x.shape, fSize, stride, pad); test_util.expectArraysClose(yGPU, yCPU.getValues(), 1e-5); diff --git a/src/math/webgl/pool_gpu.ts b/src/math/webgl/pool_gpu.ts index fd35f1cd39..de0545e2ae 100644 --- a/src/math/webgl/pool_gpu.ts +++ b/src/math/webgl/pool_gpu.ts @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -import * as conv_util from '../conv_util'; +import {ConvInfo} from '../conv_util'; import {GPGPUProgram} from './gpgpu_math'; export class Pool2DProgram implements GPGPUProgram { @@ -23,30 +23,38 @@ export class Pool2DProgram implements GPGPUProgram { userCode: string; constructor( - xShape: [number, number, number], fSize: number, stride: number, - pad: number, poolType: 'max'|'min'|'avg', computePositions: boolean) { + convInfo: ConvInfo, poolType: 'max'|'min'|'avg', + computePositions: boolean) { if (poolType === 'avg' && computePositions) { throw new Error('Cannot compute positions for average pool.'); } + const filterHeight = convInfo.filterHeight; + const filterWidth = convInfo.filterWidth; + const strideHeight = convInfo.strideHeight; + const strideWidth = convInfo.strideWidth; + let returnValue = 'minMaxValue'; if (computePositions) { returnValue = 'float(minMaxPosition)'; } else if (poolType === 'avg') { - returnValue = `avgValue / ${fSize * fSize}.0`; + returnValue = `avgValue / ${filterHeight * filterWidth}.0`; } - const xRowsLimit = xShape[0]; - const xColsLimit = xShape[1]; - this.params = [stride, pad, fSize, poolType, computePositions]; - this.outputShape = - conv_util.computeOutputShape3D(xShape, fSize, xShape[2], stride, pad); + const xNumRows = convInfo.inShape[0]; + const xNumCols = convInfo.inShape[1]; + const padTop = convInfo.padInfo.top; + const padLeft = convInfo.padInfo.left; + this.params = [ + strideHeight, strideWidth, padLeft, padTop, poolType, computePositions + ]; + this.outputShape = convInfo.outShape; const isAvgPool = poolType === 'avg'; const compareOp = poolType === 'min' ? '<=' : '>='; this.userCode = ` - const ivec2 strides = ivec2(${stride}, ${stride}); - const ivec2 pads = ivec2(${pad}, ${pad}); + const ivec2 strides = ivec2(${strideHeight}, ${strideWidth}); + const ivec2 pads = ivec2(${padTop}, ${padLeft}); void main() { ivec3 coords = getOutputCoords(); @@ -63,17 +71,17 @@ export class Pool2DProgram implements GPGPUProgram { int minMaxPosition = 0; float avgValue = 0.0; - for (int wR = 0; wR < ${fSize}; wR++) { + for (int wR = 0; wR < ${filterHeight}; wR++) { int xR = xRCorner + wR; - if (xR < 0 || xR >= ${xRowsLimit}) { + if (xR < 0 || xR >= ${xNumRows}) { continue; } - for (int wC = 0; wC < ${fSize}; wC++) { + for (int wC = 0; wC < ${filterWidth}; wC++) { int xC = xCCorner + wC; - if (xC < 0 || xC >= ${xColsLimit}) { + if (xC < 0 || xC >= ${xNumCols}) { continue; } @@ -95,7 +103,7 @@ export class Pool2DProgram implements GPGPUProgram { minMaxValue = value; minMaxValueFound = 1.0; if (${computePositions}) { - minMaxPosition = wR * ${fSize} + wC; + minMaxPosition = wR * ${filterWidth} + wC; } } } diff --git a/src/math/webgl/pool_gpu_test_util.ts b/src/math/webgl/pool_gpu_test_util.ts index a06be73b31..e3510cd2b8 100644 --- a/src/math/webgl/pool_gpu_test_util.ts +++ b/src/math/webgl/pool_gpu_test_util.ts @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +import * as conv_util from '../conv_util'; import {Array3D, initializeGPU, NDArray} from '../ndarray'; import {GPGPUContext} from './gpgpu_context'; @@ -20,16 +21,22 @@ import {Pool2DProgram} from './pool_gpu'; import {TextureManager} from './texture_manager'; export function uploadPoolDownload( - a: Float32Array, xShape: [number, number, number], fieldSize: number, - stride: number, zeroPad: number, op: 'min'|'max'|'avg'): Float32Array { + a: Float32Array, xShape: [number, number, number], + filterSizes: [number, number]|number, strides: [number, number]|number, + zeroPad: number|'valid'|'same', op: 'min'|'max'|'avg'): Float32Array { const gpgpu = new GPGPUContext(); gpgpu.enableAutomaticDebugValidation(true); const textureManager = new TextureManager(gpgpu); initializeGPU(gpgpu, textureManager); const x = Array3D.new(xShape, a); - const program = - new Pool2DProgram(xShape, fieldSize, stride, zeroPad, op, false); + const outDepth = x.shape[2]; + const [filterHeight, filterWidth] = parseTuple(filterSizes); + const [strideHeight, strideWidth] = parseTuple(strides); + const convInfo = conv_util.computeConvInfo( + xShape, filterHeight, filterWidth, outDepth, strideHeight, strideWidth, + zeroPad); + const program = new Pool2DProgram(convInfo, op, false); const res = NDArray.zeros(program.outputShape); const binary = gpgpu_math.compileProgram(gpgpu, program, [x], res); gpgpu_math.runProgram(binary, [x], res); @@ -40,3 +47,7 @@ export function uploadPoolDownload( gpgpu.dispose(); return resValues; } + +function parseTuple(a: number|[number, number]): [number, number] { + return typeof a === 'number' ? [a, a] : a; +} diff --git a/src/ops/convolution.ts b/src/ops/convolution.ts index 22d6efb0b3..6daa72c759 100644 --- a/src/ops/convolution.ts +++ b/src/ops/convolution.ts @@ -17,7 +17,7 @@ import {Tensor} from '../graph'; import * as conv_util from '../math/conv_util'; import {NDArrayMath} from '../math/math'; import {Array1D, Array3D, Array4D} from '../math/ndarray'; -import {TensorArrayMap, SummedTensorArrayMap} from '../tensor_array_map'; +import {SummedTensorArrayMap, TensorArrayMap} from '../tensor_array_map'; import * as util from '../util'; import {Operation} from './op'; diff --git a/src/ops/convolution_test.ts b/src/ops/convolution_test.ts index aa2e79e9de..0b7d274058 100644 --- a/src/ops/convolution_test.ts +++ b/src/ops/convolution_test.ts @@ -17,7 +17,7 @@ import {Tensor} from '../graph'; import * as conv_util from '../math/conv_util'; import {NDArrayMathCPU} from '../math/math_cpu'; import {Array1D, Array2D, Array3D, Array4D, NDArray} from '../math/ndarray'; -import {TensorArrayMap, SummedTensorArrayMap} from '../tensor_array_map'; +import {SummedTensorArrayMap, TensorArrayMap} from '../tensor_array_map'; import {Convolution2D} from './convolution'; diff --git a/src/ops/max_pool_test.ts b/src/ops/max_pool_test.ts index e66c7a6491..0e7606c3a1 100644 --- a/src/ops/max_pool_test.ts +++ b/src/ops/max_pool_test.ts @@ -17,7 +17,7 @@ import {Tensor} from '../graph'; import * as conv_util from '../math/conv_util'; import {NDArrayMathCPU} from '../math/math_cpu'; import {Array3D, NDArray} from '../math/ndarray'; -import {TensorArrayMap, SummedTensorArrayMap} from '../tensor_array_map'; +import {SummedTensorArrayMap, TensorArrayMap} from '../tensor_array_map'; import * as test_util from '../test_util'; import {MaxPool} from './max_pool';