diff --git a/demos/fast-style-transfer/net.ts b/demos/fast-style-transfer/net.ts index 87656001c9..8d51717c72 100644 --- a/demos/fast-style-transfer/net.ts +++ b/demos/fast-style-transfer/net.ts @@ -77,7 +77,7 @@ export class TransformNet implements dl.Model { return convT3.tanh() .mul(this.timesScalar) .add(this.plusScalar) - .clip(0, 255) + .clipByValue(0, 255) .div(dl.scalar(255)) as dl.Tensor3D; }); diff --git a/src/gradients.ts b/src/gradients.ts index 0d06317916..60a960916f 100644 --- a/src/gradients.ts +++ b/src/gradients.ts @@ -49,6 +49,28 @@ export class Gradients { * `x` is computed instead. `f(x)` must take a single tensor `x` and return a * single tensor `y`. If `f()` takes multiple inputs, use `grads` instead. * + * ```js + * // f(x) = x ^ 2 + * const f = x => x.square(); + * // f'(x) = 2x + * const g = dl.grad(f); + * + * const x = dl.tensor1d([2, 3]); + * g(x).print(); + * ``` + * + * ```js + * // f(x) = x ^ 3 + * const f = x => x.pow(dl.scalar(3, 'int32')); + * // f'(x) = 3x ^ 2 + * const g = dl.grad(f); + * // f''(x) = 6x + * const gg = dl.grad(g); + * + * const x = dl.tensor1d([2, 3]); + * gg(x).print(); + * ``` + * * @param f The function f(x), to compute gradient for. */ @doc({heading: 'Training', subheading: 'Gradients'}) @@ -85,6 +107,21 @@ export class Gradients { * The provided `f` must take one or more tensors and return a single tensor * `y`. If `f()` takes a single input, we recommend using `grad` instead. * + * ```js + * // f(a, b) = a * b + * const f = (a, b) => a.mul(b); + * // df / da = b, df / db = a + * const g = dl.grads(f); + * + * const a = dl.tensor1d([2, 3]); + * const b = dl.tensor1d([-2, -3]); + * const [da, db] = g([a, b]); + * console.log('da'); + * da.print(); + * console.log('db'); + * db.print(); + * ``` + * * @param f The function `f(x1, x2,...)` to compute gradients for. */ @doc({heading: 'Training', subheading: 'Gradients'}) @@ -119,6 +156,21 @@ export class Gradients { * The result is a rich object with the following properties: * - grad: The gradient of `f(x)` w.r.t `x` (result of `grad`). * - value: The value returned by `f(x)`. + * + * ```js + * // f(x) = x ^ 2 + * const f = x => x.square(); + * // f'(x) = 2x + * const g = dl.valueAndGrad(f); + * + * const x = dl.tensor1d([2, 3]); + * const {value, grad} = g(x); + * + * console.log('value'); + * value.print(); + * console.log('grad'); + * grad.print(); + * ``` */ @doc({heading: 'Training', subheading: 'Gradients'}) static valueAndGrad(f: (x: I) => O): @@ -149,6 +201,27 @@ export class Gradients { * The result is a rich object with the following properties: * - grads: The gradients of `f()` w.r.t each input (result of `grads`). * - value: The value returned by `f(x)`. + * + * ```js + * // f(a, b) = a * b + * const f = (a, b) => a.mul(b); + * // df/da = b, df/db = a + * const g = dl.valueAndGrads(f); + * + * const a = dl.tensor1d([2, 3]); + * const b = dl.tensor1d([-2, -3]); + * const {value, grads} = g([a, b]); + * + * const [da, db] = grads; + * + * console.log('value'); + * value.print(); + * + * console.log('da'); + * da.print(); + * console.log('db'); + * db.print(); + * ``` */ @doc({heading: 'Training', subheading: 'Gradients'}) static valueAndGrads(f: (...args: Tensor[]) => O): @@ -183,9 +256,20 @@ export class Gradients { * trainable variables provided by `varList`. If no list is provided, it * defaults to all trainable variables. * + * ```js + * const a = dl.variable(dl.tensor1d([3, 4])); + * const b = dl.variable(dl.tensor1d([5, 6])); + * const x = dl.tensor1d([1, 2]); + * + * // f(a, b) = a * x ^ 2 + b * x + * const f = () => a.mul(x.square()).add(b.mul(x)).sum(); + * // df/da = x ^ 2, df/db = x + * const {value, grads} = dl.variableGrads(f); + * + * Object.keys(grads).forEach(varName => grads[varName].print()); + * ``` + * * @param f The function to execute. f() should return a scalar. - * @param varList An optional list of variables to provide gradients with - * respect to. Defaults to all trainable variables. */ @doc({heading: 'Training', subheading: 'Gradients'}) static variableGrads(f: () => Scalar, varList?: Variable[]): @@ -239,6 +323,21 @@ export class Gradients { * called, `g` returns `f().value`. In backward mode, custom gradients with * respect to each input of `f` are computed using `f().gradFunc`. * + * ```js + * const customOp = dl.customGrad(x => { + * // Override gradient of our custom x ^ 2 op to be dy * abs(x); + * return {value: x.square(), gradFunc: dy => [dy.mul(x.abs())]}; + * }); + * + * const x = dl.tensor1d([-1, -2, 3]); + * const dx = dl.grad(x => customOp(x)); + * + * console.log(`f(x):`); + * customOp(x).print(); + * console.log(`f'(x):`); + * dx(x).print(); + * ``` + * * @param f The function to evaluate in forward mode, which should return * `{value: Tensor, gradFunc: (dy) => Tensor[]}`, where `gradFunc` returns * the custom gradients of `f` with respect to its inputs. diff --git a/src/math.ts b/src/math.ts index 9ad583bf8b..bb22a5c29e 100644 --- a/src/math.ts +++ b/src/math.ts @@ -84,7 +84,6 @@ export class NDArrayMath { conv2dDerInput = conv.Ops.conv2dDerInput; argMax = reduction_ops.Ops.argMax; - argMaxEquals = reduction_ops.Ops.argMaxEquals; argMin = reduction_ops.Ops.argMin; logSumExp = reduction_ops.Ops.logSumExp; max = reduction_ops.Ops.max; @@ -372,6 +371,12 @@ export class NDArrayMath { const res = ops.conv2d(x, filter, strides, pad, dimRoundingMode); return res.add(bias) as T; } + + /** @deprecated */ + argMaxEquals(x1: Tensor, x2: Tensor): Scalar { + util.assertShapesMatch(x1.shape, x2.shape, 'Error in argMaxEquals: '); + return x1.argMax().equal(x2.argMax()); + } } export type ScopeFn = diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index 3cd52827e7..fa1285a0db 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -32,18 +32,18 @@ export class Ops { * * ```js * // Pass an array of values to create a vector. - * dl.tensor([1, 2, 3, 4]).print() // shape: [4] + * dl.tensor([1, 2, 3, 4]).print(); * ``` * * ```js * // Pass a nested array of values to make a matrix or a higher * // dimensional tensor. - * dl.tensor([[1, 2], [3, 4]]).print(); // shape: [2, 2] + * dl.tensor([[1, 2], [3, 4]]).print(); * ``` * * ```js * // Pass a flat array and specify a shape yourself. - * dl.tensor([1, 2, 3, 4], [2, 2]).print(); // shape: [2, 2] + * dl.tensor([1, 2, 3, 4], [2, 2]).print(); * ``` * * @param values The values of the tensor. Can be nested array of numbers, @@ -339,6 +339,10 @@ export class Ops { /** * Creates a `Tensor` with values sampled from a normal distribution. * + * ```js + * dl.randomNormal([2, 2]).print(); + * ``` + * * @param shape An array of integers defining the output tensor shape. * @param mean The mean of the normal distribution. * @param stdDev The standard deviation of the normal distribution. @@ -362,6 +366,10 @@ export class Ops { * Creates a `Tensor` with values sampled from a truncated normal * distribution. * + * ```js + * dl.truncatedNormal([2, 2]).print(); + * ``` + * * The generated values follow a normal distribution with specified mean and * standard deviation, except that values whose magnitude is more than 2 * standard deviations from the mean are dropped and re-picked. @@ -392,6 +400,10 @@ export class Ops { * maxval). The lower bound minval is included in the range, while the upper * bound maxval is excluded. * + * ```js + * dl.randomUniform([2, 2]).print(); + * ``` + * * @param shape An array of integers defining the output tensor shape. * @param minval The lower bound on the range of random values to generate. * Defaults to 0. @@ -489,6 +501,10 @@ export class Ops { * value `onValue` (defaults to 1), while all other locations take value * `offValue` (defaults to 0). * + * ```js + * dl.oneHot(dl.tensor1d([0, 1]), 3).print(); + * ``` + * * @param indices 1D Array of indices. * @param depth The depth of the one hot dimension. * @param onValue A number used to fill in output when the index matches @@ -510,6 +526,16 @@ export class Ops { /** * Creates a `Tensor` from an image. * + * ```js + * const image = new ImageData(1, 1); + * image.data[0] = 100; + * image.data[1] = 150; + * image.data[2] = 200; + * image.data[3] = 255; + * + * dl.fromPixels(image).print(); + * ``` + * * @param pixels The input image to construct the tensor from. Accepts image * of type `ImageData`, `HTMLImageElement`, `HTMLCanvasElement`, or * `HTMLVideoElement`. @@ -545,6 +571,11 @@ export class Ops { * elements implied by shape must be the same as the number of elements in * tensor. * + * ```js + * const x = dl.tensor1d([1, 2, 3, 4]); + * x.reshape([2, 2]).print(); + * ``` + * * @param x A tensor. * @param shape An array of integers defining the output tensor shape. */ @@ -567,6 +598,10 @@ export class Ops { /** * Removes dimensions of size 1 from the shape of a `Tensor`. * + * ```js + * const x = dl.tensor([1, 2, 3, 4], [1, 1, 4]); + * x.squeeze().print(); + * ``` * @param axis An optional list of numbers. If specified, only * squeezes the dimensions listed. The dimension index starts at 0. It is * an error to squeeze a dimension that is not 1. @@ -578,6 +613,11 @@ export class Ops { /** * Casts a tensor to a new dtype. + * + * ```js + * const x = dl.tensor1d([1.5, 2.5, 3]); + * dl.cast(x, 'int32').print(); + * ``` * @param x A tensor. * @param dtype The dtype to cast the input tensor to. */ @@ -600,6 +640,17 @@ export class Ops { * `reps[i]` times along the i'th dimension. For example, tiling * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`. * + * ```js + * const a = dl.tensor1d([1, 2]); + * + * a.tile([2]).print(); // or a.tile([2]) + * ``` + * + * ```js + * const a = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * a.tile([1, 2]).print(); // or a.tile([1, 2]) + * ``` * @param x The tensor to transpose. * @param reps Determines the number of replications per dimension. */ @@ -616,6 +667,19 @@ export class Ops { /** * Gather slices from tensor `x`'s axis `axis` according to `indices`. * + * ```js + * const x = dl.tensor1d([1, 2, 3, 4]); + * const indices = dl.tensor1d([1, 3, 3]); + * + * x.gather(indices).print(); + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * const indices = dl.tensor1d([1, 1, 0]); + * + * x.gather(indices).print(); + * ``` * @param x The input tensor. * @param indices The indices of the values to extract. * @param axis The axis over which to select values. Defaults to 0. @@ -678,6 +742,10 @@ export class Ops { * This operation currently only implements the `CONSTANT` mode from * Tensorflow's `pad` operation. * + * ```js + * const x = dl.tensor1d([1, 2, 3, 4]); + * x.pad([[1, 2]]).print(); + * ``` * @param x The tensor to pad. * @param paddings An array of length `R` (the rank of the tensor), where each * element is a length-2 tuple of ints `[padBefore, padAfter]`, specifying @@ -705,10 +773,17 @@ export class Ops { /** * Stacks a list of rank-`R` `Tensor`s into one rank-`(R+1)` `Tensor`. * + * ```js + * const a = dl.tensor1d([1, 2]); + * const b = dl.tensor1d([3, 4]); + * const c = dl.tensor1d([5, 6]); + * dl.stack([a, b, c]).print(); + * ``` + * * @param tensors A list of tensor objects with the same shape and dtype. * @param axis The axis to stack along. Defaults to 0 (the first dim). */ - @doc({heading: 'Tensors', subheading: 'Transformations'}) + @doc({heading: 'Tensors', subheading: 'Slicing and Joining'}) @operation static stack(tensors: T[], axis = 0): Tensor { util.assert(tensors.length >= 2, 'Pass at least two tensors to dl.stack'); @@ -737,6 +812,12 @@ export class Ops { * Returns a `Tensor` that has expanded rank, by inserting a dimension * into the tensor's shape. * + * ```js + * const x = dl.tensor1d([1, 2, 3, 4]); + * const axis = 1; + * x.expandDims(axis).print(); + * ``` + * * @param axis The dimension index at which to insert shape of `1`. Defaults * to 0 (the first dimension). */ @@ -786,6 +867,10 @@ export class Ops { * excludes stop. Decrementing ranges and negative step values are also * supported. * + * ```js + * dl.range(0, 9, 2).print(); + * ``` + * * @param start An integer start value * @param stop An integer stop value * @param step An integer increment (will default to 1 or -1) @@ -860,26 +945,40 @@ export class Ops { /** * Prints information about the `Tensor` including its data. * - * @param verbose Whether to print verbose information about the `Tensor`, + * ```js + * const verbose = true; + * dl.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose); + * ``` + * + * @param verbose Whether to print verbose information about the ` Tensor`, * including dtype and size. */ @doc({heading: 'Tensors', subheading: 'Creation'}) static print(x: T, verbose = false): void { const C = class Tensor { shape: number[]; - data: number[]; + values: number[]; dtype: string; size: number; }; const displayTensor = new C(); displayTensor.shape = x.shape; - displayTensor.data = Array.from(x.dataSync()); + displayTensor.values = Array.from(x.dataSync()); displayTensor.toString = function() { - return `Tensor {\n` + - ` data: [${this.data.join(', ')}],\n` + - ` shape: [${x.shape.join(', ')}]\n` + - `}`; + const fields = [ + `values: [${this.values.join(', ')}]`, `shape: [${x.shape.join(', ')}]`, + `rank: ${x.rank}` + ]; + if (verbose) { + fields.push(`dtype: '${this.dtype}'`); + fields.push(`size: ${this.size}`); + } + for (let i = 0; i < fields.length; i++) { + fields[i] = ' ' + fields[i]; + } + + return 'TensorInfo {\n' + fields.join(',\n') + '\n}'; }; if (verbose) { @@ -900,7 +999,7 @@ function makeZerosTypedArray( } else if (dtype === 'bool') { return new Uint8Array(size); } else { - throw new Error(`Unknown data type ${dtype}`); + throw new Error(`Unknown data type $ {dtype}`); } } diff --git a/src/ops/binary_ops.ts b/src/ops/binary_ops.ts index a80ba9ba81..31f5e3eb52 100644 --- a/src/ops/binary_ops.ts +++ b/src/ops/binary_ops.ts @@ -30,6 +30,21 @@ export class Ops { * * We also expose `addStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = dl.tensor1d([1, 2, 3, 4]); + * const b = dl.tensor1d([10, 20, 30, 40]); + * + * a.add(b).print(); // or dl.add(a, b) + * ``` + * + * ```js + * // Broadcast add a with b. + * const a = dl.scalar(5); + * const b = dl.tensor1d([10, 20, 30, 40]); + * + * a.add(b).print(); // or dl.add(a, b) + * ``` * @param a The first `Tensor` to add. * @param b The second `Tensor` to add. Must have the same type as `a`. */ @@ -82,6 +97,20 @@ export class Ops { * We also expose `subStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * + * ```js + * const a = dl.tensor1d([10, 20, 30, 40]); + * const b = dl.tensor1d([1, 2, 3, 4]); + * + * a.sub(b).print(); // or dl.sub(a, b) + * ``` + * + * ```js + * // Broadcast subtract a with b. + * const a = dl.tensor1d([10, 20, 30, 40]); + * const b = dl.scalar(5); + * + * a.sub(b).print(); // or dl.sub(a, b) + * ``` * @param a The first `Tensor`. * @param b The second `Tensor`. Must have the same dtype as `a`. */ @@ -136,11 +165,18 @@ export class Ops { * corresponding elements in x and y. * * ```js - * const a = dl.tensor([[2, 2], [3, 3]]) - * const b = dl.tensor([[8, 16], [2, 3]]).toInt() - * dl.pow(a, b).print(); // [256, 65536, 9, 27] + * const a = dl.tensor([[2, 3], [4, 5]]) + * const b = dl.tensor([[1, 2], [3, 0]]).toInt(); + * + * a.pow(b).print(); // or dl.pow(a, b) * ``` * + * ```js + * const a = dl.tensor([[1, 2], [3, 4]]) + * const b = dl.tensor(2).toInt(); + * + * a.pow(b).print(); // or dl.pow(a, b) + * ``` * We also expose `powStrict` which has the same signature as this op and * asserts that `base` and `exp` are the same shape (does not broadcast). * @@ -156,7 +192,8 @@ export class Ops { broadcast_util.assertAndGetBroadcastShape(base.shape, exp.shape); const gradient = (dy: Tensor, y: Tensor) => { - if (!util.arraysEqual(base.shape, exp.shape)) { + if (!util.arraysEqual(base.shape, exp.shape) && + !util.isScalarShape(exp.shape)) { throw new Error( `Gradient of pow not yet supported for broadcasted shapes.`); } @@ -196,6 +233,20 @@ export class Ops { * We also expose `mulStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * + * ```js + * const a = dl.tensor1d([1, 2, 3, 4]); + * const b = dl.tensor1d([2, 3, 4, 5]); + * + * a.mul(b).print(); // or dl.mul(a, b) + * ``` + * + * ```js + * // Broadcast mul a with b. + * const a = dl.tensor1d([1, 2, 3, 4]); + * const b = dl.scalar(5); + * + * a.mul(b).print(); // or dl.mul(a, b) + * ``` * @param a The first tensor. * @param b The second tensor. Must have the same dtype as `a`. */ @@ -248,6 +299,20 @@ export class Ops { * We also expose `divStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * + * ```js + * const a = dl.tensor1d([1, 4, 9, 16]); + * const b = dl.tensor1d([1, 2, 3, 4]); + * + * a.div(b).print(); // or dl.div(a, b) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = dl.tensor1d([2, 4, 6, 8]); + * const b = dl.scalar(2); + * + * a.div(b).print(); // or dl.div(a, b) + * ``` * @param a The first tensor. * @param b The second tensor. Must have the same dtype as `a`. */ @@ -299,6 +364,20 @@ export class Ops { * We also expose `minimumStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * + * ```js + * const a = dl.tensor1d([1, 4, 3, 16]); + * const b = dl.tensor1d([1, 2, 9, 4]); + * + * a.minumum(b).print(); // or dl.minumum(a, b) + * ``` + * + * ```js + * // Broadcast minumum a with b. + * const a = dl.tensor1d([2, 4, 6, 8]); + * const b = dl.scalar(5); + * + * a.minumum(b).print(); // or dl.minumum(a, b) + * ``` * @param a The first tensor. * @param b The second tensor. Must have the same type as `a`. */ @@ -335,7 +414,20 @@ export class Ops { * We also expose `maximumStrict` which has the same signature as this op and * asserts that `a` and `b` are the same shape (does not broadcast). * + * ```js + * const a = dl.tensor1d([1, 4, 3, 16]); + * const b = dl.tensor1d([1, 2, 9, 4]); * + * a.maximum(b).print(); // or dl.maximum(a, b) + * ``` + * + * ```js + * // Broadcast maximum a with b. + * const a = dl.tensor1d([2, 4, 6, 8]); + * const b = dl.scalar(5); + * + * a.maximum(b).print(); // or dl.maximum(a, b) + * ``` * @param a The first tensor. * @param b The second tensor. Must have the same type as `a`. */ diff --git a/src/ops/concat.ts b/src/ops/concat.ts index 54ae3c6016..1e5427eb00 100644 --- a/src/ops/concat.ts +++ b/src/ops/concat.ts @@ -121,6 +121,25 @@ export class Concat { * The tensors ranks and types must match, and their sizes must match in all * dimensions except `axis`. * + * ```js + * const a = dl.tensor1d([1, 2]); + * const b = dl.tensor1d([3, 4]); + * a.concat(b).print(); // or a.concat(b) + * ``` + * + * ```js + * const a = dl.tensor1d([1, 2]); + * const b = dl.tensor1d([3, 4]); + * const c = dl.tensor1d([5, 6]); + * dl.concat([a, b, c]).print(); + * ``` + * + * ```js + * const a = dl.tensor2d([[1, 2], [10, 20]]); + * const b = dl.tensor2d([[3, 4], [30, 40]]); + * const axis = 1; + * dl.concat([a, b], axis).print(); + * ``` * @param tensors A list of tensors to concatenate. * @param axis The axis to concate along. Defaults to 0 (the first dim). */ diff --git a/src/ops/matmul.ts b/src/ops/matmul.ts index 32fab6191d..2ebf869be1 100644 --- a/src/ops/matmul.ts +++ b/src/ops/matmul.ts @@ -26,6 +26,12 @@ export class Ops { /** * Computes the dot product of two matrices, A * B. These must be matrices. * + * ```js + * const a = dl.tensor2d([1, 2], [1, 2]); + * const b = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * a.matMul(b).print(); // or dl.matMul(a, b) + * ``` * @param a First matrix in dot product operation. * @param b Second matrix in dot product operation. * @param transposeA If true, `a` is transposed before multiplication. @@ -74,7 +80,6 @@ export class Ops { * @param v The vector in dot product operation. * @param matrix The matrix in dot product operation. */ - @doc({heading: 'Operations', subheading: 'Matrices'}) @operation static vectorTimesMatrix(v: Tensor1D, matrix: Tensor2D): Tensor1D { util.assert( @@ -94,10 +99,10 @@ export class Ops { /** * Computes the dot product of a matrix and vector, A * v. + * * @param matrix The matrix in dot product operation. * @param v The vector in dot product operation. */ - @doc({heading: 'Operations', subheading: 'Matrices'}) @operation static matrixTimesVector(matrix: Tensor2D, v: Tensor1D): Tensor1D { util.assert( @@ -123,7 +128,6 @@ export class Ops { * @param v1 The first vector in the dot product operation. * @param v2 The second vector in the dot product operation. */ - @doc({heading: 'Operations', subheading: 'Matrices'}) @operation static dotProduct(v1: Tensor1D, v2: Tensor1D): Scalar { util.assert( @@ -140,6 +144,12 @@ export class Ops { /** * Computes the outer product of two vectors, v1 and v2. * + * ```js + * const a = dl.tensor1d([1, 2, 3]); + * const b = dl.tensor1d([3, 4, 5]); + * + * dl.outerProduct(a, b).print(); + * ``` * @param v1 The first vector in the outer product operation. * @param v2 The second vector in the dot product operation. */ diff --git a/src/ops/matmul_test.ts b/src/ops/matmul_test.ts index 509ecaa9ed..0db4c0c370 100644 --- a/src/ops/matmul_test.ts +++ b/src/ops/matmul_test.ts @@ -20,6 +20,8 @@ import * as dl from '../index'; import {ALL_ENVS, describeWithFlags, expectArraysClose, expectNumbersClose, WEBGL_ENVS} from '../test_util'; import {Rank} from '../types'; +import {Ops as MatmulOps} from './matmul'; + describeWithFlags('matmul', ALL_ENVS, () => { it('A x B', () => { const a = dl.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); @@ -193,7 +195,7 @@ describeWithFlags('matmul', ALL_ENVS, () => { it('Dot product', () => { const v1 = dl.tensor1d([2, 3]); const v2 = dl.tensor1d([2, 1]); - const result = dl.dotProduct(v1, v2); + const result = MatmulOps.dotProduct(v1, v2); expectNumbersClose(result.get(), 7); }); @@ -201,7 +203,7 @@ describeWithFlags('matmul', ALL_ENVS, () => { it('Dot product propagates NaNs', () => { const v1 = dl.tensor1d([2, NaN]); const v2 = dl.tensor1d([2, 1]); - const result = dl.dotProduct(v1, v2); + const result = MatmulOps.dotProduct(v1, v2); expect(result.get()).toEqual(NaN); }); @@ -209,8 +211,8 @@ describeWithFlags('matmul', ALL_ENVS, () => { const v1 = dl.tensor1d([2, 3, 3]); const v2 = dl.tensor1d([2, 1]); - expect(() => dl.dotProduct(v1, v2)).toThrowError(); - expect(() => dl.dotProduct(v2, v1)).toThrowError(); + expect(() => MatmulOps.dotProduct(v1, v2)).toThrowError(); + expect(() => MatmulOps.dotProduct(v2, v1)).toThrowError(); }); it('Dot product throws when passed non vectors', () => { @@ -218,8 +220,8 @@ describeWithFlags('matmul', ALL_ENVS, () => { const v1: any = dl.tensor2d([1, 2, 3, 3], [2, 2]); const v2 = dl.tensor1d([2, 1]); - expect(() => dl.dotProduct(v1, v2)).toThrowError(); - expect(() => dl.dotProduct(v2, v1)).toThrowError(); + expect(() => MatmulOps.dotProduct(v1, v2)).toThrowError(); + expect(() => MatmulOps.dotProduct(v2, v1)).toThrowError(); }); it('Outer product', () => { diff --git a/src/ops/norm.ts b/src/ops/norm.ts index 86d4fa8035..adc7c9f59c 100644 --- a/src/ops/norm.ts +++ b/src/ops/norm.ts @@ -29,6 +29,12 @@ export class Ops { * Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) * and matrix norms (Frobenius, 1-norm, and inf-norm). * + * ```js + * const x = dl.tensor1d([1, 2, 3, 4]); + * + * x.norm().print(); // or dl.norm(x) + * ``` + * * @param x The input array. * @param ord Optional. Order of the norm. Supported norm types are * following: @@ -42,7 +48,6 @@ export class Ops { * |1 |max(sum(abs(x), axis=0)) |sum(abs(x)) * |2 | |sum(abs(x)^2)^1/2* * - * * @param axis Optional. If axis is null (the default), the input is * considered a vector and a single vector norm is computed over the entire * set of values in the Tensor, i.e. norm(x, ord) is equivalent diff --git a/src/ops/ops.ts b/src/ops/ops.ts index e341a5b273..11a93dbcc8 100644 --- a/src/ops/ops.ts +++ b/src/ops/ops.ts @@ -51,7 +51,6 @@ export const conv2d = conv_ops.Ops.conv2d; export const conv2dTranspose = conv_ops.Ops.conv2dTranspose; export const depthwiseConv2d = conv_ops.Ops.depthwiseConv2d; -export const dotProduct = matmul_ops.Ops.dotProduct; export const matMul = matmul_ops.Ops.matMul; export const matrixTimesVector = matmul_ops.Ops.matrixTimesVector; export const outerProduct = matmul_ops.Ops.outerProduct; @@ -76,7 +75,6 @@ export const slice3d = slice_ops.Ops.slice3d; export const slice4d = slice_ops.Ops.slice4d; export const argMax = reduction_ops.Ops.argMax; -export const argMaxEquals = reduction_ops.Ops.argMaxEquals; export const argMin = reduction_ops.Ops.argMin; export const logSumExp = reduction_ops.Ops.logSumExp; export const max = reduction_ops.Ops.max; diff --git a/src/ops/reduction_ops.ts b/src/ops/reduction_ops.ts index 37d5ac6dd5..dec6d23e66 100644 --- a/src/ops/reduction_ops.ts +++ b/src/ops/reduction_ops.ts @@ -18,7 +18,7 @@ import {doc} from '../doc'; import {ENV} from '../environment'; import {customGrad} from '../globals'; -import {Scalar, Tensor} from '../tensor'; +import {Tensor} from '../tensor'; import * as util from '../util'; import * as axis_util from './axis_util'; import {operation} from './operation'; @@ -34,6 +34,18 @@ export class Ops { * If `axis` has no entries, all dimensions are reduced, and an array with a * single element is returned. * + * ```js + * const x = dl.tensor1d([1, 2, 3]); + * + * x.logSumExp().print(); // or dl.logSumExp(x) + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.logSumExp(axis).print(); // or dl.logSumExp(a, axis) + * ``` * @param input The input tensor. * @param axis The dimension(s) to reduce. If null (the default), * reduces all dimensions. @@ -68,6 +80,19 @@ export class Ops { * If axes has no entries, all dimensions are reduced, and a `Tensor` with a * single element is returned. * + * ```js + * const x = dl.tensor1d([1, 2, 3]); + * + * x.sum().print(); // or dl.logSumExp(x) + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.sum(axis).print(); // or dl.sum(x, axis) + * ``` + * * @param x The input tensor to compute the sum over. * @param axis The dimension(s) to reduce. By default it reduces * all dimensions. @@ -121,6 +146,19 @@ export class Ops { * If `axis` has no entries, all dimensions are reduced, and a `Tensor` with * a single element is returned. * + * ```js + * const x = dl.tensor1d([1, 2, 3]); + * + * x.mean().print(); // or dl.logSumExp(a) + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.mean(axis).print(); // or dl.mean(x, axis) + * ``` + * * @param x The input tensor. * @param axis The dimension(s) to reduce. By default it reduces * all dimensions. @@ -167,6 +205,19 @@ export class Ops { * If `axes` has no entries, all dimensions are reduced, and an array with a * single element is returned. * + * ```js + * const x = dl.tensor1d([1, 2, 3]); + * + * x.min().print(); // or dl.min(x) + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.min(axis).print(); // or dl.min(x, axis) + * ``` + * * @param x The input Tensor. * @param axis The dimension(s) to reduce. By default it reduces * all dimensions. @@ -201,6 +252,19 @@ export class Ops { * If `axes` has no entries, all dimensions are reduced, and an `Tensor` with * a single element is returned. * + * ```js + * const x = dl.tensor1d([1, 2, 3]); + * + * x.max().print(); // or dl.max(x) + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.max(axis).print(); // or dl.max(x, axis) + * ``` + * * @param x The input tensor. * @param axis The dimension(s) to reduce. By default it reduces * all dimensions. @@ -232,6 +296,19 @@ export class Ops { * The result has the same shape as `input` with the dimension along `axis` * removed. * + * ```js + * const x = dl.tensor1d([1, 2, 3]); + * + * x.argMin().print(); // or dl.argMin(x) + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 4, 3], [2, 2]); + * + * const axis = 1; + * x.argMin(axis).print(); // or dl.argMin(x, axis) + * ``` + * * @param x The input tensor. * @param axis The dimension to reduce. By default it reduces * across all axes and returns the flat index. @@ -255,6 +332,19 @@ export class Ops { * The result has the same shape as `input` with the dimension along `axis` * removed. * + * ```js + * const x = dl.tensor1d([1, 2, 3]); + * + * x.argMax().print(); // or dl.argMax(x) + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 4, 3], [2, 2]); + * + * const axis = 1; + * x.argMax(axis).print(); // or dl.argMax(x, axis) + * ``` + * * @param x The input tensor. * @param axis The dimension to reduce. By default it reduces * across all axes and returns the flat index @@ -272,19 +362,6 @@ export class Ops { return ENV.engine.executeKernel('ArgMax', {inputs: {x}, args: {axes}}) as T; } - /** - * Returns a 1 if the argMax of x1 and x2 are the same, otherwise 0. - * - * @param x1 The first input tensor. - * @param x2 The second input tensor. - */ - @doc({heading: 'Operations', subheading: 'Reduction'}) - @operation - static argMaxEquals(x1: Tensor, x2: Tensor): Scalar { - util.assertShapesMatch(x1.shape, x2.shape, 'Error in argMaxEquals: '); - return x1.argMax().equal(x2.argMax()); - } - /** * Calculates the mean and variance of `x`. The mean and variance are * calculated by aggregating the contents of `x` across `axes`. If `x` is diff --git a/src/ops/reduction_ops_test.ts b/src/ops/reduction_ops_test.ts index 1661a7e555..55c98b8825 100644 --- a/src/ops/reduction_ops_test.ts +++ b/src/ops/reduction_ops_test.ts @@ -259,35 +259,6 @@ describeWithFlags('argmin', ALL_ENVS, () => { }); }); -describeWithFlags('argMaxEquals', ALL_ENVS, () => { - it('equals', () => { - const a = dl.tensor1d([5, 0, 3, 7, 3]); - const b = dl.tensor1d([-100.3, -20.0, -10.0, -5, -100]); - const result = dl.argMaxEquals(a, b); - expect(result.get()).toBe(1); - }); - - it('not equals', () => { - const a = dl.tensor1d([5, 0, 3, 1, 3]); - const b = dl.tensor1d([-100.3, -20.0, -10.0, -5, 0]); - const result = dl.argMaxEquals(a, b); - expect(result.get()).toBe(0); - }); - - it('propagates NaNs', () => { - const a = dl.tensor1d([0, 3, 1, 3]); - const b = dl.tensor1d([NaN, -20.0, -10.0, -5]); - const result = dl.argMaxEquals(a, b); - assertIsNan(result.get(), result.dtype); - }); - - it('throws when given arrays of different shape', () => { - const a = dl.tensor1d([5, 0, 3, 7, 3, 10]); - const b = dl.tensor1d([-100.3, -20.0, -10.0, -5, -100]); - expect(() => dl.argMaxEquals(a, b)).toThrowError(); - }); -}); - describeWithFlags('logSumExp', ALL_ENVS, () => { it('0', () => { const a = dl.scalar(0); diff --git a/src/ops/reverse.ts b/src/ops/reverse.ts index eb8362c2bb..8a8cfdbf75 100644 --- a/src/ops/reverse.ts +++ b/src/ops/reverse.ts @@ -74,6 +74,18 @@ export class Ops { /** * Reverses a `Tensor` along a specified axis. * + * ```js + * const x = dl.tensor1d([1, 2, 3, 4]); + * + * x.reverse().print(); + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.reverse(axis).print(); + * ``` * @param x The input tensor. * @param axis The set of dimensions to reverse. Must be in the * range [-rank(x), rank(x)). diff --git a/src/ops/slice.ts b/src/ops/slice.ts index a87fbafded..16be1a4746 100644 --- a/src/ops/slice.ts +++ b/src/ops/slice.ts @@ -99,6 +99,18 @@ export class Ops { * - `dl.slice2d` * - `dl.slice3d` * - `dl.slice4d` + * + * ```js + * const x = dl.tensor1d([1, 2, 3, 4]); + * + * x.slice([1], [2]).print(); + * ``` + * + * ```js + * const x = dl.tensor2d([1, 2, 3, 4], [2, 2]); + * + * x.slice([1, 0], [1, 2]).print(); + * ``` * @param x The input `Tensor` to slice from. * @param begin The coordinates to start the slice from. The length of this * array should match the rank of `x`. diff --git a/src/ops/softmax.ts b/src/ops/softmax.ts index 786802fb46..7af3d92c43 100644 --- a/src/ops/softmax.ts +++ b/src/ops/softmax.ts @@ -28,6 +28,18 @@ export class Ops { /** * Computes the softmax normalized vector given the logits. * + * ```js + * const a = dl.tensor1d([1, 2, 3]); + * + * a.softmax().print(); // or dl.softmax(a) + * ``` + * + * ```js + * const a = dl.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]); + * + * a.softmax().print(); // or dl.softmax(a) + * ``` + * * @param logits The logits array. * @param dim The dimension softmax would be performed on. Defaults to -1 * which indicates the last dimension. diff --git a/src/ops/transpose.ts b/src/ops/transpose.ts index 5d941b6133..eec05d093f 100644 --- a/src/ops/transpose.ts +++ b/src/ops/transpose.ts @@ -32,6 +32,12 @@ export class Ops { * where `n` is the rank of the input `Tensor`. Hence by default, this * operation performs a regular matrix transpose on 2-D input `Tensor`s. * + * ```js + * const a = dl.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + * + * a.transpose().print(); // or dl.transpose(a) + * ``` + * * @param x The tensor to transpose. * @param perm The permutation of the dimensions of a. */ diff --git a/src/ops/unary_ops.ts b/src/ops/unary_ops.ts index f37f3acfaf..9f3c396612 100644 --- a/src/ops/unary_ops.ts +++ b/src/ops/unary_ops.ts @@ -28,6 +28,12 @@ export class Ops { /** * Computes `-1 * x` element-wise. * + * ```js + * const x = dl.tensor2d([1, 2, -2, 0], [2, 2]); + * + * x.neg().print(); // or dl.neg(x) + * ``` + * * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -41,6 +47,11 @@ export class Ops { /** * Computes ceiling of input `Tensor` element-wise: `ceil(x)` * + * ```js + * const x = dl.tensor1d([.6, 1.1, -3.3]); + * + * x.ceil().print(); // or dl.ceil(x) + * ``` * @param x The input Tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -55,6 +66,12 @@ export class Ops { /** * Computes floor of input `Tensor` element-wise: `floor(x)`. + * + * ```js + * const x = dl.tensor1d([.6, 1.1, -3.3]); + * + * x.floor().print(); // or dl.floor(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -70,6 +87,12 @@ export class Ops { /** * Computes exponential of the input `Tensor` element-wise. `e ^ x` + * + * ```js + * const x = dl.tensor1d([1, 2, -3]); + * + * x.exp().print(); // or dl.exp(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -82,6 +105,12 @@ export class Ops { /** * Computes natural logarithm of the input `Tensor` element-wise: `ln(x)` + * + * ```js + * const x = dl.tensor1d([1, 2, Math.E]); + * + * x.log().print(); // or dl.log(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -94,6 +123,12 @@ export class Ops { /** * Computes square root of the input `Tensor` element-wise: `y = sqrt(x)` + * + * ```js + * const x = dl.tensor1d([1, 2, 4, -1]); + * + * x.sqrt().print(); // or dl.sqrt(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -107,6 +142,11 @@ export class Ops { /** * Computes square of `x` element-wise: `x ^ 2` * + * ```js + * const x = dl.tensor1d([1, 2, Math.sqrt(2), -1]); + * + * x.square().print(); // or dl.square(x) + * ``` * @param x The input Tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -120,6 +160,11 @@ export class Ops { /** * Computes absolute value element-wise: `abs(x)` * + * ```js + * const x = dl.tensor1d([-1, 2, -3, 4]); + * + * x.abs().print(); // or dl.abs(x) + * ``` * @param x The input `Tensor`. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -133,6 +178,11 @@ export class Ops { /** * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)` * + * ```js + * const x = dl.tensor1d([-1, 2, -3, 4]); + * + * x.clipByValue(-2, 3).print(); // or dl.clipByValue(x, -2, 3) + * ``` * @param x The input tensor. * @param clipValueMin Lower-bound of range to be clipped to. * @param clipValueMax Upper-bound of range to be clipped to. @@ -163,6 +213,11 @@ export class Ops { /** * Computes rectified linear element-wise: `max(x, 0)` * + * ```js + * const x = dl.tensor1d([-1, 2, -3, 4]); + * + * x.relu().print(); // or dl.relu(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -177,6 +232,11 @@ export class Ops { /** * Computes exponential linear element-wise, `x > 0 ? e ^ x - 1 : 0` * + * ```js + * const x = dl.tensor1d([-1, 1, -3, 2]); + * + * x.elu().print(); // or dl.elu(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -200,6 +260,11 @@ export class Ops { * * `x < 0 ? scale * alpha * (exp(x) - 1) : x` * + * ```js + * const x = dl.tensor1d([-1, 2, -3, 4]); + * + * x.selu().print(); // or dl.selu(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -232,6 +297,11 @@ export class Ops { * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf]( * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf) * + * ```js + * const x = dl.tensor1d([-1, 2, -3, 4]); + * + * x.leakyRelu(0.1).print(); // or dl.leakyRelu(x, 0.1) + * ``` * @param x The input tensor. * @param alpha The scaling factor for negative values, defaults to 0.2. */ @@ -250,6 +320,12 @@ export class Ops { * * `x < 0 ? alpha * x : f(x) = x` * + * ```js + * const x = dl.tensor1d([-1, 2, -3, 4]); + * const alpha = dl.scalar(0.1); + * + * x.prelu(alpha).print(); // or dl.prelu(x, alpha) + * ``` * @param x The input tensor. * @param alpha Scaling factor for negative values. */ @@ -272,6 +348,11 @@ export class Ops { /** * Computes sigmoid element-wise, `1 / (1 + exp(-x))` * + * ```js + * const x = dl.tensor1d([0, -1, 2, -3]); + * + * x.sigmoid().print(); // or dl.sigmoid(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -285,6 +366,11 @@ export class Ops { /** * Computes sin of the input Tensor element-wise: `sin(x)` * + * ```js + * const x = dl.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.sin().print(); // or dl.sin(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -298,6 +384,11 @@ export class Ops { /** * Computes cos of the input `Tensor` element-wise: `cos(x)` * + * ```js + * const x = dl.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.cos().print(); // or dl.cos(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -311,6 +402,11 @@ export class Ops { /** * Computes tan of the input `Tensor` element-wise, `tan(x)` * + * ```js + * const x = dl.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.tan().print(); // or dl.tan(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -324,6 +420,11 @@ export class Ops { /** * Computes asin of the input `Tensor` element-wise: `asin(x)` * + * ```js + * const x = dl.tensor1d([0, 1, -1, .7]); + * + * x.asin().print(); // or dl.asin(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -339,6 +440,11 @@ export class Ops { /** * Computes acos of the input `Tensor` element-wise: `acos(x)` * + * ```js + * const x = dl.tensor1d([0, 1, -1, .7]); + * + * x.acos().print(); // or dl.acos(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -354,6 +460,11 @@ export class Ops { /** * Computes atan of the input `Tensor` element-wise: `atan(x)` * + * ```js + * const x = dl.tensor1d([0, 1, -1, .7]); + * + * x.atan().print(); // or dl.atan(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -367,6 +478,11 @@ export class Ops { /** * Computes hyperbolic sin of the input `Tensor` element-wise: `sinh(x)` * + * ```js + * const x = dl.tensor1d([0, 1, -1, .7]); + * + * x.sinh().print(); // or dl.sinh(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -379,6 +495,12 @@ export class Ops { /** * Computes hyperbolic cos of the input `Tensor` element-wise: `cosh(x)` + * + * ```js + * const x = dl.tensor1d([0, 1, -1, .7]); + * + * x.cosh().print(); // or dl.cosh(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -391,6 +513,12 @@ export class Ops { /** * Computes hyperbolic tangent of the input `Tensor` element-wise: `tanh(x)` + * + * ```js + * const x = dl.tensor1d([0, 1, -1, 70]); + * + * x.tanh().print(); // or dl.tanh(x) + * ``` * @param x The input tensor. */ @doc({heading: 'Operations', subheading: 'Basic math'}) @@ -404,6 +532,11 @@ export class Ops { /** * Computes step of the input `Tensor` element-wise: `x > 0 ? 1 : alpha * x` * + * ```js + * const x = dl.tensor1d([0, 2, -1, -3]); + * + * x.step(.5).print(); // or dl.step(x, .5) + * ``` * @param x The input tensor. * @param alpha The gradient when input is negative. */ diff --git a/src/ops/unary_ops_test.ts b/src/ops/unary_ops_test.ts index 62f18f2f8e..1effb1c127 100644 --- a/src/ops/unary_ops_test.ts +++ b/src/ops/unary_ops_test.ts @@ -1436,7 +1436,7 @@ describeWithFlags('clip', ALL_ENVS, () => { const max = 2; const x = dl.tensor1d([3, -2, 1]); // Only 1 is not clipped. const dy = dl.tensor1d([5, 50, 500]); - const gradients = dl.grad(x => x.clip(min, max))(x, dy); + const gradients = dl.grad(x => x.clipByValue(min, max))(x, dy); expect(gradients.shape).toEqual(x.shape); expect(gradients.dtype).toEqual('float32'); @@ -1448,7 +1448,7 @@ describeWithFlags('clip', ALL_ENVS, () => { const max = 2; const x = dl.scalar(-10); // Clipped. const dy = dl.scalar(5); - const gradients = dl.grad(x => x.clip(min, max))(x, dy); + const gradients = dl.grad(x => x.clipByValue(min, max))(x, dy); expect(gradients.shape).toEqual(x.shape); expect(gradients.dtype).toEqual('float32'); diff --git a/src/optimizers/optimizer_constructors.ts b/src/optimizers/optimizer_constructors.ts index fbd9eee072..55eac661a3 100644 --- a/src/optimizers/optimizer_constructors.ts +++ b/src/optimizers/optimizer_constructors.ts @@ -29,6 +29,36 @@ export class OptimizerConstructors { /** * Constructs a `SGDOptimizer` that uses stochastic gradient descent. * + * ```js + * // Fit a quadratic function by learning the coefficients a, b, c. + * const xs = dl.tensor1d([0, 1, 2, 3]); + * const ys = dl.tensor1d([1.1, 5.9, 16.8, 33.9]); + * + * const a = dl.variable(dl.scalar(Math.random())); + * const b = dl.variable(dl.scalar(Math.random())); + * const c = dl.variable(dl.scalar(Math.random())); + * + * // y = a * x^2 + b * x + c. + * const f = x => a.mul(x.square()).add(b.mul(x)).add(c); + * const loss = (pred, label) => pred.sub(label).square().mean(); + * + * const learningRate = 0.01; + * const optimizer = dl.train.sgd(learningRate); + * + * // Train the model. + * for (let i = 0; i < 10; i++) { + * optimizer.minimize(() => loss(f(xs), ys)); + * } + * + * // Make predictions. + * console.log( + * `a: ${a.dataSync()}, b: ${b.dataSync()}, c: ${c.dataSync()}`); + * const preds = f(xs).dataSync(); + * preds.forEach((pred, i) => { + * console.log(`x: ${i}, pred: ${pred}`); + * }); + * ``` + * * @param learningRate The learning rate to use for the SGD algorithm. */ @doc({heading: 'Training', subheading: 'Optimizers', namespace: 'train'}) diff --git a/src/tensor.ts b/src/tensor.ts index c846d4450e..eaf591e2eb 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -511,6 +511,12 @@ export class Tensor { this.throwIfDisposed(); return ops.matMul(this as Tensor2D, b, transposeA, transposeB); } + norm( + ord: number|'euclidean'|'fro' = 'euclidean', axis: number|number[] = null, + keepDims = false): Tensor { + this.throwIfDisposed(); + return ops.norm(this, ord, axis, keepDims); + } slice(begin: ShapeMap[R], size: ShapeMap[R]): Tensor { this.throwIfDisposed(); return ops.slice(this, begin, size); @@ -570,10 +576,6 @@ export class Tensor { this.throwIfDisposed(); return ops.argMax(this, axis); } - argMaxEquals(x: Tensor): Scalar { - this.throwIfDisposed(); - return ops.argMaxEquals(this, x); - } // Binary ops. @@ -740,7 +742,7 @@ export class Tensor { this.throwIfDisposed(); return ops.abs(this); } - clip(min: number, max: number): Tensor { + clipByValue(min: number, max: number): Tensor { this.throwIfDisposed(); return ops.clipByValue(this, min, max); } @@ -959,6 +961,12 @@ export class Variable extends Tensor { /** * Creates a new variable with the provided initial value. + * ```js + * const x = dl.variable(dl.tensor([1, 2, 3])); + * x.assign(dl.tensor([4, 5, 6])); + * + * x.print(); + * ``` * * @param initialValue A tensor. * @param trainable If true, optimizers are allowed to update it. @@ -976,8 +984,8 @@ export class Variable extends Tensor { } /** - * Assign a new `Tensor` to this variable. The new `Tensor` must have the same - * shape and dtype as the old `Tensor`. + * Assign a new `Tensor` to this variable. The new `Tensor` must have the + * same shape and dtype as the old `Tensor`. */ @doc({heading: 'Tensors', subheading: 'Classes'}) assign(newValue: Tensor): void { diff --git a/src/tracking.ts b/src/tracking.ts index 4a9df254f1..fb383f7388 100644 --- a/src/tracking.ts +++ b/src/tracking.ts @@ -34,6 +34,25 @@ export class Tracking { * When in safe mode, you must enclose all `Tensor` creation and ops * inside a `tidy` to prevent memory leaks. * + * ```js + * // y = 2 ^ 2 + 1 + * const y = dl.tidy(() => { + * // a, b, and one will be cleaned up when the tidy ends. + * const one = dl.scalar(1); + * const a = dl.scalar(2); + * const b = a.square(); + * + * console.log('numTensors (in tidy): ' + dl.memory().numTensors); + * + * // The value returned inside the tidy function will return + * // through the tidy, in this case to the variable y. + * return b.add(one); + * }); + * + * console.log('numTensors (outside tidy): ' + dl.memory().numTensors); + * y.print(); + * ``` + * * @param nameOrFn The name of the closure, or the function to execute. * If a name is provided, the 2nd argument should be the function. * If a name is provided, and debug mode is on, the timing and the memory @@ -79,9 +98,34 @@ export class Tracking { } /** - * Keeps a Tensor generated inside a `tidy` from being disposed + * Keeps a `Tensor` generated inside a `tidy` from being disposed * automatically. - * @param result The Tensor to keep from being disposed. + * + * ```js + * let b; + * const y = dl.tidy(() => { + * const one = dl.scalar(1); + * const a = dl.scalar(2); + * + * // b will not be cleaned up by the tidy. a and one will be cleaned up + * // when the tidy ends. + * b = dl.keep(a.square()); + * + * console.log('numTensors (in tidy): ' + dl.memory().numTensors); + * + * // The value returned inside the tidy function will return + * // through the tidy, in this case to the variable y. + * return b.add(one); + * }); + * + * console.log('numTensors (outside tidy): ' + dl.memory().numTensors); + * console.log('y:'); + * y.print(); + * console.log('b:'); + * b.print(); + * ``` + * + * @param result The tensor to keep from being disposed. */ @doc({heading: 'Performance', subheading: 'Memory'}) static keep(result: T): T { @@ -99,6 +143,13 @@ export class Tracking { * - `uploadWaitMs`: cpu blocking time on texture uploads. * - `downloadWaitMs`: cpu blocking time on texture downloads (readPixels). * + * ```js + * const x = dl.randomNormal([20, 20]); + * const time = await dl.time(() => x.matMul(x)); + * + * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`); + * ``` + * * @param f The function to execute and time. */ @doc({heading: 'Performance', subheading: 'Timing'}) diff --git a/website/api/index.html b/website/api/index.html index c7c6fd59c1..6ecbf26c3a 100644 --- a/website/api/index.html +++ b/website/api/index.html @@ -241,10 +241,6 @@ padding-bottom: 16px; } - .reference .symbol .documentation p { - margin: 0 - } - .parameter { margin-bottom: 24px; } @@ -302,7 +298,7 @@ background: #f4f4f4; padding: 8px 16px; min-height: 16px; - white-space: pre; + white-space: pre-wrap; font-family: "Roboto Mono", monospace; } @@ -495,7 +491,7 @@
Methods: