-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit implements the AdamW algorithm in the form of an `Optimizer`. Some (potential) further tasks include: - Merging the implementations of Adam and AdamW, so that `AdamW` inherits from `Adam` - Adding more unit tests Signed-off-by: Daniel Jang <[email protected]>
- Loading branch information
Showing
8 changed files
with
193 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
/** | ||
* Copyright (C) 2024 Daniel Jang <[email protected]> | ||
* | ||
* @file adamw.cpp | ||
* @date 3 November 2024 | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Jijoong Moon <[email protected]> | ||
* @author Parichay Kapoor <[email protected]> | ||
* @author Daniel Jang <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* @brief This is the AdamW Optimizer. | ||
*/ | ||
|
||
#include <cmath> | ||
#include <fstream> | ||
|
||
#include <adamw.h> | ||
#include <nntrainer_error.h> | ||
#include <nntrainer_log.h> | ||
#include <node_exporter.h> | ||
#include <util_func.h> | ||
|
||
namespace nntrainer { | ||
|
||
AdamW::AdamW() : adam_props(PropsB1(), PropsB2(), PropsEpsilon(), TorchRef()) { | ||
/** default properties */ | ||
auto &[b1, b2, eps, torch_ref] = adam_props; | ||
b1.set(0.9f); | ||
b2.set(0.999f); | ||
eps.set(1.0e-7f); | ||
torch_ref.set(false); | ||
} | ||
|
||
AdamW::~AdamW() {} | ||
|
||
enum AdamParams { wm, wv }; | ||
|
||
std::vector<TensorDim> AdamW::getOptimizerVariableDim(const TensorDim &dim) { | ||
return {dim, dim}; | ||
} | ||
|
||
void AdamW::exportTo(Exporter &exporter, | ||
const ml::train::ExportMethods &method) const { | ||
exporter.saveResult(adam_props, method, this); | ||
Optimizer::exportTo(exporter, method); | ||
} | ||
|
||
void AdamW::setProperty(const std::vector<std::string> &values) { | ||
auto left = loadProperties(values, adam_props); | ||
Optimizer::setProperty(left); | ||
} | ||
|
||
void AdamW::applyGradient(RunOptimizerContext &context) { | ||
Tensor &x_grad = context.getGradient(); | ||
|
||
auto &beta1 = std::get<PropsB1>(adam_props).get(); | ||
auto &beta2 = std::get<PropsB2>(adam_props).get(); | ||
auto &epsilon = std::get<PropsEpsilon>(adam_props).get(); | ||
auto &torch_ref = std::get<TorchRef>(adam_props).get(); | ||
|
||
// This is implementation of adam from original paper. | ||
// This is not deleted intentionally. | ||
unsigned int iteration = context.getIteration(); | ||
float biasCorrection1 = 1 - pow(beta1, iteration + 1); | ||
float biasCorrection2 = 1 - pow(beta2, iteration + 1); | ||
Tensor &wm = context.getOptimizerVariable(AdamParams::wm); | ||
Tensor &wv = context.getOptimizerVariable(AdamParams::wv); | ||
|
||
wm.multiply_i(beta1); | ||
wm.add_i(x_grad, 1.0f - beta1); | ||
|
||
wv.multiply_i(beta2); | ||
wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2); | ||
|
||
wv.divide_i(sqrtFloat(biasCorrection2)); | ||
std::function<double(double)> sqrtEps = [epsilon](double f) { | ||
return 1 / (sqrtDouble(f) + epsilon); | ||
}; | ||
Tensor &term = wv; | ||
term.apply<float>(sqrtEps); | ||
term.divide_i(biasCorrection1); | ||
term.multiply_i(wm); | ||
x_grad.add_i(term); | ||
|
||
context.applyGradient(context.getLearningRate()); | ||
} | ||
|
||
} // namespace nntrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
/** | ||
* Copyright (C) 2024 Daniel Jang <[email protected]> | ||
* | ||
* @file adamw.h | ||
* @date 3 November 2024 | ||
* @see https://github.com/nnstreamer/nntrainer | ||
* @author Jijoong Moon <[email protected]> | ||
* @author Parichay Kapoor <[email protected]> | ||
* @author Daniel Jang <[email protected]> | ||
* @bug No known bugs except for NYI items | ||
* @brief This is the AdamW Optimizer. | ||
*/ | ||
#ifndef __ADAMW_H__ | ||
#define __ADAMW_H__ | ||
#ifdef __cplusplus | ||
|
||
#include <tuple> | ||
|
||
#include <adam.h> | ||
|
||
#include <base_properties.h> | ||
#include <optimizer_devel.h> | ||
|
||
namespace nntrainer { | ||
|
||
/** | ||
* @class AdamW Optimizer class | ||
* @brief AdamW Optimizer | ||
*/ | ||
class AdamW : public Optimizer { | ||
public: | ||
/** | ||
* @brief Construct a new AdamW object | ||
* | ||
*/ | ||
AdamW(); | ||
|
||
/** | ||
* @brief Destroy the AdamW object | ||
* | ||
*/ | ||
~AdamW(); | ||
|
||
/** | ||
* @copydoc Optimizer::getDefaultLearningRate() | ||
* | ||
*/ | ||
double getDefaultLearningRate() const override { return 0.001; } | ||
|
||
/** | ||
* @copydoc applyGradient(RunOptimizerContext &context) | ||
*/ | ||
void applyGradient(RunOptimizerContext &context) override; | ||
|
||
/** | ||
* @copydoc Optimizer::getType() | ||
*/ | ||
const std::string getType() const override { return AdamW::type; } | ||
|
||
/** | ||
* @copydoc Optimizer::getOptimizerVariableDim(const TensorDim &dim) | ||
*/ | ||
std::vector<TensorDim> getOptimizerVariableDim(const TensorDim &dim) override; | ||
|
||
/** | ||
* @copydoc Optimizer::exportTo(Exporter &exporter, const | ||
* ml::train::ExportMethods& method) | ||
*/ | ||
void exportTo(Exporter &exporter, | ||
const ml::train::ExportMethods &method) const override; | ||
|
||
inline static const std::string type = "adamw"; | ||
|
||
/** | ||
* @copydoc Optimizer::setProperty(const std::vector<std::string> &values) | ||
*/ | ||
void setProperty(const std::vector<std::string> &values) override; | ||
|
||
private: | ||
std::tuple<PropsB1, PropsB2, PropsEpsilon, TorchRef> adam_props; | ||
}; | ||
} /* namespace nntrainer */ | ||
|
||
#endif /* __cplusplus */ | ||
#endif /* __ADAMW_H__ */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters