-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ layer ] Mixed precision forwarding / backwarding for bn layer @open sesame 03/07 10:42 #2462
Conversation
📝 TAOS-CI Version: 1.5.20200925. Thank you for submitting PR #2462. Please a submit 1commit/1PR (one commit per one PR) policy to get comments quickly from reviewers. Your PR must pass all verificiation processes of cibot before starting a review process from reviewers. If you are new member to join this project, please read manuals in documentation folder and wiki page. In order to monitor a progress status of your PR in more detail, visit http://ci.nnstreamer.ai/. |
412da0f
to
11006c0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skykongkong8, 💯 All CI checkers are successfully verified. Thanks.
nntrainer/layers/bn_layer.cpp
Outdated
TensorDim mu_dim = mu.getDim(); | ||
mu_dim.setDataType(ml::train::TensorDim::DataType::FP32); | ||
Tensor mu32(mu_dim, true); | ||
mu32.copyData(mu); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be reduced?
Tensor getFloatTensor(Tensor input) {
Tensor output({input.getFormat(), ml::train::TensorDim::DataType::FP32}, true);
ouput.copyData(input);
return output;
}
Tensor mu32 = getFloatTensor(mu);
Tensor var32 = getFloatTensor(var);
Tensor gamma = getFloatTensor(gamma);
...
nntrainer/layers/bn_layer.cpp
Outdated
@@ -213,42 +306,151 @@ void BatchNormalizationLayer::calcDerivative(RunLayerContext &context) { | |||
|
|||
Tensor &t_reduced = context.getTensor(wt_idx[BNParams::t_reduced]); | |||
Tensor &t_full = context.getTensor(wt_idx[BNParams::t_full]); | |||
if (deriv.getDataType() == ml::train::TensorDim::DataType::FP16) { | |||
#ifdef ENABLE_FP16 | |||
TensorDim gamma_dim = gamma.getDim(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Applied!
- According to recent papers in mixed precision, when it comes to computation of statistics in bn layer, we should use fp32 values - input(fp16) <-> BN_layer(fp16 weight, but compute with fp32 and re-cast to fp16) <-> output(fp16) - Although we need a bulky code block for this, I believe we can revisit here for cleaner code when TensorV2 becomes official **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 <[email protected]>
11006c0
to
534d25c
Compare
534d25c
to
bd50fbe
Compare
- For mixed precision training computation of BN layer, we have been declaring temporal single-precision Tensors, compute it with Tensor ops, and save the result in previous half-precision. - To remove redundant code, declare temporal function and reuse. - We should revisit here fore even clearer code when TensorV2 refactorization is finished **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: skykongkong8 <[email protected]>
bd50fbe
to
f48e585
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skykongkong8, 💯 All CI checkers are successfully verified. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
/** | ||
* This calculates dgamma tensor. | ||
*/ | ||
Tensor dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it Tensor &dgamma = context.getWeightGrad(wt_idx[BNParams::gamma]);
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@skykongkong8, 💯 All CI checkers are successfully verified. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think this is the right way to enable the mixed precision. please consider #2455
This PR is no longer needed. Close. |
Self evaluation: