You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/08 09:09:01 UTC
[incubator-mxnet] branch master updated: [BUGFIX] Type fix for large tensors (#20922)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 7a49008 [BUGFIX] Type fix for large tensors (#20922)
7a49008 is described below
commit 7a49008de42e97c3f5d1634955e0db4954739907
Author: DominikaJedynak <do...@intel.com>
AuthorDate: Tue Mar 8 10:06:59 2022 +0100
[BUGFIX] Type fix for large tensors (#20922)
* Dimension type fix
* Review suggestion
---
src/operator/subgraph/dnnl/dnnl_conv.cc | 4 ++--
src/operator/subgraph/dnnl/dnnl_fc.cc | 12 ++++++------
2 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/src/operator/subgraph/dnnl/dnnl_conv.cc b/src/operator/subgraph/dnnl/dnnl_conv.cc
index ccaabdd..4936d8c 100644
--- a/src/operator/subgraph/dnnl/dnnl_conv.cc
+++ b/src/operator/subgraph/dnnl/dnnl_conv.cc
@@ -59,11 +59,11 @@ static void UpdateConvWeightBias(NDArray* weight,
const float* var_ptr = variance.data().dptr<float>();
DType* update_weight_ptr = update_weight.data().dptr<DType>();
DType* update_bias_ptr = update_bias.data().dptr<DType>();
- size_t channel = gamma.shape()[0];
+ index_t channel = static_cast<index_t>(gamma.shape()[0]);
const auto wshape = weight->shape();
size_t offset = wshape.ProdShape(1, wshape.ndim());
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
- for (int c = 0; c < static_cast<int>(channel); ++c) {
+ for (index_t c = 0; c < channel; ++c) {
const DType* p1 = weight_ptr + c * offset;
DType* p2 = update_weight_ptr + c * offset;
float alpha = (param->fix_gamma ? 1.0f : gamma_ptr[c]) / sqrt(var_ptr[c] + param->eps);
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 8ead3e7..db75ccd 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -235,15 +235,15 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
const mxnet::TShape oshape = output.shape();
dnnl::memory::dims out_dims(2);
if (oshape.ndim() == 2) {
- out_dims[0] = static_cast<int>(oshape[0]);
- out_dims[1] = static_cast<int>(oshape[1]);
+ out_dims[0] = static_cast<index_t>(oshape[0]);
+ out_dims[1] = static_cast<index_t>(oshape[1]);
} else {
if (!default_param.flatten) {
- out_dims[0] = static_cast<int>(oshape.ProdShape(0, oshape.ndim() - 1));
- out_dims[1] = static_cast<int>(oshape[oshape.ndim() - 1]);
+ out_dims[0] = static_cast<index_t>(oshape.ProdShape(0, oshape.ndim() - 1));
+ out_dims[1] = static_cast<index_t>(oshape[oshape.ndim() - 1]);
} else {
- out_dims[0] = static_cast<int>(static_cast<int>(oshape[0]));
- out_dims[1] = static_cast<int>(oshape.ProdShape(1, oshape.ndim()));
+ out_dims[0] = static_cast<index_t>(oshape[0]);
+ out_dims[1] = static_cast<index_t>(oshape.ProdShape(1, oshape.ndim()));
}
}
dnnl::memory::desc out_md =