Skip to content

Commit

Permalink
Add nn::BatchNorm2D
Browse files Browse the repository at this point in the history
  • Loading branch information
keith2018 committed Jan 2, 2025
1 parent 521712c commit 49ac040
Show file tree
Hide file tree
Showing 18 changed files with 867 additions and 222 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Tiny deep learning training framework implemented from scratch in C++ that follo
- Module
- Linear
- Conv2D
- BatchNorm2D
- MaxPool2D
- Dropout
- Softmax
Expand Down
4 changes: 3 additions & 1 deletion TinyTorch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ if (MSVC)
set_source_files_properties(${TinyTorch_src} PROPERTIES COMPILE_FLAGS "/WX")
else ()
set_source_files_properties(${TinyTorch_src} PROPERTIES COMPILE_FLAGS "-Werror -Wno-deprecated-declarations")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
if (CMAKE_BUILD_TYPE STREQUAL Release)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
endif ()
endif ()

add_library(${PROJECT_NAME} ${TinyTorch_src})
Expand Down
107 changes: 107 additions & 0 deletions TinyTorch/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ std::unordered_map<FunctionType, std::string> Function::funcTypeToString_ = {
FUNC_ENUM_TO_STRING(Function_LogSoftmax),
FUNC_ENUM_TO_STRING(Function_MaxPool2D),
FUNC_ENUM_TO_STRING(Function_Conv2D),
FUNC_ENUM_TO_STRING(Function_BatchNorm),
FUNC_ENUM_TO_STRING(Function_MSELoss),
FUNC_ENUM_TO_STRING(Function_NLLLoss),
};
Expand Down Expand Up @@ -132,6 +133,15 @@ Tensor Function::conv2d(const Tensor& input, const Tensor& weight,
->callForward({&input, &weight, &bias});
}

Tensor Function::batchNorm(const Tensor& input, Tensor& runningMean,
Tensor& runningVar, const Tensor& weight,
const Tensor& bias, bool training, float momentum,
float eps) {
return std::make_shared<FuncBatchNorm>(runningMean, runningVar, momentum, eps,
training)
->callForward({&input, &weight, &bias});
}

Tensor Function::mseLoss(const Tensor& input, const Tensor& target,
LossReduction reduction) {
return std::make_shared<FuncMSELoss>(reduction)->callForward(
Expand Down Expand Up @@ -595,6 +605,103 @@ std::vector<TensorImpl> FuncConv2D::backward(const TensorImpl& grad) {
return ret;
}

TensorImpl FuncBatchNorm::forward(const std::vector<const Tensor*>& inputs) {
auto& input = inputs[0]->data();
auto& weight = inputs[1]->data();
auto& bias = inputs[2]->data();

auto& shape = input.shape();
assert(shape.size() == 3 || shape.size() == 4);

if (shape.size() == 3) {
dims_ = {0, 2};
viewShape_ = {1, shape[1], 1};
} else {
dims_ = {0, 2, 3};
viewShape_ = {1, shape[1], 1, 1};
}

Tensor mean;
Tensor var;
if (training_) {
mean.data() = input.mean(dims_, true);
var.data() = input.var(dims_, false, true);
auto varUnbiased = input.var(dims_, true, true);

if (!runningMean_.empty() && !runningVar_.empty()) {
runningMean_.data() *= 1.f - momentum_;
runningMean_.data() += TensorImpl::squeeze(mean.data()) * momentum_;
runningVar_.data() *= 1.f - momentum_;
runningVar_.data() += TensorImpl::squeeze(varUnbiased) * momentum_;
}
} else {
if (!runningMean_.empty() && !runningVar_.empty()) {
mean = runningMean_;
var = runningVar_;
} else {
mean.data() = input.mean(dims_, true);
var.data() = input.var(dims_, true, true);
}
}

auto inputCentered = Tensor(input - mean.data());
auto std = Tensor((var.data() + eps_).sqrt());
auto inputNorm = Tensor(inputCentered / std);

saveForBackward(inputs);
saveForBackward({&inputNorm, &inputCentered, &std});

if (!weight.empty()) {
inputNorm.data() = inputNorm.data() * weight.view(viewShape_);
}
if (!bias.empty()) {
inputNorm.data() = inputNorm.data() + bias.view(viewShape_);
}
return inputNorm.data();
}

std::vector<TensorImpl> FuncBatchNorm::backward(const TensorImpl& grad) {
const auto& savedTensors = getSavedTensors();
auto& input = savedTensors[0].data();
auto& weight = savedTensors[1].data();
// auto& bias = savedTensors[2].data();
auto& inputNorm = savedTensors[3].data();
auto& inputCentered = savedTensors[4].data();
auto& std = savedTensors[5].data();

std::vector<TensorImpl> ret;
// grad of input
if (savedTensors[0].isRequiresGrad()) {
auto dInputNorm = grad;
if (!weight.empty()) {
dInputNorm = dInputNorm * weight.view(viewShape_);
}
int32_t N = 1;
for (int dim : dims_) {
N *= input.shape()[dim];
}
auto dVar =
(dInputNorm * inputCentered * -0.5f * std.pow(-3.f)).sum(dims_, true);
auto dMean = (dInputNorm * -1.f / std).sum(dims_, true) +
dVar * (inputCentered * -2.f / (float)N).sum(dims_, true);
auto dInput = dInputNorm / std + dVar * 2.f * inputCentered / (float)N +
dMean / (float)N;
ret.push_back(dInput);
}
// grad of weight
if (savedTensors[1].isRequiresGrad()) {
auto dWeight = (grad * inputNorm).sum(dims_);
ret.push_back(dWeight);
}

// grad of bias
if (savedTensors[2].isRequiresGrad()) {
auto dBias = grad.sum(dims_);
ret.push_back(dBias);
}
return ret;
}

TensorImpl FuncMSELoss::forward(const std::vector<const Tensor*>& inputs) {
saveForBackward(inputs);
auto ret = TensorImpl::pow(inputs[0]->data() - inputs[1]->data(), 2);
Expand Down
34 changes: 31 additions & 3 deletions TinyTorch/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ enum FunctionType {
Function_LogSoftmax,
Function_MaxPool2D,
Function_Conv2D,
Function_BatchNorm,
Function_MSELoss,
Function_NLLLoss,
};
Expand Down Expand Up @@ -70,7 +71,8 @@ class Function : public std::enable_shared_from_this<Function> {
static Tensor linear(const Tensor& input, const Tensor& weight,
const Tensor& bias);

static Tensor dropout(const Tensor& input, float p, bool training = true);
static Tensor dropout(const Tensor& input, float p = 0.5f,
bool training = true);

static Tensor softmax(const Tensor& input, int32_t dim);

Expand All @@ -84,6 +86,11 @@ class Function : public std::enable_shared_from_this<Function> {
const Tensor& bias = {}, Size2D stride = 1,
Size2D padding = 0);

static Tensor batchNorm(const Tensor& input, Tensor& runningMean,
Tensor& runningVar, const Tensor& weight,
const Tensor& bias, bool training = false,
float momentum = 0.1f, float eps = 1e-5);

static Tensor nllloss(const Tensor& input, const Tensor& target,
LossReduction reduction = MEAN);

Expand All @@ -110,14 +117,13 @@ class Function : public std::enable_shared_from_this<Function> {
if (!NoGradScope::isGradEnabled()) {
return;
}
savedTensors_.reserve(tensors.size());
savedTensors_.reserve(savedTensors_.size() + tensors.size());
for (const auto& t : tensors) {
savedTensors_.push_back(*t);
}
}
std::vector<Tensor>& getSavedTensors() { return savedTensors_; };

protected:
std::weak_ptr<AutogradMeta> owner_;
std::vector<Tensor> savedTensors_;
std::vector<std::shared_ptr<Function>> nextFuncs_;
Expand Down Expand Up @@ -286,6 +292,28 @@ class FuncConv2D : public Function {
TensorImpl col_;
};

class FuncBatchNorm : public Function {
public:
explicit FuncBatchNorm(Tensor& runningMean, Tensor& runningVar,
float momentum, float eps, bool training)
: runningMean_(runningMean),
runningVar_(runningVar),
momentum_(momentum),
eps_(eps),
training_(training) {}
DEFINE_FUNCTION_MEMBERS(Function_BatchNorm)

private:
Tensor& runningMean_;
Tensor& runningVar_;
std::vector<int32_t> dims_;
std::vector<int32_t> viewShape_;

float momentum_;
float eps_;
bool training_;
};

class FuncMSELoss : public Function {
public:
explicit FuncMSELoss(LossReduction reduction) : reduction_(reduction) {}
Expand Down
93 changes: 93 additions & 0 deletions TinyTorch/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "Module.h"

#include <cassert>

#include "Function.h"
#include "Init.h"

Expand All @@ -21,6 +23,16 @@ std::vector<Tensor *> Module::parameters() {
return ret;
}

std::vector<Tensor *> Module::states() {
std::vector<Tensor *> ret;
for (auto &module : subModules_) {
for (auto p : module.get().states()) {
ret.push_back(p);
}
}
return ret;
}

void Module::resetParameters() {
for (auto &module : subModules_) {
module.get().resetParameters();
Expand Down Expand Up @@ -52,6 +64,16 @@ std::vector<Tensor *> Sequential::parameters() {
return ret;
}

std::vector<Tensor *> Sequential::states() {
std::vector<Tensor *> ret;
for (auto &module : modules_) {
for (auto p : module->states()) {
ret.push_back(p);
}
}
return ret;
}

void Sequential::resetParameters() {
for (auto &module : modules_) {
module->resetParameters();
Expand Down Expand Up @@ -91,6 +113,8 @@ std::vector<Tensor *> Linear::parameters() {
return {&weights_};
}

std::vector<Tensor *> Linear::states() { return parameters(); }

void Linear::resetParameters() {
Init::kaimingUniform(weights_, std::sqrt(5.f));
if (useBias_) {
Expand Down Expand Up @@ -156,6 +180,8 @@ std::vector<Tensor *> Conv2D::parameters() {
return {&weights_};
}

std::vector<Tensor *> Conv2D::states() { return parameters(); }

void Conv2D::resetParameters() {
Init::kaimingUniform(weights_, std::sqrt(5.f));
if (useBias_) {
Expand All @@ -174,4 +200,71 @@ void Conv2D::zeroGrad() {
}
}

BatchNorm2D::BatchNorm2D(int32_t numFeatures, float eps, float momentum,
bool affine, bool trackRunningStats)
: numFeatures_(numFeatures),
eps_(eps),
momentum_(momentum),
affine_(affine),
trackRunningStats_(trackRunningStats),
numBatchesTracked_(0) {
if (affine_) {
weights_ = Tensor::shape({numFeatures_}, true);
bias_ = Tensor::shape({numFeatures_}, true);
}
if (trackRunningStats_) {
runningMean_ = Tensor::shape({numFeatures_}, true);
runningVar_ = Tensor::shape({numFeatures_}, true);
}

BatchNorm2D::resetParameters();
}

Tensor BatchNorm2D::forward(Tensor &input) {
assert(input.dim() == 4);
if (training_ && trackRunningStats_) {
numBatchesTracked_++;
}

bool bnTrain = training_ || !trackRunningStats_;
return Function::batchNorm(input, runningMean_, runningVar_, weights_, bias_,
bnTrain, momentum_, eps_);
}

std::vector<Tensor *> BatchNorm2D::parameters() {
if (affine_) {
return {&weights_, &bias_};
}
return {};
}

std::vector<Tensor *> BatchNorm2D::states() {
std::vector<Tensor *> ret({&runningMean_, &runningVar_});
if (affine_) {
ret.push_back(&weights_);
ret.push_back(&bias_);
}
return ret;
}

void BatchNorm2D::resetParameters() {
if (affine_) {
weights_.data().fill(1.f);
bias_.data().fill(0.f);
}

if (trackRunningStats_) {
runningMean_.data().fill(0.f);
runningVar_.data().fill(1.f);
numBatchesTracked_ = 0;
}
}

void BatchNorm2D::zeroGrad() {
if (affine_) {
weights_.zeroGrad();
bias_.zeroGrad();
}
}

} // namespace TinyTorch::nn
Loading

0 comments on commit 49ac040

Please sign in to comment.