You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/08 09:09:01 UTC

[incubator-mxnet] branch master updated: [BUGFIX] Type fix for large tensors (#20922)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a49008  [BUGFIX] Type fix for large tensors (#20922)
7a49008 is described below

commit 7a49008de42e97c3f5d1634955e0db4954739907
Author: DominikaJedynak <do...@intel.com>
AuthorDate: Tue Mar 8 10:06:59 2022 +0100

    [BUGFIX] Type fix for large tensors (#20922)
    
    * Dimension type fix
    
    * Review suggestion
---
 src/operator/subgraph/dnnl/dnnl_conv.cc |  4 ++--
 src/operator/subgraph/dnnl/dnnl_fc.cc   | 12 ++++++------
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc b/src/operator/subgraph/dnnl/dnnl_conv.cc
index ccaabdd..4936d8c 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv.cc
+++ b/src/operator/subgraph/dnnl/dnnl_conv.cc
@@ -59,11 +59,11 @@ static void UpdateConvWeightBias(NDArray* weight,
   const float* var_ptr     = variance.data().dptr<float>();
   DType* update_weight_ptr = update_weight.data().dptr<DType>();
   DType* update_bias_ptr   = update_bias.data().dptr<DType>();
-  size_t channel           = gamma.shape()[0];
+  index_t channel          = static_cast<index_t>(gamma.shape()[0]);
   const auto wshape        = weight->shape();
   size_t offset            = wshape.ProdShape(1, wshape.ndim());
 #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
-  for (int c = 0; c < static_cast<int>(channel); ++c) {
+  for (index_t c = 0; c < channel; ++c) {
     const DType* p1 = weight_ptr + c * offset;
     DType* p2       = update_weight_ptr + c * offset;
     float alpha     = (param->fix_gamma ? 1.0f : gamma_ptr[c]) / sqrt(var_ptr[c] + param->eps);
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 8ead3e7..db75ccd 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -235,15 +235,15 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
     const mxnet::TShape oshape = output.shape();
     dnnl::memory::dims out_dims(2);
     if (oshape.ndim() == 2) {
-      out_dims[0] = static_cast<int>(oshape[0]);
-      out_dims[1] = static_cast<int>(oshape[1]);
+      out_dims[0] = static_cast<index_t>(oshape[0]);
+      out_dims[1] = static_cast<index_t>(oshape[1]);
     } else {
       if (!default_param.flatten) {
-        out_dims[0] = static_cast<int>(oshape.ProdShape(0, oshape.ndim() - 1));
-        out_dims[1] = static_cast<int>(oshape[oshape.ndim() - 1]);
+        out_dims[0] = static_cast<index_t>(oshape.ProdShape(0, oshape.ndim() - 1));
+        out_dims[1] = static_cast<index_t>(oshape[oshape.ndim() - 1]);
       } else {
-        out_dims[0] = static_cast<int>(static_cast<int>(oshape[0]));
-        out_dims[1] = static_cast<int>(oshape.ProdShape(1, oshape.ndim()));
+        out_dims[0] = static_cast<index_t>(oshape[0]);
+        out_dims[1] = static_cast<index_t>(oshape.ProdShape(1, oshape.ndim()));
       }
     }
     dnnl::memory::desc out_md =