You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2020/03/23 16:56:05 UTC

[incubator-mxnet] branch master updated: Use FP32 copy of weights for norm (multitensor LAMB optimizer) (#17700)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e39518  Use FP32 copy of weights for norm (multitensor LAMB optimizer) (#17700)
8e39518 is described below

commit 8e3951876b3598c8b52606a467add5f239d88b38
Author: MoisesHer <50...@users.noreply.github.com>
AuthorDate: Mon Mar 23 09:55:24 2020 -0700

    Use FP32 copy of weights for norm (multitensor LAMB optimizer) (#17700)
    
    * Use fp32 copy of weights for computing norm in LAMB optimizer
    
    * Fix cpplint
---
 src/operator/contrib/multi_lamb-inl.h | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/src/operator/contrib/multi_lamb-inl.h b/src/operator/contrib/multi_lamb-inl.h
index 7fb186f..256445a 100644
--- a/src/operator/contrib/multi_lamb-inl.h
+++ b/src/operator/contrib/multi_lamb-inl.h
@@ -282,10 +282,14 @@ inline void MultiLAMB(const nnvm::NodeAttrs& attrs,
     FillMultiLAMBKernelParam<xpu, DType, MPDType, MultiLAMBParam, input_stride>
             (attrs, ctx, inputs, outputs, &kernel_params);
 
-    // create vector of TBlob with all the weights contiguous
-    std::vector<TBlob> weights;
+    // create vector of TBlob with all the weights contiguous to compute the norm
+    // if mixed precision, use fp32 copy
+    std::vector<TBlob> weights_for_norm;
+    int position_weights = 0;
+    if (!std::is_same<DType, MPDType>::value)
+      position_weights = input_stride - 1;
     for (size_t index = 0; index < kernel_params.ntensors; ++index) {
-        weights.emplace_back(inputs[index*input_stride]);
+      weights_for_norm.emplace_back(inputs[index * input_stride + position_weights]);
     }
 
     // Calculate amount of temporary storage (temp_g, r1, r2, block_to_tensor, block_to_chunk)
@@ -327,7 +331,7 @@ inline void MultiLAMB(const nnvm::NodeAttrs& attrs,
     Tensor<xpu, 1, int> block_to_chunk(reinterpret_cast<int*>(&workspace[pos_wspace]),
       Shape1(kernel_params.nchunks), s);
 
-    MultiSumSqRun<xpu>(weights, kernel_params.ntensors, r1.dptr_, ctx);
+    MultiSumSqRun<xpu>(weights_for_norm, kernel_params.ntensors, r1.dptr_, ctx);
     CallKernel1<MPDType, DType>(s, kernel_params, param, temp_g.dptr_,
                                 block_to_tensor.dptr_,
                                 block_to_chunk.dptr_);