You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/08/04 15:08:41 UTC

[incubator-mxnet] branch master updated: Add size threshold for few oneDNN operators (#21106)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d1fbe35d2 Add size threshold for few oneDNN operators (#21106)
6d1fbe35d2 is described below

commit 6d1fbe35d220a7789f5f79eac8bd55716d7d2492
Author: bgawrych <ba...@intel.com>
AuthorDate: Thu Aug 4 17:08:23 2022 +0200

    Add size threshold for few oneDNN operators (#21106)
    
    * Add size threshold for few oneDNN operators
    
    * add const
    
    * Add PR number
    
    * Apply suggestions from code review
    
    Co-authored-by: bartekkuncer <ba...@intel.com>
    
    * Refactor binary op
    
    * Fix true_divide
    
    Co-authored-by: Bartlomiej Gawrych <ba...@intel.com>
    Co-authored-by: bartekkuncer <ba...@intel.com>
---
 src/operator/nn/dnnl/dnnl_base-inl.h               |   2 +-
 src/operator/nn/dnnl/dnnl_binary.cc                |  13 ++-
 src/operator/nn/dnnl/dnnl_masked_softmax.cc        |   8 +-
 src/operator/nn/dnnl/dnnl_slice-inl.h              |  71 -------------
 src/operator/nn/dnnl/dnnl_slice.cc                 | 111 ---------------------
 src/operator/numpy/np_elemwise_broadcast_op.h      |   2 +-
 src/operator/numpy/np_true_divide-inl.h            |   2 +-
 .../tensor/elemwise_binary_broadcast_op_basic.cc   |   6 +-
 src/operator/tensor/matrix_op-inl.h                |  25 +----
 src/operator/tensor/matrix_op.cc                   |  19 +---
 10 files changed, 25 insertions(+), 234 deletions(-)

diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h b/src/operator/nn/dnnl/dnnl_base-inl.h
index 23105a8d20..7cf4eee5c5 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -251,7 +251,7 @@ struct SoftmaxOutputParam;
 bool SupportDNNLAct(const ActivationParam& param);
 bool SupportDNNLAct(const ActivationParam& param, const NDArray& input);
 bool SupportDNNLBatchDot(const std::vector<NDArray>& inputs);
-bool SupportDNNLBinary(const std::vector<NDArray>& inputs);
+bool SupportDNNLBinary(const std::vector<NDArray>& inputs, const std::vector<NDArray>& outputs);
 bool SupportDNNLConcat(const std::vector<NDArray>& arrs);
 bool SupportDNNLConv(const ConvolutionParam& params, const NDArray& input);
 bool SupportDNNLDeconv(const DeconvolutionParam& params, const NDArray& input);
diff --git a/src/operator/nn/dnnl/dnnl_binary.cc b/src/operator/nn/dnnl/dnnl_binary.cc
index 75c4805fb7..6566bfdcd6 100644
--- a/src/operator/nn/dnnl/dnnl_binary.cc
+++ b/src/operator/nn/dnnl/dnnl_binary.cc
@@ -64,9 +64,16 @@ void DNNLBinaryOpFwd::Execute(const std::vector<NDArray>& inputs,
 }
 
 // Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_binary.html
-bool SupportDNNLBinary(const std::vector<NDArray>& inputs) {
-  return SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[1]) &&
-         SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[0]);
+bool SupportDNNLBinary(const std::vector<NDArray>& inputs, const std::vector<NDArray>& outputs) {
+  // threshold value selected experimentally basing on performance results - PR-21106
+  constexpr size_t optimal_size_threshold = 2 << 13;
+  const bool threshold_condition          = outputs[0].shape().Size() >= optimal_size_threshold;
+  const bool is_any_dnnl_data =
+      inputs[0].IsDNNLData() || inputs[1].IsDNNLData() || outputs[0].IsDNNLData();
+
+  return SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[0]) &&
+         SupportDNNL<DNNLTypeMode::FloatTypes>(inputs[1]) &&
+         (threshold_condition || is_any_dnnl_data);
 }
 
 }  // namespace op
diff --git a/src/operator/nn/dnnl/dnnl_masked_softmax.cc b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
index c9561fc17b..d3c3f0e24e 100644
--- a/src/operator/nn/dnnl/dnnl_masked_softmax.cc
+++ b/src/operator/nn/dnnl/dnnl_masked_softmax.cc
@@ -31,12 +31,16 @@ namespace op {
 // Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_softmax.html
 bool SupportDNNLMaskedSoftmax(const MaskedSoftmaxParam& param, const std::vector<NDArray>& inputs) {
   CHECK_EQ(inputs.size(), 2);
+  const auto data = inputs[0];
   const auto mask = inputs[1];
   SoftmaxParam softmax_param;
   softmax_param.axis        = param.axis;
-  softmax_param.dtype       = inputs[0].dtype();
+  softmax_param.dtype       = data.dtype();
   softmax_param.temperature = param.temperature;
-  return mask.dtype() == mshadow::kBool && SupportDNNLSoftmax(softmax_param, inputs[0]);
+  // threshold value selected experimentally basing on performance results - PR-21106
+  constexpr size_t optimal_size_threshold = 2 << 13;
+  return data.shape().Size() >= optimal_size_threshold && mask.dtype() == mshadow::kBool &&
+         SupportDNNLSoftmax(softmax_param, data);
 }
 
 inline static dnnl::memory::dims GetOneDNNDims(const NDArray& arr) {
diff --git a/src/operator/nn/dnnl/dnnl_slice-inl.h b/src/operator/nn/dnnl/dnnl_slice-inl.h
deleted file mode 100644
index db4d562595..0000000000
--- a/src/operator/nn/dnnl/dnnl_slice-inl.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_slice-inl.h
- * \brief
- * \author Zhiyuan Huang
- */
-
-#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_SLICE_INL_H_
-#define MXNET_OPERATOR_NN_DNNL_DNNL_SLICE_INL_H_
-
-#if MXNET_USE_ONEDNN == 1
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-
-#include <utility>
-
-#include "operator/operator_common.h"
-#include "operator/tensor/slice-inl.h"
-#include "dnnl_base-inl.h"
-
-namespace mxnet {
-namespace op {
-
-class DNNLSliceFwd {
- public:
-  DNNLSliceFwd(const SliceParam& param, const NDArray& in, const NDArray& out);
-  void SetNewMem(const dnnl::memory& input, const dnnl::memory& output);
-  void Register();
-
- private:
-  std::shared_ptr<dnnl::memory> data_;
-  std::shared_ptr<dnnl::memory> out_;
-  std::shared_ptr<dnnl::reorder> fwd_;
-};
-
-typedef ParamOpSign<SliceParam> DNNLSliceSignature;
-DNNLSliceFwd& GetSliceForward(const SliceParam& param,
-                              const bool is_train,
-                              const NDArray& in_data,
-                              const NDArray& out_data);
-
-void DNNLSlice(const nnvm::NodeAttrs& attrs,
-               const OpContext& ctx,
-               const NDArray& in,
-               OpReqType req,
-               const NDArray& out);
-
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_USE_ONEDNN == 1
-#endif  // MXNET_OPERATOR_NN_DNNL_DNNL_SLICE_INL_H_
diff --git a/src/operator/nn/dnnl/dnnl_slice.cc b/src/operator/nn/dnnl/dnnl_slice.cc
deleted file mode 100644
index 102bf684fb..0000000000
--- a/src/operator/nn/dnnl/dnnl_slice.cc
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_slice.cc
- * \brief
- * \author Zhiyuan Huang
- */
-
-#if MXNET_USE_ONEDNN == 1
-
-#include "dnnl_base-inl.h"
-#include "dnnl_slice-inl.h"
-
-namespace mxnet {
-namespace op {
-
-DNNLSliceFwd::DNNLSliceFwd(const SliceParam& param, const NDArray& in, const NDArray& out) {
-  const mxnet::TShape ishape = in.shape();
-  const mxnet::TShape oshape = out.shape();
-  const int N                = ishape.ndim();
-  dnnl::memory::dims dims(N);
-  dnnl::memory::dims offsets(N);
-  for (int i = 0; i < N; ++i) {
-    dim_t s = 0;
-    if (i < param.begin.ndim() && param.begin[i]) {
-      s = *param.begin[i];
-      if (s < 0)
-        s += ishape[i];
-    }
-    dims[i]    = oshape[i];
-    offsets[i] = s;
-  }
-
-  auto in_md  = in.GetDNNLData()->get_desc();
-  auto out_md = out.GetDNNLData()->get_desc();
-  auto sub_md = in_md.submemory_desc(dims, offsets);
-
-  auto engine = CpuEngine::Get()->get_engine();
-  this->data_ = std::make_shared<dnnl::memory>(sub_md, engine, nullptr);
-  this->out_  = std::make_shared<dnnl::memory>(out_md, engine, nullptr);
-  this->fwd_  = std::make_shared<dnnl::reorder>(*this->data_, *this->out_);
-}
-
-void DNNLSliceFwd::SetNewMem(const dnnl::memory& input, const dnnl::memory& output) {
-  this->data_->set_data_handle(input.get_data_handle());
-  this->out_->set_data_handle(output.get_data_handle());
-}
-
-void DNNLSliceFwd::Register() {
-  DNNLStream::Get()->RegisterPrimArgs(
-      *fwd_, {{DNNL_ARG_FROM, *(this->data_)}, {DNNL_ARG_TO, *(this->out_)}});
-}
-
-DNNLSliceFwd& GetSliceForward(const SliceParam& param,
-                              const bool is_train,
-                              const NDArray& in_data,
-                              const NDArray& out_data) {
-#if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<DNNLSliceSignature, DNNLSliceFwd, OpHash> fwds;
-#else
-  static MX_THREAD_LOCAL std::unordered_map<DNNLSliceSignature, DNNLSliceFwd, OpHash> fwds;
-#endif
-  DNNLSliceSignature key(param);
-  key.AddSign(is_train);
-  key.AddSign(in_data);
-  key.AddSign(out_data);
-
-  auto it = fwds.find(key);
-  if (it == fwds.end()) {
-    DNNLSliceFwd fwd(param, in_data, out_data);
-    it = AddToCache(&fwds, key, fwd);
-  }
-  return it->second;
-}
-
-void DNNLSlice(const nnvm::NodeAttrs& attrs,
-               const OpContext& ctx,
-               const NDArray& in,
-               OpReqType req,
-               const NDArray& out) {
-  const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
-  DNNLSliceFwd& fwd       = GetSliceForward(param, ctx.is_train, in, out);
-  auto in_mem             = in.GetDNNLData();
-  auto out_md             = out.GetDNNLData()->get_desc();
-  auto out_mem            = CreateDNNLMem(out, out_md, req);
-  fwd.SetNewMem(*in_mem, *out_mem.second);
-  fwd.Register();
-  CommitOutput(out, out_mem);
-  DNNLStream::Get()->Submit();
-}
-
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_USE_ONEDNN == 1
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h
index c2db724cb8..94bf19bba1 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -910,7 +910,7 @@ void NumpyBinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
                                      const std::vector<mxnet::NDArray>& inputs,
                                      const std::vector<OpReqType>& req,
                                      const std::vector<mxnet::NDArray>& outputs) {
-  if (SupportDNNLBinary(inputs)) {
+  if (SupportDNNLBinary(inputs, outputs)) {
     const dnnl::algorithm alg = DNNLAlgorithm<OP>::value;
     DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
     return;
diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h
index 5d0700f56b..f049617a4b 100644
--- a/src/operator/numpy/np_true_divide-inl.h
+++ b/src/operator/numpy/np_true_divide-inl.h
@@ -95,7 +95,7 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs& attrs,
     if (common::is_float(lhs.type_flag_)) {
       // If both are the same floats, normal launch
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, {
+        MSHADOW_REAL_TYPE_SWITCH_EX(lhs.type_flag_, DType, _, {
           Kernel<op_with_req<mshadow_op::true_divide, Req>, xpu>::Launch(
               s, out.Size(), out.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>());
         });
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index cc66a1e599..a29914ecbd 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -74,13 +74,11 @@ static void BinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs,
                                        const std::vector<NDArray>& outputs) {
 #if MXNET_USE_ONEDNN == 1
   if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
-    if (SupportDNNLBinary(inputs)) {
+    if (SupportDNNLBinary(inputs, outputs)) {
       const dnnl::algorithm alg = DNNLAlgorithm<OP>::value;
       DNNLRun(DNNLBinaryOpForward<alg>, attrs, ctx, inputs, req, outputs);
     } else {
-      std::vector<mxnet::TBlob> in_data  = {inputs[0].data(), inputs[1].data()};
-      std::vector<mxnet::TBlob> out_data = {outputs[0].data()};
-      BinaryBroadcastCompute<cpu, OP>(attrs, ctx, in_data, req, out_data);
+      FallBackCompute(BinaryBroadcastCompute<cpu, OP>, attrs, ctx, inputs, req, outputs);
     }
     return;
   }
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index 2831789fe4..2b0bf281d9 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -676,18 +676,6 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
   return shape_is_known(in_attrs->at(0)) && shape_is_known(out_attrs->at(0));
 }
 
-// Currently DNNL only supports step = 1 or step has no value
-// Support for https://oneapi-src.github.io/oneDNN/v2.6/dev_guide_reorder.html
-inline bool SupportDNNLSlice(const SliceParam& param) {
-  if (param.step.ndim() == 0U)
-    return true;
-  for (int i = 0; i < param.step.ndim(); ++i) {
-    if (param.step[i].has_value() && param.step[i].value() != 1)
-      return false;
-  }
-  return true;
-}
-
 inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
                                          const int dev_mask,
                                          DispatchMode* dispatch_mode,
@@ -709,16 +697,9 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
     trivial_step = true;
   }
 
-  if (in_stype == kDefaultStorage) {
-#if MXNET_USE_ONEDNN == 1
-    if (dev_mask == Context::kCPU && DNNLEnvSet() && SupportDNNLSlice(param)) {
-      dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, dispatch_ex);
-    }
-#endif
-    if (!dispatched) {
-      dispatched =
-          storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
-    }
+  if (!dispatched && in_stype == kDefaultStorage) {
+    dispatched =
+        storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute);
   }
 
   if (!dispatched && in_stype == kCSRStorage && trivial_step) {
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 303011f058..e49f61c9eb 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -27,7 +27,6 @@
 #if MXNET_USE_ONEDNN == 1
 #include "operator/nn/dnnl/dnnl_base-inl.h"
 #include "operator/nn/dnnl/dnnl_reshape-inl.h"
-#include "operator/nn/dnnl/dnnl_slice-inl.h"
 #include "operator/nn/dnnl/dnnl_transpose-inl.h"
 #include "operator/nn/dnnl/dnnl_split-inl.h"
 #include "operator/nn/dnnl/dnnl_stack-inl.h"
@@ -473,12 +472,6 @@ will return a new array with shape ``(2,1,3,4)``.
     .add_argument("data", "NDArray-or-Symbol", "Source input")
     .add_arguments(ExpandDimParam::__FIELDS__());
 
-#if MXNET_USE_ONEDNN == 1
-bool SupportDNNLSlice(const SliceParam& param, const NDArray& input, const NDArray& output) {
-  return SupportDNNLSlice(param) && SupportDNNL(input) && SupportDNNL(output);
-}
-#endif
-
 void SliceExCPU(const nnvm::NodeAttrs& attrs,
                 const OpContext& ctx,
                 const std::vector<NDArray>& inputs,
@@ -490,14 +483,6 @@ void SliceExCPU(const nnvm::NodeAttrs& attrs,
   auto in_stype           = inputs[0].storage_type();
   if (in_stype == kCSRStorage) {
     SliceCsrImpl<cpu>(param, ctx, inputs[0], req[0], outputs[0]);
-#if MXNET_USE_ONEDNN == 1
-  } else if (in_stype == kDefaultStorage) {
-    if (SupportDNNLSlice(param, inputs[0], outputs[0])) {
-      DNNLRun(DNNLSlice, attrs, ctx, inputs[0], req[0], outputs[0]);
-    } else {
-      FallBackCompute(SliceOpForward<cpu>, attrs, ctx, inputs, req, outputs);
-    }
-#endif
   } else {
     LOG(FATAL) << "Slice not implemented for storage type" << in_stype;
   }
@@ -561,9 +546,7 @@ Example::
     .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice"})
     .set_attr<FCompute>("FCompute<cpu>", SliceOpForward<cpu>)
     .set_attr<FComputeEx>("FComputeEx<cpu>", SliceExCPU)
-#if MXNET_USE_ONEDNN == 1
-    .set_attr<bool>("TIsDNNL", true)
-#endif
+    // oneDNN support removed in PR-21106 due to performance reasons
     .add_argument("data", "NDArray-or-Symbol", "Source input")
     .add_arguments(SliceParam::__FIELDS__());