From 49ac0401fc45510885a57f9432196f0dd490c3ac Mon Sep 17 00:00:00 2001 From: keith2018 Date: Tue, 31 Dec 2024 18:12:51 +0800 Subject: [PATCH] Add nn::BatchNorm2D --- README.md | 1 + TinyTorch/CMakeLists.txt | 4 +- TinyTorch/Function.cpp | 107 +++++++++ TinyTorch/Function.h | 34 ++- TinyTorch/Module.cpp | 93 ++++++++ TinyTorch/Module.h | 46 +++- TinyTorch/Tensor.cpp | 5 + TinyTorch/Tensor.h | 2 + TinyTorch/TensorImpl.cpp | 137 ++++++++---- TinyTorch/TensorImpl.h | 458 ++++++++++++++++++++++++++++----------- TinyTorch/Torch.cpp | 40 ++-- TinyTorch/Torch.h | 2 +- demo/demo_mnist.cpp | 2 + demo/demo_module.cpp | 6 +- demo/demo_optim.cpp | 6 +- test/test_function.cpp | 27 +++ test/test_module.cpp | 34 ++- test/test_tensorimpl.cpp | 85 +++++++- 18 files changed, 867 insertions(+), 222 deletions(-) diff --git a/README.md b/README.md index 8a761ab..c2976e6 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Tiny deep learning training framework implemented from scratch in C++ that follo - Module - Linear - Conv2D + - BatchNorm2D - MaxPool2D - Dropout - Softmax diff --git a/TinyTorch/CMakeLists.txt b/TinyTorch/CMakeLists.txt index f25e36b..afe5080 100644 --- a/TinyTorch/CMakeLists.txt +++ b/TinyTorch/CMakeLists.txt @@ -22,7 +22,9 @@ if (MSVC) set_source_files_properties(${TinyTorch_src} PROPERTIES COMPILE_FLAGS "/WX") else () set_source_files_properties(${TinyTorch_src} PROPERTIES COMPILE_FLAGS "-Werror -Wno-deprecated-declarations") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + if (CMAKE_BUILD_TYPE STREQUAL Release) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") + endif () endif () add_library(${PROJECT_NAME} ${TinyTorch_src}) diff --git a/TinyTorch/Function.cpp b/TinyTorch/Function.cpp index d12441b..7bb9fbc 100644 --- a/TinyTorch/Function.cpp +++ b/TinyTorch/Function.cpp @@ -38,6 +38,7 @@ std::unordered_map Function::funcTypeToString_ = { FUNC_ENUM_TO_STRING(Function_LogSoftmax), FUNC_ENUM_TO_STRING(Function_MaxPool2D), FUNC_ENUM_TO_STRING(Function_Conv2D), + FUNC_ENUM_TO_STRING(Function_BatchNorm), FUNC_ENUM_TO_STRING(Function_MSELoss), FUNC_ENUM_TO_STRING(Function_NLLLoss), }; @@ -132,6 +133,15 @@ Tensor Function::conv2d(const Tensor& input, const Tensor& weight, ->callForward({&input, &weight, &bias}); } +Tensor Function::batchNorm(const Tensor& input, Tensor& runningMean, + Tensor& runningVar, const Tensor& weight, + const Tensor& bias, bool training, float momentum, + float eps) { + return std::make_shared(runningMean, runningVar, momentum, eps, + training) + ->callForward({&input, &weight, &bias}); +} + Tensor Function::mseLoss(const Tensor& input, const Tensor& target, LossReduction reduction) { return std::make_shared(reduction)->callForward( @@ -595,6 +605,103 @@ std::vector FuncConv2D::backward(const TensorImpl& grad) { return ret; } +TensorImpl FuncBatchNorm::forward(const std::vector& inputs) { + auto& input = inputs[0]->data(); + auto& weight = inputs[1]->data(); + auto& bias = inputs[2]->data(); + + auto& shape = input.shape(); + assert(shape.size() == 3 || shape.size() == 4); + + if (shape.size() == 3) { + dims_ = {0, 2}; + viewShape_ = {1, shape[1], 1}; + } else { + dims_ = {0, 2, 3}; + viewShape_ = {1, shape[1], 1, 1}; + } + + Tensor mean; + Tensor var; + if (training_) { + mean.data() = input.mean(dims_, true); + var.data() = input.var(dims_, false, true); + auto varUnbiased = input.var(dims_, true, true); + + if (!runningMean_.empty() && !runningVar_.empty()) { + runningMean_.data() *= 1.f - momentum_; + runningMean_.data() += TensorImpl::squeeze(mean.data()) * momentum_; + runningVar_.data() *= 1.f - momentum_; + runningVar_.data() += TensorImpl::squeeze(varUnbiased) * momentum_; + } + } else { + if (!runningMean_.empty() && !runningVar_.empty()) { + mean = runningMean_; + var = runningVar_; + } else { + mean.data() = input.mean(dims_, true); + var.data() = input.var(dims_, true, true); + } + } + + auto inputCentered = Tensor(input - mean.data()); + auto std = Tensor((var.data() + eps_).sqrt()); + auto inputNorm = Tensor(inputCentered / std); + + saveForBackward(inputs); + saveForBackward({&inputNorm, &inputCentered, &std}); + + if (!weight.empty()) { + inputNorm.data() = inputNorm.data() * weight.view(viewShape_); + } + if (!bias.empty()) { + inputNorm.data() = inputNorm.data() + bias.view(viewShape_); + } + return inputNorm.data(); +} + +std::vector FuncBatchNorm::backward(const TensorImpl& grad) { + const auto& savedTensors = getSavedTensors(); + auto& input = savedTensors[0].data(); + auto& weight = savedTensors[1].data(); + // auto& bias = savedTensors[2].data(); + auto& inputNorm = savedTensors[3].data(); + auto& inputCentered = savedTensors[4].data(); + auto& std = savedTensors[5].data(); + + std::vector ret; + // grad of input + if (savedTensors[0].isRequiresGrad()) { + auto dInputNorm = grad; + if (!weight.empty()) { + dInputNorm = dInputNorm * weight.view(viewShape_); + } + int32_t N = 1; + for (int dim : dims_) { + N *= input.shape()[dim]; + } + auto dVar = + (dInputNorm * inputCentered * -0.5f * std.pow(-3.f)).sum(dims_, true); + auto dMean = (dInputNorm * -1.f / std).sum(dims_, true) + + dVar * (inputCentered * -2.f / (float)N).sum(dims_, true); + auto dInput = dInputNorm / std + dVar * 2.f * inputCentered / (float)N + + dMean / (float)N; + ret.push_back(dInput); + } + // grad of weight + if (savedTensors[1].isRequiresGrad()) { + auto dWeight = (grad * inputNorm).sum(dims_); + ret.push_back(dWeight); + } + + // grad of bias + if (savedTensors[2].isRequiresGrad()) { + auto dBias = grad.sum(dims_); + ret.push_back(dBias); + } + return ret; +} + TensorImpl FuncMSELoss::forward(const std::vector& inputs) { saveForBackward(inputs); auto ret = TensorImpl::pow(inputs[0]->data() - inputs[1]->data(), 2); diff --git a/TinyTorch/Function.h b/TinyTorch/Function.h index 1ce96d9..403dabe 100644 --- a/TinyTorch/Function.h +++ b/TinyTorch/Function.h @@ -35,6 +35,7 @@ enum FunctionType { Function_LogSoftmax, Function_MaxPool2D, Function_Conv2D, + Function_BatchNorm, Function_MSELoss, Function_NLLLoss, }; @@ -70,7 +71,8 @@ class Function : public std::enable_shared_from_this { static Tensor linear(const Tensor& input, const Tensor& weight, const Tensor& bias); - static Tensor dropout(const Tensor& input, float p, bool training = true); + static Tensor dropout(const Tensor& input, float p = 0.5f, + bool training = true); static Tensor softmax(const Tensor& input, int32_t dim); @@ -84,6 +86,11 @@ class Function : public std::enable_shared_from_this { const Tensor& bias = {}, Size2D stride = 1, Size2D padding = 0); + static Tensor batchNorm(const Tensor& input, Tensor& runningMean, + Tensor& runningVar, const Tensor& weight, + const Tensor& bias, bool training = false, + float momentum = 0.1f, float eps = 1e-5); + static Tensor nllloss(const Tensor& input, const Tensor& target, LossReduction reduction = MEAN); @@ -110,14 +117,13 @@ class Function : public std::enable_shared_from_this { if (!NoGradScope::isGradEnabled()) { return; } - savedTensors_.reserve(tensors.size()); + savedTensors_.reserve(savedTensors_.size() + tensors.size()); for (const auto& t : tensors) { savedTensors_.push_back(*t); } } std::vector& getSavedTensors() { return savedTensors_; }; - protected: std::weak_ptr owner_; std::vector savedTensors_; std::vector> nextFuncs_; @@ -286,6 +292,28 @@ class FuncConv2D : public Function { TensorImpl col_; }; +class FuncBatchNorm : public Function { + public: + explicit FuncBatchNorm(Tensor& runningMean, Tensor& runningVar, + float momentum, float eps, bool training) + : runningMean_(runningMean), + runningVar_(runningVar), + momentum_(momentum), + eps_(eps), + training_(training) {} + DEFINE_FUNCTION_MEMBERS(Function_BatchNorm) + + private: + Tensor& runningMean_; + Tensor& runningVar_; + std::vector dims_; + std::vector viewShape_; + + float momentum_; + float eps_; + bool training_; +}; + class FuncMSELoss : public Function { public: explicit FuncMSELoss(LossReduction reduction) : reduction_(reduction) {} diff --git a/TinyTorch/Module.cpp b/TinyTorch/Module.cpp index 9eb7029..d487953 100644 --- a/TinyTorch/Module.cpp +++ b/TinyTorch/Module.cpp @@ -6,6 +6,8 @@ #include "Module.h" +#include + #include "Function.h" #include "Init.h" @@ -21,6 +23,16 @@ std::vector Module::parameters() { return ret; } +std::vector Module::states() { + std::vector ret; + for (auto &module : subModules_) { + for (auto p : module.get().states()) { + ret.push_back(p); + } + } + return ret; +} + void Module::resetParameters() { for (auto &module : subModules_) { module.get().resetParameters(); @@ -52,6 +64,16 @@ std::vector Sequential::parameters() { return ret; } +std::vector Sequential::states() { + std::vector ret; + for (auto &module : modules_) { + for (auto p : module->states()) { + ret.push_back(p); + } + } + return ret; +} + void Sequential::resetParameters() { for (auto &module : modules_) { module->resetParameters(); @@ -91,6 +113,8 @@ std::vector Linear::parameters() { return {&weights_}; } +std::vector Linear::states() { return parameters(); } + void Linear::resetParameters() { Init::kaimingUniform(weights_, std::sqrt(5.f)); if (useBias_) { @@ -156,6 +180,8 @@ std::vector Conv2D::parameters() { return {&weights_}; } +std::vector Conv2D::states() { return parameters(); } + void Conv2D::resetParameters() { Init::kaimingUniform(weights_, std::sqrt(5.f)); if (useBias_) { @@ -174,4 +200,71 @@ void Conv2D::zeroGrad() { } } +BatchNorm2D::BatchNorm2D(int32_t numFeatures, float eps, float momentum, + bool affine, bool trackRunningStats) + : numFeatures_(numFeatures), + eps_(eps), + momentum_(momentum), + affine_(affine), + trackRunningStats_(trackRunningStats), + numBatchesTracked_(0) { + if (affine_) { + weights_ = Tensor::shape({numFeatures_}, true); + bias_ = Tensor::shape({numFeatures_}, true); + } + if (trackRunningStats_) { + runningMean_ = Tensor::shape({numFeatures_}, true); + runningVar_ = Tensor::shape({numFeatures_}, true); + } + + BatchNorm2D::resetParameters(); +} + +Tensor BatchNorm2D::forward(Tensor &input) { + assert(input.dim() == 4); + if (training_ && trackRunningStats_) { + numBatchesTracked_++; + } + + bool bnTrain = training_ || !trackRunningStats_; + return Function::batchNorm(input, runningMean_, runningVar_, weights_, bias_, + bnTrain, momentum_, eps_); +} + +std::vector BatchNorm2D::parameters() { + if (affine_) { + return {&weights_, &bias_}; + } + return {}; +} + +std::vector BatchNorm2D::states() { + std::vector ret({&runningMean_, &runningVar_}); + if (affine_) { + ret.push_back(&weights_); + ret.push_back(&bias_); + } + return ret; +} + +void BatchNorm2D::resetParameters() { + if (affine_) { + weights_.data().fill(1.f); + bias_.data().fill(0.f); + } + + if (trackRunningStats_) { + runningMean_.data().fill(0.f); + runningVar_.data().fill(1.f); + numBatchesTracked_ = 0; + } +} + +void BatchNorm2D::zeroGrad() { + if (affine_) { + weights_.zeroGrad(); + bias_.zeroGrad(); + } +} + } // namespace TinyTorch::nn diff --git a/TinyTorch/Module.h b/TinyTorch/Module.h index 6af7cc7..c8fcd17 100644 --- a/TinyTorch/Module.h +++ b/TinyTorch/Module.h @@ -14,6 +14,7 @@ class Module { public: virtual ~Module() = default; virtual std::vector parameters(); + virtual std::vector states(); virtual void resetParameters(); virtual void zeroGrad(); @@ -75,6 +76,7 @@ class Sequential : public Module { Tensor forward(Tensor &input) override; std::vector parameters() override; + std::vector states() override; void resetParameters() override; void zeroGrad() override; @@ -100,11 +102,12 @@ class Linear : public Module { Tensor forward(Tensor &input) override; std::vector parameters() override; + std::vector states() override; void resetParameters() override; void zeroGrad() override; - Tensor &Weights() { return weights_; } - Tensor &Bias() { return bias_; } + Tensor &weights() { return weights_; } + Tensor &bias() { return bias_; } private: int32_t inFeatures_; @@ -185,11 +188,12 @@ class Conv2D : public Module { Tensor forward(Tensor &input) override; std::vector parameters() override; + std::vector states() override; void resetParameters() override; void zeroGrad() override; - Tensor &Weights() { return weights_; } - Tensor &Bias() { return bias_; } + Tensor &weights() { return weights_; } + Tensor &bias() { return bias_; } private: int32_t inFeatures_; @@ -198,8 +202,42 @@ class Conv2D : public Module { Size2D stride_; Size2D padding_; bool useBias_; + Tensor weights_; Tensor bias_; }; +class BatchNorm2D : public Module { + public: + explicit BatchNorm2D(int32_t numFeatures, float eps = 1e-5, + float momentum = 0.1f, bool affine = true, + bool trackRunningStats = true); + + Tensor forward(Tensor &input) override; + std::vector parameters() override; + std::vector states() override; + void resetParameters() override; + void zeroGrad() override; + + Tensor &weights() { return weights_; } + Tensor &bias() { return bias_; } + + Tensor &runningMean() { return runningMean_; } + Tensor &runningVar() { return runningVar_; } + + private: + int32_t numFeatures_; + float eps_; + float momentum_; + bool affine_; + bool trackRunningStats_; + + Tensor weights_; + Tensor bias_; + + Tensor runningMean_; + Tensor runningVar_; + int32_t numBatchesTracked_; +}; + } // namespace TinyTorch::nn diff --git a/TinyTorch/Tensor.cpp b/TinyTorch/Tensor.cpp index 3e70824..b3fda9d 100644 --- a/TinyTorch/Tensor.cpp +++ b/TinyTorch/Tensor.cpp @@ -62,6 +62,11 @@ Tensor Tensor::randn(const Shape &shape, bool requiresGrad) { return Tensor(std::move(ret), requiresGrad); } +Tensor Tensor::arange(float start, float stop, float step, bool requiresGrad) { + auto ret = TensorImpl::arange(start, stop, step); + return Tensor(std::move(ret), requiresGrad); +} + Tensor Tensor::linspace(float start, float end, int steps, bool requiresGrad) { auto ret = TensorImpl::linspace(start, end, steps); return Tensor(std::move(ret), requiresGrad); diff --git a/TinyTorch/Tensor.h b/TinyTorch/Tensor.h index 373d5b9..d237510 100644 --- a/TinyTorch/Tensor.h +++ b/TinyTorch/Tensor.h @@ -33,6 +33,8 @@ class Tensor { static Tensor onesLike(const Tensor &t, bool requiresGrad = false); static Tensor zeros(const Shape &shape, bool requiresGrad = false); static Tensor randn(const Shape &shape, bool requiresGrad = false); + static Tensor arange(float start, float stop, float steps, + bool requiresGrad = false); static Tensor linspace(float start, float end, int steps, bool requiresGrad = false); diff --git a/TinyTorch/TensorImpl.cpp b/TinyTorch/TensorImpl.cpp index a75dc6e..93cb847 100644 --- a/TinyTorch/TensorImpl.cpp +++ b/TinyTorch/TensorImpl.cpp @@ -111,20 +111,26 @@ std::default_random_engine RandomGenerator::randomEngine_; } \ return ret -#define TENSOR_UFUNC_FAST_LOOP(scalarRet, func) \ +#define TENSOR_UFUNC_REDUCE_ALL(scalarRet, func) \ if (t.isScalar()) { \ return scalarRet; \ } \ - func functor; \ - functor.reset(); \ - for (int32_t i = 0; i < t.elemCount_; i++) { \ - functor.op(t.data_[i]); \ + auto functor = std::make_shared(); \ + return t.reduceAll(functor) + +#define TENSOR_UFUNC_REDUCE_SINGLE(scalarRet, func, axis) \ + if (t.isScalar()) { \ + return scalar(scalarRet); \ } \ - return functor.result() + auto functor = std::make_shared(); \ + return t.reduceSingle(functor, axis, keepDims) -#define TENSOR_UFUNC_REDUCE(func) \ - func functor; \ - return t.reduce(functor, axis.get(t.dimCount_), keepDims) +#define TENSOR_UFUNC_REDUCE_MULTI(scalarRet, func, axes) \ + if (t.isScalar()) { \ + return scalar(scalarRet); \ + } \ + auto functor = std::make_shared(); \ + return t.reduceMulti(functor, axes, keepDims) // clang-format on @@ -716,14 +722,14 @@ TensorImpl TensorImpl::col2im(const Shape &inputShape, Size2D kernelSize, return retTensor; } -TensorImpl TensorImpl::transpose(const std::vector &axis) const { +TensorImpl TensorImpl::transpose(const std::vector &axes) const { TENSOR_CHECK_EMPTY(*this, {}); if (dim() <= 1) { return *this; } TensorIter it(shape()); - if (axis.empty()) { + if (axes.empty()) { // If not specified, defaults to range(a.ndim)[::-1], which reverses the // order of the axes. std::vector reverseTrans; @@ -732,11 +738,11 @@ TensorImpl TensorImpl::transpose(const std::vector &axis) const { reverseTrans[i] = dim() - i - 1; } it.transpose(reverseTrans); - } else if (axis.size() != dim()) { + } else if (axes.size() != dim()) { error(__FUNCTION__, TensorError_InvalidAxis); return {}; } else { - it.transpose(axis); + it.transpose(axes); } TensorImpl ret = shape(it.shape()); @@ -1448,96 +1454,121 @@ TensorImpl TensorImpl::matmulTrans(const TensorImpl &a, const TensorImpl &b) { float TensorImpl::min(const TensorImpl &t) { TENSOR_CHECK_EMPTY(t, 0); - TENSOR_UFUNC_FAST_LOOP(t[0], UFuncMin); + TENSOR_UFUNC_REDUCE_ALL(t[0], UFuncSingleMin); } float TensorImpl::max(const TensorImpl &t) { TENSOR_CHECK_EMPTY(t, 0); - TENSOR_UFUNC_FAST_LOOP(t[0], UFuncMax); + TENSOR_UFUNC_REDUCE_ALL(t[0], UFuncSingleMax); } float TensorImpl::mean(const TensorImpl &t) { TENSOR_CHECK_EMPTY(t, 0); - TENSOR_UFUNC_FAST_LOOP(t[0], UFuncMean); + TENSOR_UFUNC_REDUCE_ALL(t[0], UFuncSingleMean); } float TensorImpl::sum(const TensorImpl &t) { TENSOR_CHECK_EMPTY(t, 0); - TENSOR_UFUNC_FAST_LOOP(t[0], UFuncSum); + TENSOR_UFUNC_REDUCE_ALL(t[0], UFuncSingleSum); } -float TensorImpl::var(const TensorImpl &t) { +float TensorImpl::var(const TensorImpl &t, bool unbiased) { TENSOR_CHECK_EMPTY(t, 0); - TENSOR_UFUNC_FAST_LOOP(0, UFuncVar); + if (unbiased) { + TENSOR_UFUNC_REDUCE_ALL(0, UFuncSingleVarUnbiased); + } + TENSOR_UFUNC_REDUCE_ALL(0, UFuncSingleVar); } float TensorImpl::argmin(const TensorImpl &t) { TENSOR_CHECK_EMPTY(t, 0); - TENSOR_UFUNC_FAST_LOOP(0, UFuncArgMin); + TENSOR_UFUNC_REDUCE_ALL(0, UFuncSingleArgMin); } float TensorImpl::argmax(const TensorImpl &t) { TENSOR_CHECK_EMPTY(t, 0); - TENSOR_UFUNC_FAST_LOOP(0, UFuncArgMax); + TENSOR_UFUNC_REDUCE_ALL(0, UFuncSingleArgMax); } TensorImpl TensorImpl::min(const TensorImpl &t, const Axis &axis, bool keepDims) { TENSOR_CHECK_EMPTY(t, {}); - TENSOR_UFUNC_REDUCE(UFuncMin); + TENSOR_UFUNC_REDUCE_SINGLE(t[0], UFuncSingleMin, axis.get(t.dimCount_)); } TensorImpl TensorImpl::max(const TensorImpl &t, const Axis &axis, bool keepDims) { TENSOR_CHECK_EMPTY(t, {}); - TENSOR_UFUNC_REDUCE(UFuncMax); + TENSOR_UFUNC_REDUCE_SINGLE(t[0], UFuncSingleMax, axis.get(t.dimCount_)); } TensorImpl TensorImpl::mean(const TensorImpl &t, const Axis &axis, bool keepDims) { TENSOR_CHECK_EMPTY(t, {}); - TENSOR_UFUNC_REDUCE(UFuncMean); + TENSOR_UFUNC_REDUCE_SINGLE(t[0], UFuncSingleMean, axis.get(t.dimCount_)); } TensorImpl TensorImpl::sum(const TensorImpl &t, const Axis &axis, bool keepDims) { TENSOR_CHECK_EMPTY(t, {}); - TENSOR_UFUNC_REDUCE(UFuncSum); + TENSOR_UFUNC_REDUCE_SINGLE(t[0], UFuncSingleSum, axis.get(t.dimCount_)); } -TensorImpl TensorImpl::var(const TensorImpl &t, const Axis &axis, +TensorImpl TensorImpl::var(const TensorImpl &t, const Axis &axis, bool unbiased, bool keepDims) { TENSOR_CHECK_EMPTY(t, {}); - TENSOR_UFUNC_REDUCE(UFuncVar); + if (unbiased) { + TENSOR_UFUNC_REDUCE_SINGLE(0, UFuncSingleVarUnbiased, + axis.get(t.dimCount_)); + } + TENSOR_UFUNC_REDUCE_SINGLE(0, UFuncSingleVar, axis.get(t.dimCount_)); } TensorImpl TensorImpl::argmin(const TensorImpl &t, const Axis &axis, bool keepDims) { TENSOR_CHECK_EMPTY(t, {}); - TENSOR_UFUNC_REDUCE(UFuncArgMin); + TENSOR_UFUNC_REDUCE_SINGLE(0, UFuncSingleArgMin, axis.get(t.dimCount_)); } TensorImpl TensorImpl::argmax(const TensorImpl &t, const Axis &axis, bool keepDims) { TENSOR_CHECK_EMPTY(t, {}); - TENSOR_UFUNC_REDUCE(UFuncArgMax); + TENSOR_UFUNC_REDUCE_SINGLE(0, UFuncSingleArgMax, axis.get(t.dimCount_)); } -void TensorImpl::traverse(UFunc &func, int32_t start, int32_t stride, - int32_t cnt) const { +TensorImpl TensorImpl::mean(const TensorImpl &t, + const std::vector &axes, bool keepDims) { + TENSOR_CHECK_EMPTY(t, {}); + TENSOR_UFUNC_REDUCE_MULTI(t[0], UFuncMultiMean, axes); +} + +TensorImpl TensorImpl::sum(const TensorImpl &t, + const std::vector &axes, bool keepDims) { + TENSOR_CHECK_EMPTY(t, {}); + TENSOR_UFUNC_REDUCE_MULTI(t[0], UFuncMultiSum, axes); +} + +TensorImpl TensorImpl::var(const TensorImpl &t, + const std::vector &axes, bool unbiased, + bool keepDims) { + TENSOR_CHECK_EMPTY(t, {}); + if (unbiased) { + TENSOR_UFUNC_REDUCE_MULTI(0, UFuncMultiVarUnbiased, axes); + } + TENSOR_UFUNC_REDUCE_MULTI(0, UFuncMultiVar, axes); +} + +void TensorImpl::traverse(const std::shared_ptr &func, + int32_t start, int32_t stride, int32_t cnt) const { int32_t idx = start; for (int32_t n = 0; n < cnt; n++) { - func.op(data_[idx]); + func->op(data_[idx]); idx += stride; } } -TensorImpl TensorImpl::reduce(UFunc &func, int32_t axis, bool keepDims) const { - // check scalar - if (isScalar()) { - return scalar(data_[0]); - } - +TensorImpl TensorImpl::reduceSingle(const std::shared_ptr &func, + int32_t axis, bool keepDims) const { // check axis if (axis >= dimCount_) { error(__FUNCTION__, TensorError_InvalidAxis); @@ -1571,9 +1602,9 @@ TensorImpl TensorImpl::reduce(UFunc &func, int32_t axis, bool keepDims) const { for (int32_t i = 0; i < groupCount; i++) { axisStart = i * groupStride; for (int32_t j = 0; j < axisStride; j++) { - func.reset(); + func->reset(); traverse(func, axisStart, axisStride, axisLength); - ret[retIdx++] = func.result(); + ret[retIdx++] = func->result(); axisStart++; } } @@ -1581,6 +1612,22 @@ TensorImpl TensorImpl::reduce(UFunc &func, int32_t axis, bool keepDims) const { return ret; } +TensorImpl TensorImpl::reduceMulti(const std::shared_ptr &func, + const std::vector &axes, + bool keepDims) const { + ReduceHelper helper(*this); + helper.initAxisReduce(axes, keepDims); + return func->doReduce(helper); +} + +float TensorImpl::reduceAll(const std::shared_ptr &func) const { + func->reset(); + for (int32_t i = 0; i < elemCount_; i++) { + func->op(data_[i]); + } + return func->result(); +} + void TensorImpl::splitAxis(std::vector &retTensors, std::vector &splitIndices, int32_t axis) const { @@ -1804,14 +1851,14 @@ void TensorIter::broadcast(const Shape &shape) { reset(); } -void TensorIter::transpose(const std::vector &axis) { +void TensorIter::transpose(const std::vector &axes) { // assume axis size equal to dimension count - assert(axis.size() == ndM1_ + 1); + assert(axes.size() == ndM1_ + 1); // reorder dimsM1_, strides_, backStrides_ - reorder(dimsM1_, axis); - reorder(strides_, axis); - reorder(backStrides_, axis); + reorder(dimsM1_, axes); + reorder(strides_, axes); + reorder(backStrides_, axes); } } // namespace TinyTorch \ No newline at end of file diff --git a/TinyTorch/TensorImpl.h b/TinyTorch/TensorImpl.h index 859b416..2c3711f 100644 --- a/TinyTorch/TensorImpl.h +++ b/TinyTorch/TensorImpl.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -19,6 +20,9 @@ namespace TinyTorch { #define TENSOR_MAX_DIMS 8 +class UFuncSingle; +class UFuncMulti; + typedef enum TensorError_ { TensorError_None = 0, TensorError_EmptyTensor, @@ -82,127 +86,6 @@ class Axis { int32_t axis_ = 0; }; -class UFunc { - public: - virtual ~UFunc() = default; - virtual void op(const float &val) { idx_++; }; - - virtual float result() { return tmp; }; - - virtual void reset() { - idx_ = 0; - tmp = 0.f; - } - - protected: - int32_t idx_ = 0; - float tmp = 0.f; -}; - -class UFuncSum : public UFunc { - public: - void op(const float &val) override { tmp += val; } -}; - -class UFuncMean : public UFunc { - public: - void op(const float &val) override { - idx_++; - tmp += val; - } - - float result() override { return tmp / (float)idx_; } -}; - -class UFuncVar : public UFunc { - public: - void op(const float &val) override { - idx_++; - tmp += val; - squareSum_ += val * val; - } - - float result() override { - float mean = tmp / (float)idx_; - return squareSum_ / (float)idx_ - mean * mean; - } - - void reset() override { - idx_ = 0; - tmp = 0; - squareSum_ = 0; - } - - private: - float squareSum_ = 0; -}; - -class UFuncMin : public UFunc { - public: - void op(const float &val) override { - if (val < tmp) { - tmp = val; - } - } - - void reset() override { tmp = std::numeric_limits::max(); } -}; - -class UFuncMax : public UFunc { - public: - void op(const float &val) override { - if (val > tmp) { - tmp = val; - } - } - - void reset() override { tmp = -std::numeric_limits::max(); } -}; - -class UFuncArgMin : public UFunc { - public: - void op(const float &val) override { - if (val < tmp) { - tmp = val; - minIdx_ = idx_; - } - idx_++; - } - - float result() override { return (float)minIdx_; } - - void reset() override { - tmp = std::numeric_limits::max(); - idx_ = 0; - minIdx_ = 0; - } - - private: - int32_t minIdx_ = 0; -}; - -class UFuncArgMax : public UFunc { - public: - void op(const float &val) override { - if (val > tmp) { - tmp = val; - maxIdx_ = idx_; - } - idx_++; - } - - float result() override { return (float)maxIdx_; } - - void reset() override { - tmp = -std::numeric_limits::max(); - idx_ = 0; - maxIdx_ = 0; - } - - private: - int32_t maxIdx_ = 0; -}; - // float type elements only class TensorImpl { public: @@ -371,11 +254,11 @@ class TensorImpl { Size2D padding = 0) const; // transpose - TensorImpl transpose(const std::vector &axis = {}) const; + TensorImpl transpose(const std::vector &axes = {}) const; static TensorImpl transpose(const TensorImpl &t, - const std::vector &axis = {}) { - return t.transpose(axis); + const std::vector &axes = {}) { + return t.transpose(axes); } // split @@ -541,7 +424,7 @@ class TensorImpl { static float max(const TensorImpl &t); static float mean(const TensorImpl &t); static float sum(const TensorImpl &t); - static float var(const TensorImpl &t); + static float var(const TensorImpl &t, bool unbiased = true); static float argmin(const TensorImpl &t); static float argmax(const TensorImpl &t); @@ -549,7 +432,9 @@ class TensorImpl { float max() const { return TensorImpl::max(*this); }; float mean() const { return TensorImpl::mean(*this); }; float sum() const { return TensorImpl::sum(*this); }; - float var() const { return TensorImpl::var(*this); }; + float var(bool unbiased = true) const { + return TensorImpl::var(*this, unbiased); + }; float argmin() const { return TensorImpl::argmin(*this); }; float argmax() const { return TensorImpl::argmax(*this); }; @@ -562,12 +447,19 @@ class TensorImpl { static TensorImpl sum(const TensorImpl &t, const Axis &axis, bool keepDims = false); static TensorImpl var(const TensorImpl &t, const Axis &axis, - bool keepDims = false); + bool unbiased = true, bool keepDims = false); static TensorImpl argmin(const TensorImpl &t, const Axis &axis, bool keepDims = false); static TensorImpl argmax(const TensorImpl &t, const Axis &axis, bool keepDims = false); + static TensorImpl mean(const TensorImpl &t, const std::vector &axes, + bool keepDims = false); + static TensorImpl sum(const TensorImpl &t, const std::vector &axes, + bool keepDims = false); + static TensorImpl var(const TensorImpl &t, const std::vector &axes, + bool unbiased = true, bool keepDims = false); + TensorImpl min(const Axis &axis, bool keepDims = false) const { return TensorImpl::min(*this, axis, keepDims); } @@ -584,8 +476,8 @@ class TensorImpl { return TensorImpl::sum(*this, axis, keepDims); } - TensorImpl var(const Axis &axis, bool keepDims = false) const { - return TensorImpl::var(*this, axis, keepDims); + TensorImpl var(const Axis &axis, bool unbiased, bool keepDims = false) const { + return TensorImpl::var(*this, axis, unbiased, keepDims); } TensorImpl argmin(const Axis &axis, bool keepDims = false) const { @@ -596,6 +488,21 @@ class TensorImpl { return TensorImpl::argmax(*this, axis, keepDims); } + TensorImpl mean(const std::vector &axes, + bool keepDims = false) const { + return TensorImpl::mean(*this, axes, keepDims); + } + + TensorImpl sum(const std::vector &axes, + bool keepDims = false) const { + return TensorImpl::sum(*this, axes, keepDims); + } + + TensorImpl var(const std::vector &axes, bool unbiased = true, + bool keepDims = false) const { + return TensorImpl::var(*this, axes, unbiased, keepDims); + } + public: class Iterator { public: @@ -621,8 +528,14 @@ class TensorImpl { void initMeta(); void initData(const float *from = nullptr); - void traverse(UFunc &func, int32_t start, int32_t stride, int32_t cnt) const; - TensorImpl reduce(UFunc &func, int32_t axis, bool keepDims = false) const; + void traverse(const std::shared_ptr &func, int32_t start, + int32_t stride, int32_t cnt) const; + TensorImpl reduceSingle(const std::shared_ptr &func, + int32_t axis, bool keepDims = false) const; + TensorImpl reduceMulti(const std::shared_ptr &func, + const std::vector &axes, + bool keepDims = false) const; + float reduceAll(const std::shared_ptr &func) const; void splitAxis(std::vector &retTensors, std::vector &splitIndices, int32_t axis) const; @@ -685,7 +598,7 @@ class TensorIter { void broadcast(const Shape &shape); // transpose - void transpose(const std::vector &axis); + void transpose(const std::vector &axes); protected: // reorder array @@ -711,4 +624,285 @@ class TensorIter { int32_t itCnt_ = 0; }; +class ReduceHelper { + public: + explicit ReduceHelper(const TensorImpl &tensor) + : srcTensor_(tensor), allReduce_(false), reduceSize_(1) {} + + void initAxisReduce(const std::vector &axes, bool keepDims) { + allReduce_ = false; + reduceAxes_ = axes; + Shape retShape; + retShape.reserve(srcTensor_.dim()); + reduceShape_.reserve(srcTensor_.dim()); + std::vector isAxis(srcTensor_.dim(), false); + for (int32_t axis : axes) { + axis = Axis(axis).get(srcTensor_.dim()); + isAxis[axis] = true; + reduceSize_ *= srcTensor_.shape()[axis]; + } + + // init retShape and reduceShape_ + for (int32_t dim = 0; dim < srcTensor_.dim(); dim++) { + if (isAxis[dim]) { + if (keepDims) { + retShape.emplace_back(1); + } + reduceShape_.emplace_back(1); + } else { + retShape.emplace_back(srcTensor_.shape()[dim]); + reduceShape_.emplace_back(srcTensor_.shape()[dim]); + } + } + + // calculate reduceStrides_ + auto dimCount = (int32_t)reduceShape_.size(); + auto elemCount = 1; + reduceStrides_.resize(dimCount); + for (auto dim = int32_t(dimCount - 1); dim >= 0; dim--) { + reduceStrides_[dim] = elemCount; + elemCount *= reduceShape_[dim]; + } + + dstTensor_ = TensorImpl::zeros(retShape); + } + + void initAllReduce() { + allReduce_ = true; + reduceSize_ = srcTensor_.size(); + dstTensor_ = TensorImpl::scalar(0.f); + } + + const TensorImpl &getOriginTensor() { return srcTensor_; } + TensorImpl &getReducedTensor() { return dstTensor_; } + int32_t getReduceSize() const { return reduceSize_; } + + // src index -> dst index + int32_t indexMapping(int32_t idx) { + if (allReduce_) { + return 0; + } + + int32_t ret = 0; + for (int i = 0; i < srcTensor_.dim(); i++) { + if (reduceShape_[i] != 1) { + ret += (idx / srcTensor_.strides()[i]) * reduceStrides_[i]; + } + idx %= srcTensor_.strides()[i]; + } + return ret; + } + + private: + const TensorImpl &srcTensor_; + std::vector reduceAxes_; + bool allReduce_; + int32_t reduceSize_; + Shape reduceShape_; + Shape reduceStrides_; + TensorImpl dstTensor_; +}; + +class UFuncSingle { + public: + virtual ~UFuncSingle() = default; + virtual void op(const float &val) { idx_++; }; + + virtual float result() { return tmp; }; + + virtual void reset() { + idx_ = 0; + tmp = 0.f; + } + + protected: + int32_t idx_ = 0; + float tmp = 0.f; +}; + +class UFuncSingleSum : public UFuncSingle { + public: + void op(const float &val) override { tmp += val; } +}; + +class UFuncSingleMean : public UFuncSingle { + public: + void op(const float &val) override { + idx_++; + tmp += val; + } + + float result() override { return tmp / (float)idx_; } +}; + +class UFuncSingleVar : public UFuncSingle { + public: + void op(const float &val) override { + idx_++; + tmp += val; + squareSum_ += val * val; + } + + virtual float result() override { + float mean = tmp / (float)idx_; + return squareSum_ / (float)idx_ - mean * mean; + } + + void reset() override { + idx_ = 0; + tmp = 0; + squareSum_ = 0; + } + + protected: + float squareSum_ = 0; +}; + +class UFuncSingleVarUnbiased : public UFuncSingleVar { + public: + float result() override { + float mean = tmp / (float)idx_; + return (squareSum_ / (float)idx_ - mean * mean) * + ((float)idx_ / ((float)idx_ - 1.f)); + } +}; + +class UFuncSingleMin : public UFuncSingle { + public: + void op(const float &val) override { + if (val < tmp) { + tmp = val; + } + } + + void reset() override { tmp = std::numeric_limits::max(); } +}; + +class UFuncSingleMax : public UFuncSingle { + public: + void op(const float &val) override { + if (val > tmp) { + tmp = val; + } + } + + void reset() override { tmp = -std::numeric_limits::max(); } +}; + +class UFuncSingleArgMin : public UFuncSingle { + public: + void op(const float &val) override { + if (val < tmp) { + tmp = val; + minIdx_ = idx_; + } + idx_++; + } + + float result() override { return (float)minIdx_; } + + void reset() override { + tmp = std::numeric_limits::max(); + idx_ = 0; + minIdx_ = 0; + } + + private: + int32_t minIdx_ = 0; +}; + +class UFuncSingleArgMax : public UFuncSingle { + public: + void op(const float &val) override { + if (val > tmp) { + tmp = val; + maxIdx_ = idx_; + } + idx_++; + } + + float result() override { return (float)maxIdx_; } + + void reset() override { + tmp = -std::numeric_limits::max(); + idx_ = 0; + maxIdx_ = 0; + } + + private: + int32_t maxIdx_ = 0; +}; + +class UFuncMulti { + public: + virtual ~UFuncMulti() = default; + virtual TensorImpl &&doReduce(ReduceHelper &reduceHelper) = 0; +}; + +class UFuncMultiSum : public UFuncMulti { + public: + TensorImpl &&doReduce(ReduceHelper &reduceHelper) override { + auto &src = reduceHelper.getOriginTensor(); + auto &dst = reduceHelper.getReducedTensor(); + for (int32_t i = 0; i < src.size(); i++) { + auto dstIdx = reduceHelper.indexMapping(i); + dst[dstIdx] += src[i]; + } + + return std::move(dst); + } +}; + +class UFuncMultiMean : public UFuncMulti { + public: + TensorImpl &&doReduce(ReduceHelper &reduceHelper) override { + auto &src = reduceHelper.getOriginTensor(); + auto &dst = reduceHelper.getReducedTensor(); + for (int32_t i = 0; i < src.size(); i++) { + auto dstIdx = reduceHelper.indexMapping(i); + dst[dstIdx] += src[i]; + } + dst *= 1.f / (float)reduceHelper.getReduceSize(); + return std::move(dst); + } +}; + +class UFuncMultiVar : public UFuncMulti { + public: + TensorImpl &&doReduce(ReduceHelper &reduceHelper) override { + auto &src = reduceHelper.getOriginTensor(); + auto &dst = reduceHelper.getReducedTensor(); + + auto mean = TensorImpl::zeros(dst.shape()); + for (int32_t i = 0; i < src.size(); i++) { + auto dstIdx = reduceHelper.indexMapping(i); + mean[dstIdx] += src[i]; + } + // mean + auto scale = 1.f / (float)reduceHelper.getReduceSize(); + mean *= scale; + + // squared diff + for (int32_t i = 0; i < src.size(); i++) { + auto dstIdx = reduceHelper.indexMapping(i); + auto diff = mean[dstIdx] - src[i]; + dst[dstIdx] += diff * diff; + } + varianceReduce(dst, reduceHelper); + return std::move(dst); + } + + protected: + virtual void varianceReduce(TensorImpl &dst, ReduceHelper &reduceHelper) { + dst *= 1.f / (float)reduceHelper.getReduceSize(); + } +}; + +class UFuncMultiVarUnbiased : public UFuncMultiVar { + protected: + void varianceReduce(TensorImpl &dst, ReduceHelper &reduceHelper) override { + dst *= 1.f / ((float)reduceHelper.getReduceSize() - 1.f); + } +}; + } // namespace TinyTorch \ No newline at end of file diff --git a/TinyTorch/Torch.cpp b/TinyTorch/Torch.cpp index c778c20..6b0f10a 100644 --- a/TinyTorch/Torch.cpp +++ b/TinyTorch/Torch.cpp @@ -13,10 +13,10 @@ namespace TinyTorch { void manualSeed(unsigned int seed) { RandomGenerator::setSeed(seed); } template -static std::string printArray(const T* vec, int32_t size, bool restrict) { +static std::string printArray(const T* vec, int32_t size, bool full) { std::ostringstream oss; oss << "("; - if (!restrict || size <= 16) { + if (full || size <= 16) { for (size_t i = 0; i < size; ++i) { oss << vec[i]; if (i != size - 1) { @@ -42,7 +42,7 @@ static std::string printArray(const T* vec, int32_t size, bool restrict) { return oss.str(); } -void print(const Tensor& tensor) { +void print(const Tensor& tensor, bool full) { std::ostringstream oss; oss << "Tensor { shape: " << printArray(&tensor.shape()[0], tensor.dim(), false); @@ -50,7 +50,7 @@ void print(const Tensor& tensor) { if (tensor.isRequiresGrad()) { oss << ", gradFunc: " << tensor.getGradFunc()->typeString(); } - oss << ", data: " << printArray(&tensor.data()[0], tensor.size(), true); + oss << ", data: " << printArray(&tensor.data()[0], tensor.size(), full); LOGD("%s", oss.str().c_str()); } @@ -65,11 +65,15 @@ void save(const Tensor& tensor, std::ofstream& ofs) { ofs.write((const char*)(&elemCount), sizeof(elemCount)); // shape, strides, data - ofs.write((const char*)(t.shape().data()), - std::streamsize(dimCount * sizeof(int32_t))); - ofs.write((const char*)(t.strides().data()), - std::streamsize(dimCount * sizeof(int32_t))); - ofs.write((const char*)(&t[0]), std::streamsize(elemCount * sizeof(float))); + if (dimCount > 0) { + ofs.write((const char*)(t.shape().data()), + std::streamsize(dimCount * sizeof(int32_t))); + ofs.write((const char*)(t.strides().data()), + std::streamsize(dimCount * sizeof(int32_t))); + } + if (elemCount > 0) { + ofs.write((const char*)(&t[0]), std::streamsize(elemCount * sizeof(float))); + } } void load(Tensor& tensor, std::ifstream& ifs) { @@ -93,11 +97,15 @@ void load(Tensor& tensor, std::ifstream& ifs) { } // shape, strides, data - ifs.read((char*)(t.shape().data()), - std::streamsize(dimCount * sizeof(int32_t))); - ifs.read((char*)(t.strides().data()), - std::streamsize(dimCount * sizeof(int32_t))); - ifs.read((char*)(&t[0]), std::streamsize(elemCount * sizeof(float))); + if (dimCount > 0) { + ifs.read((char*)(t.shape().data()), + std::streamsize(dimCount * sizeof(int32_t))); + ifs.read((char*)(t.strides().data()), + std::streamsize(dimCount * sizeof(int32_t))); + } + if (elemCount > 0) { + ifs.read((char*)(&t[0]), std::streamsize(elemCount * sizeof(float))); + } } void save(nn::Module& model, const char* path) { @@ -107,7 +115,7 @@ void save(nn::Module& model, const char* path) { return; } - auto params = model.parameters(); + auto params = model.states(); for (auto& param : params) { save(*param, outFile); } @@ -120,7 +128,7 @@ void load(nn::Module& model, const char* path) { return; } - auto params = model.parameters(); + auto params = model.states(); for (auto& param : params) { load(*param, inFile); } diff --git a/TinyTorch/Torch.h b/TinyTorch/Torch.h index 8976a57..015ee2d 100644 --- a/TinyTorch/Torch.h +++ b/TinyTorch/Torch.h @@ -25,7 +25,7 @@ constexpr float PI = 3.1415926535f; void manualSeed(unsigned int seed); -void print(const Tensor& tensor); +void print(const Tensor& tensor, bool full = false); void save(const Tensor& tensor, std::ofstream& ofs); diff --git a/demo/demo_mnist.cpp b/demo/demo_mnist.cpp index a49150d..775d751 100644 --- a/demo/demo_mnist.cpp +++ b/demo/demo_mnist.cpp @@ -84,6 +84,8 @@ void demo_mnist() { Timer timer; timer.start(); + manualSeed(0); + // config auto lr = 1.f; auto epochs = 2; diff --git a/demo/demo_module.cpp b/demo/demo_module.cpp index f633e0e..bbfdfa3 100644 --- a/demo/demo_module.cpp +++ b/demo/demo_module.cpp @@ -14,6 +14,8 @@ void demo_module() { Timer timer; timer.start(); + manualSeed(0); + auto x = Tensor::linspace(-PI, PI, 2000); auto y = x.sin(); @@ -43,8 +45,8 @@ void demo_module() { } auto* linearLayer = dynamic_cast(&model[0]); - auto& biasData = linearLayer->Bias().data(); - auto& weightData = linearLayer->Weights().data(); + auto& biasData = linearLayer->bias().data(); + auto& weightData = linearLayer->weights().data(); LOGD("Result: y = %f + %f x + %f x^2 + %f x^3", biasData[0], weightData[0], weightData[1], weightData[2]); diff --git a/demo/demo_optim.cpp b/demo/demo_optim.cpp index eaa2799..a2e371f 100644 --- a/demo/demo_optim.cpp +++ b/demo/demo_optim.cpp @@ -14,6 +14,8 @@ void demo_optim() { Timer timer; timer.start(); + manualSeed(0); + auto x = Tensor::linspace(-PI, PI, 2000); auto y = x.sin(); @@ -39,8 +41,8 @@ void demo_optim() { } auto* linearLayer = dynamic_cast(&model[0]); - auto& biasData = linearLayer->Bias().data(); - auto& weightData = linearLayer->Weights().data(); + auto& biasData = linearLayer->bias().data(); + auto& weightData = linearLayer->weights().data(); LOGD("Result: y = %f + %f x + %f x^2 + %f x^3", biasData[0], weightData[0], weightData[1], weightData[2]); diff --git a/test/test_function.cpp b/test/test_function.cpp index 93075cd..0ec9a03 100644 --- a/test/test_function.cpp +++ b/test/test_function.cpp @@ -361,3 +361,30 @@ TEST(TEST_Function, func_conv2d_03) { ElementsAre(54, 63, 90, 99, 54, 63, 90, 99)); EXPECT_THAT(bias.getGrad().data().toArray(), ElementsAre(9)); } + +TEST(TEST_Function, func_batchNorm_2d) { + auto input = Tensor::arange(1.f, 24.5f, 1.f); + input = input.reshape({2, 3, 2, 2}); + input.setRequiresGrad(true); + auto runningMean = Tensor::zeros({3}); + auto runningVar = Tensor::ones({3}); + auto weight = Tensor::ones({3}, true); + auto bias = Tensor::zeros({3}, true); + auto output = Function::batchNorm(input, runningMean, runningVar, weight, + bias, true, 0.2); + EXPECT_THAT(output.shape(), ElementsAre(2, 3, 2, 2)); + EXPECT_FLOAT_VEC_NEAR( + output.data().toArray(), + {-1.2288, -1.0650, -0.9012, -0.7373, -1.2288, -1.0650, -0.9012, -0.7373, + -1.2288, -1.0650, -0.9012, -0.7373, 0.7373, 0.9012, 1.0650, 1.2288, + 0.7373, 0.9012, 1.0650, 1.2288, 0.7373, 0.9012, 1.0650, 1.2288}); + + EXPECT_FLOAT_VEC_NEAR(runningMean.data().toArray(), {1.7000, 2.5000, 3.3000}); + EXPECT_FLOAT_VEC_NEAR(runningVar.data().toArray(), {9.3143, 9.3143, 9.3143}); + + output.backward(Tensor::onesLike(output)); + EXPECT_FLOAT_VEC_NEAR(input.getGrad().data().toArray(), + TensorImpl::zeros({input.size()}).toArray()); + EXPECT_FLOAT_VEC_NEAR(weight.getGrad().data().toArray(), {0., 0., 0.}); + EXPECT_FLOAT_VEC_NEAR(bias.getGrad().data().toArray(), {8., 8., 8.}); +} diff --git a/test/test_module.cpp b/test/test_module.cpp index 764b852..47367a1 100644 --- a/test/test_module.cpp +++ b/test/test_module.cpp @@ -12,8 +12,8 @@ using namespace TinyTorch; TEST(TEST_Module, linear) { auto layer = nn::Linear(4, 4, true); - layer.Weights().data().fill(1.2f); - layer.Bias().data().fill(0.2f); + layer.weights().data().fill(1.2f); + layer.bias().data().fill(0.2f); auto input = Tensor({{1, 2, 3, 4}, {5, 6, 7, 8}}); auto output = layer(input); @@ -24,11 +24,11 @@ TEST(TEST_Module, linear) { EXPECT_FLOAT_EQ(loss.item(), 4490.4165f); EXPECT_FLOAT_VEC_NEAR( - layer.Weights().getGrad().data().toArray(), + layer.weights().getGrad().data().toArray(), {346.306305, 433.741211, 521.176147, 608.611, 339.37558, 425.315826, 511.256042, 597.196289, 331.547913, 417.151703, 502.755493, 588.359314, 330.02002, 416.754913, 503.489807, 590.224731}); - EXPECT_FLOAT_VEC_NEAR(layer.Bias().getGrad().data().toArray(), + EXPECT_FLOAT_VEC_NEAR(layer.bias().getGrad().data().toArray(), {87.434906, 85.940239, 85.6037903, 86.7348938}); } @@ -61,3 +61,29 @@ TEST(TEST_Module, dropout) { EXPECT_THAT(output.shape(), ElementsAre(2, 4)); EXPECT_TRUE((output.data() == 0).sum() > 0); } + +TEST(TEST_Module, batchNorm2d) { + auto input = Tensor::arange(1.f, 24.5f, 1.f); + input = input.reshape({2, 3, 2, 2}); + input.setRequiresGrad(true); + + auto bn = nn::BatchNorm2D(3); + auto output = bn(input); + + auto target = Tensor(input.data() * 1000.f); + auto lossFn = nn::MSELoss(); + auto loss = lossFn(output, target); + loss.backward(); + + auto &dW = bn.weights().getGrad(); + auto &dB = bn.bias().getGrad(); + auto &runningMean = bn.runningMean(); + auto &runningVar = bn.runningVar(); + + EXPECT_FLOAT_VEC_NEAR(dW.data().toArray(), + {-4068.18481, -4068.18457, -4068.18505}); + EXPECT_FLOAT_VEC_NEAR(dB.data().toArray(), {-5666.6665, -8333.3330, -11000.}); + EXPECT_FLOAT_VEC_NEAR(runningMean.data().toArray(), {0.85, 1.25, 1.65}); + EXPECT_FLOAT_VEC_NEAR(runningVar.data().toArray(), + {5.15714, 5.15714, 5.15714}); +} diff --git a/test/test_tensorimpl.cpp b/test/test_tensorimpl.cpp index 1fbb1e3..5dcada2 100644 --- a/test/test_tensorimpl.cpp +++ b/test/test_tensorimpl.cpp @@ -565,6 +565,14 @@ TEST(TEST_TensorImpl, math_meam) { y = TensorImpl::mean(x, 1); EXPECT_THAT(y.shape(), ElementsAre(2)); EXPECT_THAT(y.toArray(), ElementsAre(2, 5)); + + y = TensorImpl::mean(x, {0, 1}, true); + EXPECT_THAT(y.shape(), ElementsAre(1, 1)); + EXPECT_THAT(y.toArray(), ElementsAre(3.5)); + + y = TensorImpl::mean(x, {0, 1}, false); + EXPECT_TRUE(y.isScalar()); + EXPECT_THAT(y.toArray(), ElementsAre(3.5)); } TEST(TEST_TensorImpl, math_sum) { @@ -580,6 +588,10 @@ TEST(TEST_TensorImpl, math_sum) { EXPECT_THAT(y.shape(), ElementsAre(2)); EXPECT_THAT(y.toArray(), ElementsAre(6, 15)); + y = TensorImpl::sum(x, {0, 1}, true); + EXPECT_THAT(y.shape(), ElementsAre(1, 1)); + EXPECT_THAT(y.toArray(), ElementsAre(21)); + x = TensorImpl({{{4, 2, 3}, {1, 0, 3}}, {{4, 2, 3}, {1, 0, 3}}}); EXPECT_TRUE(TensorImpl::sum(x) == 26); @@ -596,22 +608,59 @@ TEST(TEST_TensorImpl, math_sum) { EXPECT_THAT(y.toArray(), ElementsAre(8, 4, 6, 2, 0, 6)); } -TEST(TEST_TensorImpl, math_var) { +TEST(TEST_TensorImpl, math_var_01) { TensorImpl x({{1, 2, 3}, {4, 5, 6}}); - EXPECT_FLOAT_NEAR(TensorImpl::var(x), 2.9166666); + EXPECT_FLOAT_NEAR(TensorImpl::var(x, false), 2.9166666); + EXPECT_FLOAT_NEAR(TensorImpl::var(x), 3.5); - auto y = TensorImpl::var(x, 0); - EXPECT_THAT(y.shape(), ElementsAre(3)); - EXPECT_THAT(y.toArray(), ElementsAre(2.25, 2.25, 2.25)); - - y = TensorImpl::var(x, 1); - EXPECT_THAT(y.shape(), ElementsAre(2)); - EXPECT_FLOAT_NEAR(y[0], 0.666667); - EXPECT_FLOAT_NEAR(y[1], 0.666667); -} + auto y = TensorImpl::var(x, Axis(0), true, true); + EXPECT_THAT(y.shape(), ElementsAre(1, 3)); + EXPECT_THAT(y.toArray(), ElementsAre(4.5, 4.5, 4.5)); -TEST(TEST_TensorImpl, math_argmin) { + y = TensorImpl::var(x, Axis(1), true, true); + EXPECT_THAT(y.shape(), ElementsAre(2, 1)); + EXPECT_FLOAT_NEAR(y[0], 1.0); + EXPECT_FLOAT_NEAR(y[1], 1.0); + + y = TensorImpl::var(x, {0, 1}, true, true); + EXPECT_THAT(y.shape(), ElementsAre(1, 1)); + EXPECT_EQ(y.item(), 3.5); +} + +TEST(TEST_TensorImpl, math_var_02) { + TensorImpl x({3.14, 7.89, 1.23, 4.56, 9.01, 2.34, 5.67, 8.90, + 0.12, 6.78, 3.45, 7.12, 1.56, 4.89, 9.34, 2.67, + 5.89, 8.23, 0.45, 6.12, 3.78, 7.45, 1.89, 4.23, + 9.56, 2.12, 5.34, 8.67, 0.78, 6.45, 3.12, 7.78}); + x = x.reshape({2, 2, 2, 4}); + + auto y = TensorImpl::var(x, Axis(0), true, true); + EXPECT_THAT(y.shape(), ElementsAre(1, 2, 2, 4)); + EXPECT_FLOAT_VEC_NEAR( + y.toArray(), + {3.7812, 0.0578, 0.3042, 1.2168, 13.6765, 13.0560, 7.1442, 10.9044, + 44.5568, 10.8578, 1.7861, 1.2013, 0.3042, 1.2168, 19.3442, 13.0561}); + + y = TensorImpl::var(x, Axis(1), true, true); + EXPECT_THAT(y.shape(), ElementsAre(2, 1, 2, 4)); + EXPECT_FLOAT_VEC_NEAR( + y.toArray(), + {4.5602, 0.6160, 2.4642, 3.2768, 27.7513, 3.2512, 6.7345, 19.4064, 6.7345, + 18.6660, 11.9561, 3.2513, 4.5000, 0.5000, 0.7564, 6.3013}); + + y = TensorImpl::var(x, {0, 1}, true, true); + EXPECT_THAT(y.shape(), ElementsAre(1, 1, 2, 4)); + EXPECT_FLOAT_VEC_NEAR(y.toArray(), {16.1479, 7.9826, 4.9094, 2.9820, 13.7604, + 4.9578, 10.8303, 8.5854}); + + y = TensorImpl::var(x, {1, 2}, true, true); + EXPECT_THAT(y.shape(), ElementsAre(2, 1, 1, 4)); + EXPECT_FLOAT_VEC_NEAR(y.toArray(), {15.2235, 5.9019, 11.9586, 7.5621, 13.6275, + 7.4389, 4.2882, 3.8282}); +} + +TEST(TEST_TensorImpl, math_argmin_01) { TensorImpl x({{4, 2, 3}, {1, 0, 3}}); EXPECT_TRUE(TensorImpl::argmin(x) == 4); @@ -625,6 +674,18 @@ TEST(TEST_TensorImpl, math_argmin) { EXPECT_THAT(y.toArray(), ElementsAre(1, 1)); } +TEST(TEST_TensorImpl, math_argmin_02) { + TensorImpl x({3.14, 7.89, 1.23, 4.56, 9.01, 2.34, 5.67, 8.90, + 0.12, 6.78, 3.45, 7.12, 1.56, 4.89, 9.34, 2.67, + 5.89, 8.23, 0.45, 6.12, 3.78, 7.45, 1.89, 4.23, + 9.56, 2.12, 5.34, 8.67, 0.78, 6.45, 3.12, 7.78}); + x = x.reshape({2, 2, 2, 4}); + auto y = TensorImpl::argmin(x, Axis(2), true); + EXPECT_THAT(y.shape(), ElementsAre(2, 2, 1, 4)); + EXPECT_THAT(y.toArray(), + ElementsAre(0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1)); +} + TEST(TEST_TensorImpl, math_argmax) { TensorImpl x({{1, 2, 4}, {1, 0, 3}});