diff --git a/tfjs-layers/src/engine/topology.ts b/tfjs-layers/src/engine/topology.ts index 0bc4fb843a4..64d37b3cb11 100644 --- a/tfjs-layers/src/engine/topology.ts +++ b/tfjs-layers/src/engine/topology.ts @@ -944,9 +944,9 @@ export abstract class Layer extends serialization.Serializable { * @doc {heading: 'Models', 'subheading': 'Classes'} */ // Porting Note: This is a replacement for __call__() in Python. - apply( - inputs: Tensor|Tensor[]|SymbolicTensor|SymbolicTensor[], - kwargs?: Kwargs): Tensor|Tensor[]|SymbolicTensor|SymbolicTensor[] { + apply( + inputs: Tensor|Tensor[]|SymbolicTensor|SymbolicTensor[], + kwargs?: Kwargs): Tensor|Tensor[]|SymbolicTensor|SymbolicTensor[] { kwargs = kwargs || {}; this.assertNotDisposed(); @@ -954,131 +954,83 @@ export abstract class Layer extends serialization.Serializable { // Ensure inputs are all the same type. const inputsList = generic_utils.toList(inputs); + // Check if inputs are Tensors or SymbolicTensors. const allAreSymbolic = checkAllSymbolic(inputs); const noneAreSymbolic = checkNoneSymbolic(inputs); if (allAreSymbolic === noneAreSymbolic) { throw new ValueError( - 'Arguments to apply() must be all ' + - 'SymbolicTensors or all Tensors'); + 'Arguments to apply() must be all SymbolicTensors or all Tensors' + ); } - // TODO(michaelterry): nameScope() may not be necessary. - return nameScope(this.name, () => { - // Handle laying building (weight creating, input spec locking). - if (!this.built) { - /* - Throw exceptions in case the input is not compatible - with the inputSpec specified in the layer constructor. - */ - this.assertInputCompatibility(inputs); - - // Collect input shapes to build layer. - const inputShapes: Shape[] = []; - for (const xElem of generic_utils.toList(inputs)) { - inputShapes.push(xElem.shape); - } - this.build(generic_utils.singletonOrArray(inputShapes)); - this.built = true; + // Determine if the execution should be synchronous or asynchronous + if (noneAreSymbolic) { + // Synchronous execution for Tensors + console.log('Processing Tensors synchronously...'); + return nameScope(this.name, () => { + // Original logic for processing tensors + if (!this.built) { + this.assertInputCompatibility(inputs); + const inputShapes: Shape[] = []; + for (const xElem of generic_utils.toList(inputs)) { + inputShapes.push(xElem.shape); + } + this.build(generic_utils.singletonOrArray(inputShapes)); + this.built = true; - // Load weights that were specified at layer instantiation. - if (this.initialWeights) { - this.setWeights(this.initialWeights); - } + if (this.initialWeights) { + this.setWeights(this.initialWeights); + } - if (this._refCount === null && noneAreSymbolic) { - // The first use of this layer is a non-symbolic call, set ref count - // to 1 so the Layer can be properly disposed if its dispose() method - // is called. - this._refCount = 1; + if (this._refCount === null) { + this._refCount = 1; + } } - } - - /* - Throw exceptions in case the input is not compatible - with the inputSpec set at build time. - */ - this.assertInputCompatibility(inputs); - // Handle mask propagation. - // TODO(michaelterry): Mask propagation not currently implemented. + this.assertInputCompatibility(inputs); - // Actually call the layer, collecting output(s), mask(s), and shape(s). - if (noneAreSymbolic) { let output = this.call(inputs, kwargs); - - // Apply masks to the output tensors if the layer supports it. - if (this.supportsMasking) { - // TODO(mattsoulanille): pass the input tensors' masks to computeMask - this.setMaskMetadata(inputs, output); - } - - // If the layer returns tensors from its inputs, unmodified, - // we copy them to avoid loss of tensor metadata. + const outputList: Tensor[] = generic_utils.toList(output); const outputListCopy: Tensor[] = []; - // TODO(michaelterry): This copying may not be necessary given our eager - // backend. for (let x of outputList) { if (inputsList.indexOf(x) !== -1) { x = x.clone(); } outputListCopy.push(x); } - output = generic_utils.singletonOrArray(outputListCopy); - - if (this.activityRegularizer != null) { - throw new NotImplementedError( - 'Layer invocation in the presence of activity ' + - 'regularizer(s) is not supported yet.'); - } - // TODO(michaelterry): Call addInboundNode()? + output = generic_utils.singletonOrArray(outputListCopy); return output; - } else { + }); + } else { + // Asynchronous execution for SymbolicTensors + console.log('Processing SymbolicTensors asynchronously...'); + return nameScope(this.name, () => { const inputShape = collectInputShape(inputs); const outputShape = this.computeOutputShape(inputShape); let output: SymbolicTensor|SymbolicTensor[]; const outputDType = guessOutputDType(inputs); - this.warnOnIncompatibleInputShape( - Array.isArray(inputs) ? inputShape[0] as Shape : - inputShape as Shape); - - if (outputShape != null && outputShape.length > 0 && - Array.isArray(outputShape[0])) { - // We have multiple output shapes. Create multiple output tensors. - output = (outputShape as Shape[]) - .map( - (shape, index) => new SymbolicTensor( - outputDType, shape, this, - generic_utils.toList(inputs), kwargs, this.name, - index)); + + if (outputShape != null && outputShape.length > 0 + && Array.isArray(outputShape[0])) { + output = (outputShape as Shape[]).map((shape, index) => + new SymbolicTensor(outputDType, shape, this, + generic_utils.toList(inputs), kwargs, this.name, index) + ); } else { - output = new SymbolicTensor( - outputDType, outputShape as Shape, this, - generic_utils.toList(inputs), kwargs, this.name); + output = new SymbolicTensor(outputDType, outputShape as + Shape, this, generic_utils.toList(inputs), kwargs, this.name); } - /* - Add an inbound node to the layer, so that it keeps track - of the call and of all new variables created during the call. - This also updates the layer history of the output tensor(s). - If the input tensor(s) had no previous history, - this does nothing. - */ - this.addInboundNode( - inputs, output, null, null, inputShape, outputShape, kwargs); + this.addInboundNode(inputs, output, null, null, inputShape, + outputShape, kwargs); this._refCount++; - if (this.activityRegularizer != null) { - throw new NotImplementedError( - 'Layer invocation in the presence of activity ' + - 'regularizer(s) is not supported yet.'); - } - return output; - } - }); + }); + } } /**