You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/02/05 15:35:02 UTC

[incubator-mxnet] branch v1.x updated: [BUGFIX] Fix AmpCast for float16 (#19749)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 0a65920  [BUGFIX] Fix AmpCast for float16 (#19749)
0a65920 is described below

commit 0a65920c8e6583a93e0ec22477c581ee15ea3a02
Author: Andrzej Kotłowski <An...@intel.com>
AuthorDate: Fri Feb 5 16:33:18 2021 +0100

    [BUGFIX] Fix AmpCast for float16 (#19749)
    
    * Fix AmpCast for float16
    
    OneDNN doesn't support float16 format so fallback to standard
    implementation is needed.
    It fixes issue 19631.
    
    * Enable amp_cast test for float16 on CPU context
---
 CONTRIBUTORS.md                        |  1 +
 src/operator/tensor/amp_cast.cc        | 38 +++++++++++++++++++---------------
 tests/python/unittest/test_operator.py |  3 +--
 3 files changed, 23 insertions(+), 19 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 286b142..cd5903d 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -254,6 +254,7 @@ List of Contributors
 * [Joe Evans](https://github.com/josephevans)
 * [Zhaoqi Zhu](https://github.com/zha0q1)
 * [Harshit Sharma](https://github.com/harshitshrma)
+* [Andrzej Kotlowski](https://github.com/anko-intel)
 
 Label Bot
 ---------
diff --git a/src/operator/tensor/amp_cast.cc b/src/operator/tensor/amp_cast.cc
index 7690783..964cd58 100644
--- a/src/operator/tensor/amp_cast.cc
+++ b/src/operator/tensor/amp_cast.cc
@@ -41,25 +41,29 @@ static void AMPCastExCPU(const nnvm::NodeAttrs& attrs,
   if (req[0] == kWriteInplace) {
     return;
   }
-  mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
   auto data = inputs[0];
-  if (data.IsView() && data.IsMKLDNNData())
-    data = data.Reorder2Default();
-  const auto i_mem = data.GetMKLDNNData();
-  const size_t i_ndim = data.shape().ndim();
-  mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
-  for (size_t i = 0; i < i_ndim; i++) {
-    i_dims[i] = static_cast<int>(data.shape()[i]);
+  if (data.dtype() != mshadow::kFloat16 && outputs[0].dtype() != mshadow::kFloat16) {
+    mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
+    if (data.IsView() && data.IsMKLDNNData())
+      data = data.Reorder2Default();
+    const auto i_mem = data.GetMKLDNNData();
+    const size_t i_ndim = data.shape().ndim();
+    mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
+    for (size_t i = 0; i < i_ndim; i++) {
+      i_dims[i] = static_cast<int>(data.shape()[i]);
+    }
+    const auto o_desc =
+        mkldnn::memory::desc(i_dims, get_mkldnn_type(outputs[0].dtype()),
+                            static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(i_ndim)));
+    const auto out_mem = CreateMKLDNNMem(outputs[0], o_desc, req[0]);
+    mkldnn_args_map_t reorder_args;
+    reorder_args[MKLDNN_ARG_SRC] = *i_mem;
+    reorder_args[MKLDNN_ARG_DST] = *out_mem.second;
+    MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*i_mem, *out_mem.second), reorder_args);
+    MKLDNNStream::Get()->Submit();
+    return;
   }
-  const auto o_desc =
-      mkldnn::memory::desc(i_dims, get_mkldnn_type(outputs[0].dtype()),
-                           static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(i_ndim)));
-  const auto out_mem = CreateMKLDNNMem(outputs[0], o_desc, req[0]);
-  mkldnn_args_map_t reorder_args;
-  reorder_args[MKLDNN_ARG_SRC] = *i_mem;
-  reorder_args[MKLDNN_ARG_DST] = *out_mem.second;
-  MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*i_mem, *out_mem.second), reorder_args);
-  MKLDNNStream::Get()->Submit();
+  FallBackCompute(AMPCastCompute<cpu>, attrs, ctx, inputs, req, outputs);
 }
 
 inline static bool AMPCastStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask,
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 92fd030..d02ff95 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -4856,8 +4856,7 @@ def test_cast_float32_to_float16():
                     fp32_val, model_fp16_val, np_fp16_val)
 
     check_cast(mx.sym.Cast, input_np, expected_output)
-    if default_context().device_type == 'gpu':
-        check_cast(mx.sym.amp_cast, input_np, expected_output)
+    check_cast(mx.sym.amp_cast, input_np, expected_output)
 
 
 @with_seed()