You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2020/03/23 16:56:05 UTC
[incubator-mxnet] branch master updated: Use FP32 copy of weights
for norm (multitensor LAMB optimizer) (#17700)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8e39518 Use FP32 copy of weights for norm (multitensor LAMB optimizer) (#17700)
8e39518 is described below
commit 8e3951876b3598c8b52606a467add5f239d88b38
Author: MoisesHer <50...@users.noreply.github.com>
AuthorDate: Mon Mar 23 09:55:24 2020 -0700
Use FP32 copy of weights for norm (multitensor LAMB optimizer) (#17700)
* Use fp32 copy of weights for computing norm in LAMB optimizer
* Fix cpplint
---
src/operator/contrib/multi_lamb-inl.h | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
diff --git a/src/operator/contrib/multi_lamb-inl.h b/src/operator/contrib/multi_lamb-inl.h
index 7fb186f..256445a 100644
--- a/src/operator/contrib/multi_lamb-inl.h
+++ b/src/operator/contrib/multi_lamb-inl.h
@@ -282,10 +282,14 @@ inline void MultiLAMB(const nnvm::NodeAttrs& attrs,
FillMultiLAMBKernelParam<xpu, DType, MPDType, MultiLAMBParam, input_stride>
(attrs, ctx, inputs, outputs, &kernel_params);
- // create vector of TBlob with all the weights contiguous
- std::vector<TBlob> weights;
+ // create vector of TBlob with all the weights contiguous to compute the norm
+ // if mixed precision, use fp32 copy
+ std::vector<TBlob> weights_for_norm;
+ int position_weights = 0;
+ if (!std::is_same<DType, MPDType>::value)
+ position_weights = input_stride - 1;
for (size_t index = 0; index < kernel_params.ntensors; ++index) {
- weights.emplace_back(inputs[index*input_stride]);
+ weights_for_norm.emplace_back(inputs[index * input_stride + position_weights]);
}
// Calculate amount of temporary storage (temp_g, r1, r2, block_to_tensor, block_to_chunk)
@@ -327,7 +331,7 @@ inline void MultiLAMB(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 1, int> block_to_chunk(reinterpret_cast<int*>(&workspace[pos_wspace]),
Shape1(kernel_params.nchunks), s);
- MultiSumSqRun<xpu>(weights, kernel_params.ntensors, r1.dptr_, ctx);
+ MultiSumSqRun<xpu>(weights_for_norm, kernel_params.ntensors, r1.dptr_, ctx);
CallKernel1<MPDType, DType>(s, kernel_params, param, temp_g.dptr_,
block_to_tensor.dptr_,
block_to_chunk.dptr_);