Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Graph] change of inplace-type setting method #2794

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 69 additions & 153 deletions nntrainer/graph/network_graph.cpp

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions nntrainer/layers/flatten_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class FlattenLayer : public ReshapeLayer {
/**
* @brief Constructor of Flatten Layer
*/
FlattenLayer() : ReshapeLayer(), flatten_props(
props::StartDimension(), props::EndDimension()) {}
FlattenLayer() :
ReshapeLayer(),
flatten_props(props::StartDimension(), props::EndDimension()) {}

/**
* @brief Destructor of Flatten Layer
Expand Down
12 changes: 12 additions & 0 deletions nntrainer/layers/identity_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace nntrainer {

/**
* @class Identity Layer
* @brief Identity Layer
* @note Identity layers takes multiple tensors as input, redirects to output
* without doing nothing (or if unavoidable, copying)
*/
Expand Down Expand Up @@ -73,6 +74,17 @@ class IdentityLayer final : public Layer {
*/
bool supportInPlace() const override { return true; }

/**
* @brief Initialize the in-place type of the layer
* @return InPlaceType
*/
InPlaceType initializeInPlaceType() final {
if (!supportInPlace())
return InPlaceType::NONE;
else
return InPlaceType::RESTRICTING;
}

/**
* @copydoc Layer::getType()
*/
Expand Down
28 changes: 28 additions & 0 deletions nntrainer/layers/layer_devel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ class InitLayerContext;
class RunLayerContext;
class Exporter;

/**
* @brief Enum class for the various types of inplace modes supported by layer
*
*/
enum class InPlaceType {
NONE, /**< layer is not inplace */
RESTRICTING, /**< layer is in-place and does place restriction on layers
ahead of it to be in-place */
NON_RESTRICTING /**< layer is in-place and does NOT place restriction on the
layers ahead of it to be in-place */
};

/**
* @class Layer Base class for layers
* @brief Base class for all layers
Expand Down Expand Up @@ -248,6 +260,22 @@ class Layer {
*/
virtual bool supportInPlace() const { return false; }

/**
* @brief Initialize the in-place type of the layer
* @details If it is a layer that supports in-place, the default in-place type
* is NONE_RESTRICTING, but if there is a RESTRICTING type among the input
* layers, it is set to NONE in the network_graph.cpp.
* Layers with exceptional behavior such as No-Operation layers should
* override this function.
* @return InPlaceType
*/
virtual InPlaceType initializeInPlaceType() {
if (!supportInPlace())
return InPlaceType::NONE;
else
return InPlaceType::NON_RESTRICTING;
}

/**
* @brief check if this layer requires label to be passed
* @note if requireLabel() == true means, for now, that it is endpoint of a
Expand Down
9 changes: 9 additions & 0 deletions nntrainer/layers/layer_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,15 @@ bool LayerNode::supportInPlace() const {
return layer->supportInPlace();
}

/**
* @brief Initialize the in-place type of the layer
* @return InPlaceType
*/
InPlaceType LayerNode::initializeInPlaceType() {
inplace_type = layer->initializeInPlaceType();
return inplace_type;
}

/**
* @brief check if this layer requires label to be passed
*/
Expand Down
22 changes: 10 additions & 12 deletions nntrainer/layers/layer_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,6 @@ class Packed;
class LossScaleForMixed;
} // namespace props

/**
* @brief Enum class for the various types of inplace modes supported by layer
*
*/
enum class InPlaceType {
NONE, /**< layer is not inplace */
RESTRICTING, /**< layer is in-place and does place restriction on layers
ahead of it to be in-place */
NON_RESTRICTING /**< layer is in-place and does NOT place restriction on the
layers ahead of it to be in-place */
};

/**
* @class LayerNode class
* @brief layer node class for the graph
Expand Down Expand Up @@ -365,6 +353,16 @@ class LayerNode final : public ml::train::Layer, public GraphNode {
*/
bool supportInPlace() const;

/**
* @brief Initialize the in-place type of the layer
* @details If it is a layer that supports in-place, the default in-place type
* is NONE_RESTRICTING, but if there is a RESTRICTING type among the input
* layers, it is set to NONE in the network_graph.cpp.
* Layers with exceptional behavior such as No-Operation layers should
* override this function.
* @return InPlaceType
*/
InPlaceType initializeInPlaceType();
/**
* @brief Notify that this layer will execute in-place
*
Expand Down
11 changes: 11 additions & 0 deletions nntrainer/layers/multiout_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ class MultiOutLayer : public Layer {
*/
bool supportBackwarding() const override { return true; };

/**
* @brief Initialize the in-place type of the layer
* @return InPlaceType
*/
InPlaceType initializeInPlaceType() final {
if (!supportInPlace())
return InPlaceType::NONE;
else
return InPlaceType::RESTRICTING;
}

/**
* @copydoc Layer::supportInPlace()
*/
Expand Down
11 changes: 11 additions & 0 deletions nntrainer/layers/reshape_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ class ReshapeLayer : public Layer {
*/
bool supportInPlace() const override { return true; }

/**
* @brief Initialize the in-place type of the layer
* @return InPlaceType
*/
InPlaceType initializeInPlaceType() final {
if (!supportInPlace())
return InPlaceType::NONE;
else
return InPlaceType::RESTRICTING;
}

/**
* @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
* method)
Expand Down
Loading