You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2021/12/07 16:47:21 UTC

[incubator-mxnet] 01/01: Fix oneDNN fallback for concat with scalar

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch concat_fallback_fix
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit c796cea22b12168b2cb2925eae333c065ccbdd98
Author: Bartlomiej Gawrych <ba...@intel.com>
AuthorDate: Fri Dec 3 10:45:52 2021 +0100

    Fix oneDNN fallback for concat with scalar
---
 src/operator/nn/concat.cc                | 2 +-
 src/operator/nn/dnnl/dnnl_log_softmax.cc | 4 ++--
 src/operator/nn/dnnl/dnnl_softmax.cc     | 3 ++-
 src/operator/nn/log_softmax.cc           | 4 ----
 src/operator/nn/softmax.cc               | 4 ----
 tests/python/unittest/test_numpy_op.py   | 4 +++-
 6 files changed, 8 insertions(+), 13 deletions(-)

diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index f5a6f7f..2c83a08 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -253,7 +253,7 @@ bool SupportDNNLConcat(const std::vector<NDArray>& arrs) {
     if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() == mshadow::kBfloat16))
       return false;
     // DO not support zero-size tensors.
-    if (arr.shape().Size() == 0)
+    if (arr.shape().Size() == 0 || arr.shape().ndim() == 0)
       return false;
     int ndim             = arr.shape().ndim();
     const int dnnl_ndims = arr.GetDNNLData()->get_desc().data.ndims;
diff --git a/src/operator/nn/dnnl/dnnl_log_softmax.cc b/src/operator/nn/dnnl/dnnl_log_softmax.cc
index 9408e60..a3c8c90 100644
--- a/src/operator/nn/dnnl/dnnl_log_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_log_softmax.cc
@@ -60,8 +60,8 @@ bool SupportDNNLLogSoftmax(const SoftmaxParam& param, const NDArray& data, const
   // DNNL does not support temperature argument in their log_softmax function
   // now. Need update this once they start to support it.
   // Currently, DNNL shows bad performance when log_softmax is not performed on the last dimension
-  if (param.temperature.has_value() || in_dtype != mshadow::kFloat32 || in_dtype != out_dtype ||
-      axis != (ndim - 1)) {
+  if (data.shape().Size() == 0 || data.shape().ndim() == 0 || param.temperature.has_value() ||
+      in_dtype != mshadow::kFloat32 || in_dtype != out_dtype || axis != (ndim - 1)) {
     return false;
   }
 
diff --git a/src/operator/nn/dnnl/dnnl_softmax.cc b/src/operator/nn/dnnl/dnnl_softmax.cc
index 72a25d4..48c2944 100644
--- a/src/operator/nn/dnnl/dnnl_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_softmax.cc
@@ -31,6 +31,7 @@ namespace op {
 
 bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& data, const NDArray& output) {
   const int ndim      = data.shape().ndim();
+  const int in_size   = data.shape().Size();
   const int in_dtype  = data.dtype();
   const int out_dtype = output.dtype();
   const int axis      = CheckAxis(param.axis, ndim);
@@ -44,7 +45,7 @@ bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& data, const ND
   }
 
   // Supports ndim up to 6
-  return (ndim >= 1 && ndim <= 6);
+  return (ndim >= 1 && ndim <= 6 && in_size != 0);
 }
 
 void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/nn/log_softmax.cc b/src/operator/nn/log_softmax.cc
index 197f892..f56e7ac 100644
--- a/src/operator/nn/log_softmax.cc
+++ b/src/operator/nn/log_softmax.cc
@@ -39,8 +39,6 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
                                    const std::vector<NDArray>& inputs,
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
-  if (inputs[0].shape().Size() == 0U)
-    return;
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   if (SupportDNNLLogSoftmax(param, inputs[0], outputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
@@ -57,8 +55,6 @@ static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                        const std::vector<NDArray>& inputs,
                                        const std::vector<OpReqType>& req,
                                        const std::vector<NDArray>& outputs) {
-  if (inputs[0].shape().Size() == 0U)
-    return;
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   if (SupportDNNLLogSoftmax(param, inputs[1], outputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc
index 5b9c4ae..29f54645 100644
--- a/src/operator/nn/softmax.cc
+++ b/src/operator/nn/softmax.cc
@@ -41,8 +41,6 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const std::vector<NDArray>& inputs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
-  if (inputs[0].shape().Size() == 0U)
-    return;
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   if (SupportDNNLSoftmax(param, inputs[0], outputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
@@ -59,8 +57,6 @@ static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                     const std::vector<NDArray>& inputs,
                                     const std::vector<OpReqType>& req,
                                     const std::vector<NDArray>& outputs) {
-  if (inputs[0].shape().Size() == 0U)
-    return;
   const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
   if (SupportDNNLSoftmax(param, inputs[1], outputs[0])) {
     DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 99aacbf..bd3624b 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -4096,7 +4096,7 @@ def test_np_concat():
             shape_lst[axis] = random.randint(0, 3)
         return tuple(shape_lst)
 
-    shapes = [(0, 0), (2, 3), (2, 1, 3)]
+    shapes = [(), (0, 0), (2, 3), (2, 1, 3)]
     hybridizes = [True, False]
     axes = [0, 1, -1, None]
     grad_reqs = ['write', 'add', 'null']
@@ -4105,6 +4105,8 @@ def test_np_concat():
 
     for shape, hybridize, axis, grad_req, dtype in combinations:
         # test gluon
+        if shape == () and axis != None:
+            continue
         test_concat = TestConcat(axis=axis)
         if hybridize:
             test_concat.hybridize()