diff --git a/Applications/Custom/LayerClient/jni/main.cpp b/Applications/Custom/LayerClient/jni/main.cpp index b655fce4c7..1608622048 100644 --- a/Applications/Custom/LayerClient/jni/main.cpp +++ b/Applications/Custom/LayerClient/jni/main.cpp @@ -151,7 +151,7 @@ int api_model_run() { /// creating array of layers same as in `custom_layer_client.ini` layers = std::vector>{ ml::train::layer::Input({"name=inputlayer", "input_shape=1:1:100"}), - ml::train::createLayer("pow", {"name=powlayer", "exponent=3"}), + ml::train::createLayer("custom_pow", {"name=powlayer", "exponent=3"}), ml::train::layer::FullyConnected( {"name=outputlayer", "input_layers=powlayer", "unit=10", "bias_initializer=zeros", "activation=softmax"}), diff --git a/Applications/Custom/LayerClient/res/custom_layer_client.ini b/Applications/Custom/LayerClient/res/custom_layer_client.ini index 1197e7f528..692701905e 100644 --- a/Applications/Custom/LayerClient/res/custom_layer_client.ini +++ b/Applications/Custom/LayerClient/res/custom_layer_client.ini @@ -17,7 +17,7 @@ Input_Shape = 1:1:100 [powlayer] input_layers = inputlayer -Type = pow # AppContext sees PowLayer::getType() and use this to parse type +Type = custom_pow # AppContext sees PowLayer::getType() and use this to parse type exponent = 3 # registering a custom property is done at int PowLayer::setProperty [outputlayer] diff --git a/Applications/Custom/LayerPlugin/layer_plugin_pow_test.cpp b/Applications/Custom/LayerPlugin/layer_plugin_pow_test.cpp index 37cd463bbd..def5492cc8 100644 --- a/Applications/Custom/LayerPlugin/layer_plugin_pow_test.cpp +++ b/Applications/Custom/LayerPlugin/layer_plugin_pow_test.cpp @@ -21,7 +21,7 @@ GTEST_PARAMETER_TEST(PowLayer, LayerPluginCommonTest, ::testing::Values(std::make_tuple("libpow_layer.so", - "pow"))); + "custom_pow"))); auto semantic_pow = LayerSemanticsParamType(nntrainer::createLayer, diff --git a/Applications/Custom/LayerPlugin/layer_plugin_test.cpp b/Applications/Custom/LayerPlugin/layer_plugin_test.cpp index 9e4997e172..03c1e7ec7f 100644 --- a/Applications/Custom/LayerPlugin/layer_plugin_test.cpp +++ b/Applications/Custom/LayerPlugin/layer_plugin_test.cpp @@ -29,9 +29,9 @@ TEST(AppContext, DlRegisterOpen_p) { ac.registerLayer("libpow_layer.so", NNTRAINER_PATH); - auto layer = ac.createObject("pow"); + auto layer = ac.createObject("custom_pow"); - EXPECT_EQ(layer->getType(), "pow"); + EXPECT_EQ(layer->getType(), "custom_pow"); } TEST(AppContext, DlRegisterWrongPath_n) { @@ -49,9 +49,9 @@ TEST(AppContext, DlRegisterDirectory_p) { ac.registerPluggableFromDirectory(NNTRAINER_PATH); - auto layer = ac.createObject("pow"); + auto layer = ac.createObject("custom_pow"); - EXPECT_EQ(layer->getType(), "pow"); + EXPECT_EQ(layer->getType(), "custom_pow"); } TEST(AppContext, DlRegisterDirectory_n) { @@ -64,8 +64,8 @@ TEST(AppContext, DlRegisterDirectory_n) { TEST(AppContext, DefaultEnvironmentPath_p) { /// as NNTRAINER_PATH is fed to the test, this should success without an /// error - std::shared_ptr l = ml::train::createLayer("pow"); - EXPECT_EQ(l->getType(), "pow"); + std::shared_ptr l = ml::train::createLayer("custom_pow"); + EXPECT_EQ(l->getType(), "custom_pow"); std::shared_ptr lnode = std::static_pointer_cast(l); diff --git a/Applications/Custom/pow.h b/Applications/Custom/pow.h index 4fcc2e2f29..1d3a05f01c 100644 --- a/Applications/Custom/pow.h +++ b/Applications/Custom/pow.h @@ -77,7 +77,7 @@ class PowLayer final : public nntrainer::Layer { */ void setProperty(const std::vector &values) override; - inline static const std::string type = "pow"; + inline static const std::string type = "custom_pow"; private: float exponent; diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index a492cd9f89..3cdc82fd11 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -41,6 +41,7 @@ enum LayerType { LAYER_SUBTRACT = ML_TRAIN_LAYER_TYPE_SUBTRACT, /**< Subtract Layer type */ LAYER_MULTIPLY = ML_TRAIN_LAYER_TYPE_MULTIPLY, /**< Multiply Layer type */ LAYER_DIVIDE = ML_TRAIN_LAYER_TYPE_DIVIDE, /**< Divide Layer type */ + LAYER_POW = ML_TRAIN_LAYER_TYPE_POW, /**< Pow Layer type */ LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */ LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */ LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */ @@ -337,6 +338,14 @@ DivideLayer(const std::vector &properties = {}) { return createLayer(LayerType::LAYER_DIVIDE, properties); } +/** + * @brief Helper function to create pow layer + */ +inline std::unique_ptr +PowLayer(const std::vector &properties = {}) { + return createLayer(LayerType::LAYER_POW, properties); +} + /** * @brief Helper function to create fully connected layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index 46c51c88ea..d820aecda5 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -62,14 +62,15 @@ typedef enum { 27, /**< Layer Normalization Layer type (Since 7.0) */ ML_TRAIN_LAYER_TYPE_POSITIONAL_ENCODING = 28, /**< Positional Encoding Layer type (Since 7.0) */ - ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */ - ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ - ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ - ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/ - ML_TRAIN_LAYER_TYPE_SUBTRACT = 33, /**< Subtract Layer type (Since 9.0)*/ - ML_TRAIN_LAYER_TYPE_MULTIPLY = 34, /**< Multiply Layer type (Since 9.0)*/ - ML_TRAIN_LAYER_TYPE_DIVIDE = 35, /**< Divide Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */ + ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ + ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_SUBTRACT = 33, /**< Subtract Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_MULTIPLY = 34, /**< Multiply Layer type (Since 9.0)*/ + ML_TRAIN_LAYER_TYPE_DIVIDE = 35, /**< Divide Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_TRANSPOSE = 36, /**< Transpose Layer type */ + ML_TRAIN_LAYER_TYPE_POW = 37, /**< Pow Layer type (Since 9.0)*/ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/app_context.cpp b/nntrainer/app_context.cpp index 3abec3d7c6..5c4ebf13be 100644 --- a/nntrainer/app_context.cpp +++ b/nntrainer/app_context.cpp @@ -70,6 +70,7 @@ #include #include #include +#include #include #include #include @@ -269,6 +270,8 @@ static void add_default_object(AppContext &ac) { LayerType::LAYER_MULTIPLY); ac.registerFactory(nntrainer::createLayer, DivideLayer::type, LayerType::LAYER_DIVIDE); + ac.registerFactory(nntrainer::createLayer, PowLayer::type, + LayerType::LAYER_POW); ac.registerFactory(nntrainer::createLayer, FullyConnectedLayer::type, LayerType::LAYER_FC); ac.registerFactory(nntrainer::createLayer, diff --git a/nntrainer/layers/common_properties.cpp b/nntrainer/layers/common_properties.cpp index 755f4407c6..563696248a 100644 --- a/nntrainer/layers/common_properties.cpp +++ b/nntrainer/layers/common_properties.cpp @@ -88,6 +88,8 @@ InputConnection::InputConnection(const Connection &value) : Epsilon::Epsilon(float value) { set(value); } +Exponent::Exponent(float value) { set(value); } + bool Epsilon::isValid(const float &value) const { return value > 0.0f; } Momentum::Momentum(float value) { set(value); } diff --git a/nntrainer/layers/common_properties.h b/nntrainer/layers/common_properties.h index 765c2e9c92..8579de7449 100644 --- a/nntrainer/layers/common_properties.h +++ b/nntrainer/layers/common_properties.h @@ -280,6 +280,22 @@ class Epsilon : public nntrainer::Property { bool isValid(const float &value) const override; }; +/** + * @brief Exponent property, this is used for pow operation + * + */ +class Exponent : public nntrainer::Property { + +public: + /** + * @brief Construct a new Exponent object + * + */ + Exponent(float value = 1.0f); + static constexpr const char *key = "exponent"; /**< unique key to access */ + using prop_tag = float_prop_tag; /**< property type */ +}; + /** * @brief Momentum property, moving average in batch normalization layer * diff --git a/nntrainer/layers/meson.build b/nntrainer/layers/meson.build index 1218aa9b0a..81a43559da 100644 --- a/nntrainer/layers/meson.build +++ b/nntrainer/layers/meson.build @@ -9,6 +9,7 @@ layer_sources = [ 'subtract_layer.cpp', 'multiply_layer.cpp', 'divide_layer.cpp', + 'pow_layer.cpp', 'addition_layer.cpp', 'attention_layer.cpp', 'mol_attention_layer.cpp', diff --git a/nntrainer/layers/pow_layer.cpp b/nntrainer/layers/pow_layer.cpp new file mode 100644 index 0000000000..a99b13bb6b --- /dev/null +++ b/nntrainer/layers/pow_layer.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file pow_layer.cpp + * @date 20 Nov 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is pow layer class (operation layer) + * + */ + +#include "common_properties.h" +#include +#include +#include +#include +#include + +#include + +namespace nntrainer { + +void PowLayer::finalize(InitLayerContext &context) { + context.setOutputDimensions({context.getInputDimensions()[0]}); +} + +void PowLayer::forwarding_operation(const Tensor &input, Tensor &hidden) { + float exp = std::get(pow_props).get(); + input.pow(exp, hidden); +} + +void PowLayer::calcDerivative(RunLayerContext &context) { + float exp = std::get(pow_props).get(); + context.getOutgoingDerivative(0).copy( + context.getIncomingDerivative(SINGLE_INOUT_IDX) + .multiply(exp) + .multiply(context.getInput(0).pow(exp - 1.0f))); +} + +void PowLayer::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, pow_props); + if (!remain_props.empty()) { + std::string msg = "[PowLayer] Unknown Layer Properties count " + + std::to_string(values.size()); + throw exception::not_supported(msg); + } +} +} /* namespace nntrainer */ diff --git a/nntrainer/layers/pow_layer.h b/nntrainer/layers/pow_layer.h new file mode 100644 index 0000000000..025b0d29b3 --- /dev/null +++ b/nntrainer/layers/pow_layer.h @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file pow_layer.h + * @date 20 Nov 2024 + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + * @brief This is pow layer class (operation layer) + * + */ + +#ifndef __POW_LAYER_H__ +#define __POW_LAYER_H__ +#ifdef __cplusplus + +#include +#include +#include + +namespace nntrainer { + +/** + * @class Pow Layer + * @brief Pow Layer + */ +class PowLayer : public UnaryOperationLayer { +public: + /** + * @brief Constructor of Pow Layer + */ + PowLayer() : + UnaryOperationLayer(), + pow_props(props::Print(), props::InPlaceProp(), props::Exponent()), + support_backwarding(true) {} + + /** + * @brief Destructor of Pow Layer + */ + ~PowLayer(){}; + + /** + * @brief Move constructor of Pow Layer. + * @param[in] PowLayer && + */ + PowLayer(PowLayer &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs PowLayer to be moved. + */ + PowLayer &operator=(PowLayer &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) final; + + /** + * @brief forwarding operation for pow + * + * @param input input tensor + * @param hidden tensor to store the result value + */ + void forwarding_operation(const Tensor &input, Tensor &hidden) final; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) final; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const final { return support_backwarding; }; + + /** + * @brief Initialize the in-place settings of the layer + * @return InPlaceType + */ + InPlaceType initializeInPlace() final { + if (std::get(pow_props).empty() || + !std::get(pow_props).get()) { + is_inplace = false; + support_backwarding = true; + } else { + is_inplace = true; + support_backwarding = false; + } + + if (!supportInPlace()) + return InPlaceType::NONE; + else + return InPlaceType::NON_RESTRICTING; + } + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const final {} + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) final; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const final { return PowLayer::type; }; + + std::tuple pow_props; + bool support_backwarding; /**< support backwarding */ + + inline static const std::string type = "pow"; +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __POW_LAYER_H__ */ diff --git a/nntrainer/utils/node_exporter.h b/nntrainer/utils/node_exporter.h index 69be76962e..480cf085cb 100644 --- a/nntrainer/utils/node_exporter.h +++ b/nntrainer/utils/node_exporter.h @@ -237,6 +237,7 @@ class Packed; class LossScaleForMixed; class InPlaceProp; class InPlaceDirectionProp; +class Exponent; } // namespace props class LayerNode; diff --git a/packaging/unittest_models_v2.tar.gz b/packaging/unittest_models_v2.tar.gz index 1793cb10ee..7b9c6a20a1 100644 Binary files a/packaging/unittest_models_v2.tar.gz and b/packaging/unittest_models_v2.tar.gz differ diff --git a/test/input_gen/genModelTests_v2.py b/test/input_gen/genModelTests_v2.py index b9b03cebee..849d250d10 100644 --- a/test/input_gen/genModelTests_v2.py +++ b/test/input_gen/genModelTests_v2.py @@ -518,6 +518,19 @@ def forward(self, inputs, labels): return out, loss +class PowOperation(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(2, 2) + self.loss = torch.nn.MSELoss() + + def forward(self, inputs, labels): + out = self.fc(inputs[0]) + out = out.pow(3) + loss = self.loss(out, labels[0]) + return out, loss + + if __name__ == "__main__": record_v2( ReduceMeanLast(), @@ -835,6 +848,16 @@ def forward(self, inputs, labels): name="multiply_operation", ) + pow_operation = PowOperation() + record_v2( + pow_operation, + iteration=2, + input_dims=[(1, 2)], + input_dtype=[float], + label_dims=[(1, 2)], + name="pow_operation", + ) + # Function to check the created golden test file inspect_file("add_operation.nnmodelgolden") fc_mixed_training_nan_sgd = LinearMixedPrecisionNaNSGD() diff --git a/test/input_gen/golden/unittest_models_v2.tar.gz b/test/input_gen/golden/unittest_models_v2.tar.gz new file mode 100644 index 0000000000..152ff37cf4 Binary files /dev/null and b/test/input_gen/golden/unittest_models_v2.tar.gz differ diff --git a/test/unittest/layers/meson.build b/test/unittest/layers/meson.build index 7d6bb3b49b..54d7d782fa 100644 --- a/test/unittest/layers/meson.build +++ b/test/unittest/layers/meson.build @@ -51,6 +51,7 @@ test_target = [ 'unittest_layers_subtract.cpp', 'unittest_layers_multiply.cpp', 'unittest_layers_divide.cpp', + 'unittest_layers_pow.cpp', 'unittest_layers_multiout.cpp', 'unittest_layers_rnn.cpp', 'unittest_layers_rnncell.cpp', diff --git a/test/unittest/layers/unittest_layers_pow.cpp b/test/unittest/layers/unittest_layers_pow.cpp new file mode 100644 index 0000000000..68bd11bad0 --- /dev/null +++ b/test/unittest/layers/unittest_layers_pow.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 SeungBaek Hong + * + * @file unittest_layers_pow.cpp + * @date 20 Nov 2024 + * @brief Pow Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author SeungBaek Hong + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_pow = LayerSemanticsParamType( + nntrainer::createLayer, nntrainer::PowLayer::type, + {"exponent=3"}, LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, + false, 1); + +auto semantic_pow_multi = LayerSemanticsParamType( + nntrainer::createLayer, nntrainer::PowLayer::type, + {"exponent=3"}, LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, + false, 2); + +GTEST_PARAMETER_TEST(Pow, LayerSemantics, + ::testing::Values(semantic_pow, semantic_pow_multi)); diff --git a/test/unittest/models/unittest_models.cpp b/test/unittest/models/unittest_models.cpp index 8de98a0446..6c691bd9f0 100644 --- a/test/unittest/models/unittest_models.cpp +++ b/test/unittest/models/unittest_models.cpp @@ -948,6 +948,25 @@ static std::unique_ptr makeDivideOperation() { return nn; } +static std::unique_ptr makePowOperation() { + std::unique_ptr nn(new NeuralNetwork()); + + auto outer_graph = + makeGraph({{"input", {"name=in", "input_shape=1:1:2"}}, + {"fully_connected", {"name=fc", "unit=2", "input_layers=in"}}, + {"pow", {"name=pow_layer", "exponent=3", "input_layers=fc"}}, + {"mse", {"name=loss", "input_layers=pow_layer"}}}); + + for (auto &node : outer_graph) { + nn->addLayer(node); + } + + nn->setProperty({"batch_size=1"}); + nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"})); + + return nn; +} + GTEST_PARAMETER_TEST( model, nntrainerModelTest, ::testing::ValuesIn({ @@ -1026,6 +1045,7 @@ GTEST_PARAMETER_TEST( ModelTestOption::ALL_V2), mkModelTc_V2(makeDivideOperation, "divide_operation", ModelTestOption::ALL_V2), + mkModelTc_V2(makePowOperation, "pow_operation", ModelTestOption::ALL_V2), }), [](const testing::TestParamInfo &info) -> const auto & { return std::get<1>(info.param); });