You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2022/05/05 08:55:17 UTC

[GitHub] [incubator-mxnet] bgawrych commented on a diff in pull request #20976: [FEATURE] Add _npi_power_scalar and _npi_multiply_scalar fuse

bgawrych commented on code in PR #20976:
URL: https://github.com/apache/incubator-mxnet/pull/20976#discussion_r865689653


##########
src/operator/subgraph/dnnl/dnnl_pow_mul_scalar_property.h:
##########
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_pow_mul_scalar_property.h
+ * \brief Graph property for fusing _npi_power_scalar with _npi_multiply_scalar
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "operator/subgraph/common.h"
+#include "operator/tensor/elemwise_binary_scalar_op.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgDNNLPowMulScalarSelector : public SubgraphSelectorV2 {
+ private:
+  std::vector<const BiDirectedNode*> matched_list_;
+  SelectStatus status_;
+
+ public:
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    if (seed_node.node->op() == Op::Get("_npi_power_scalar")) {
+      matched_list_.clear();
+      matched_list_.emplace_back(&seed_node);
+      status_ = kStart;
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& output_node) override {
+    const nnvm::Node* raw_power_scalar_node = n.node;
+    const nnvm::Node* raw_next_node         = output_node.node;
+    if (raw_power_scalar_node->op() && raw_power_scalar_node->op()->name == "_npi_power_scalar") {
+      if (raw_next_node->op() && status_ == kStart &&
+          raw_next_node->op()->name == "_npi_multiply_scalar") {
+        status_ = kSuccess;
+        return true;
+      } else {
+        status_ = kFail;
+        return false;
+      }
+    }
+
+    if (matched_list_.back() != &n) {
+      if (std::find(matched_list_.begin(), matched_list_.end(), &n) != matched_list_.end()) {
+        while (matched_list_.back() != &n) {
+          matched_list_.pop_back();
+        }
+      }
+      status_ = kSuccess;
+      return false;
+    }
+
+    return false;
+  }
+
+  void Reset() override {
+    CHECK_GE(matched_list_.size(), 1);
+    auto new_selector = SgDNNLPowMulScalarSelector();
+    new_selector.Select(*matched_list_[0], nullptr);
+    *this = new_selector;
+  }
+};
+
+class SgDNNLPowMulScalarProperty : public SubgraphProperty {
+ public:
+  SgDNNLPowMulScalarProperty() {}
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string& name = "DNNL PowMulScalar optimization pass";
+    auto property                  = std::make_shared<SgDNNLPowMulScalarProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_POW_MUL_SCALAR_OPT", 0)) {
+      property->SetAttr<bool>("disable", true);
+    }
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+                                     const int subgraph_id = 0) const override {
+    nnvm::ObjectPtr n = nnvm::Node::Create();
+
+    std::ostringstream node_name;
+    node_name << "sg_dnnl_pow_mul_scalar_" << std::to_string(subgraph_id);

Review Comment:
   you are not using node_name anywhere, so `n->attrs.name = sg_dnnl_pow_mul_scalar_" + std::to_string(subgraph_id)` will be enough



##########
src/operator/nn/dnnl/dnnl_pow_mul_scalar.cc:
##########
@@ -18,51 +18,64 @@
  */
 
 /*!
- * \file dnnl_power_scalar.cc
- * \author: Adam Grabowski, adam.grabowski@intel.com
+ * \file dnnl_pow_mul_scalar.cc
  */
 
 #if MXNET_USE_ONEDNN == 1
 
-#include "dnnl_power_scalar-inl.h"
+#include "dnnl_pow_mul_scalar-inl.h"
 
 namespace mxnet {
 namespace op {
 
-DNNLPowerFwd& DNNLPowerFwd::GetPowerForward(const nnvm::NodeAttrs& attrs,
-                                            const NDArray& input,
-                                            const NDArray& output) {
-  const NumpyBinaryScalarParam& param = nnvm::get<NumpyBinaryScalarParam>(attrs.parsed);
+bool SupportDNNLPower(const NDArray& input) {
+  return input.shape().Size() != 0 && input.shape().ndim() > 0 && input.shape().ndim() <= 6 &&
+         input.dtype() == mshadow::kFloat32;
+}
+
+DMLC_REGISTER_PARAMETER(DNNLPowMulScalarParam);
+
+DNNLPowMulScalarFwd& DNNLPowMulScalarFwd::GetCached(const DNNLPowMulScalarParam& param,
+                                                    const NDArray& input,
+                                                    const NDArray& output) {
 #if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<DNNLPowerSignature, DNNLPowerFwd, OpHash> fwds;
+  static thread_local std::unordered_map<DNNLPowMulScalarSignature, DNNLPowMulScalarFwd, OpHash>
+      fwds;
 #else
-  static MX_THREAD_LOCAL std::unordered_map<DNNLPowerSignature, DNNLPowerFwd, OpHash> fwds;
+  static MX_THREAD_LOCAL std::unordered_map<DNNLPowMulScalarSignature, DNNLPowMulScalarFwd, OpHash>
+      fwds;
 #endif
-  DNNLPowerSignature key;
-  key.AddSign(static_cast<float>(param.scalar));
+  DNNLPowMulScalarSignature key(param);
   key.AddSign(input);
   key.AddSign(output);
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    const DNNLPowerFwd fwd(input, static_cast<float>(param.scalar));
+    const DNNLPowMulScalarFwd fwd(param, input);
     it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
 }
 
-DNNLPowerFwd::DNNLPowerFwd(const NDArray& input, const float exponent) {
+DNNLPowMulScalarFwd::DNNLPowMulScalarFwd(const DNNLPowMulScalarParam& param, const NDArray& input) {
   auto src_desc = input.GetDNNLData()->get_desc();
-  dnnl::eltwise_forward::desc fwd_desc(
-      dnnl::prop_kind::forward_scoring, dnnl::algorithm::eltwise_pow, src_desc, 1, exponent);
+  dnnl::eltwise_forward::desc fwd_desc(dnnl::prop_kind::forward_scoring,
+                                       dnnl::algorithm::eltwise_pow,
+                                       src_desc,
+                                       param.multiplier,
+                                       param.exponent);
   fwd_pd = std::make_shared<eltwise_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
   fwd    = std::make_shared<eltwise_fwd_t>(*fwd_pd);
 }
 
-void DNNLPowerFwd::Execute(const NDArray& input, const OpReqType& req, const NDArray& output) {
-  auto engine           = mxnet::CpuEngine::Get()->get_engine();
-  auto src              = input.GetDNNLData();
-  dnnl_output_t out_mem = CreateDNNLMem(output, fwd_pd->dst_desc(), req, &input);
+void DNNLPowMulScalarFwd::Execute(const NDArray& input,
+                                  const OpReqType& req,
+                                  const NDArray& output) {
+  auto input_ = input;
+  if (input_.IsDNNLData())

Review Comment:
   I think it shouldn't be necessary here - reordering to default format is done in DNNLRun function



##########
src/operator/subgraph/dnnl/dnnl_pow_mul_scalar_property.h:
##########
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_pow_mul_scalar_property.h
+ * \brief Graph property for fusing _npi_power_scalar with _npi_multiply_scalar
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "operator/subgraph/common.h"
+#include "operator/tensor/elemwise_binary_scalar_op.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgDNNLPowMulScalarSelector : public SubgraphSelectorV2 {
+ private:
+  std::vector<const BiDirectedNode*> matched_list_;
+  SelectStatus status_;
+
+ public:
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    if (seed_node.node->op() == Op::Get("_npi_power_scalar")) {
+      matched_list_.clear();
+      matched_list_.emplace_back(&seed_node);
+      status_ = kStart;
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& output_node) override {
+    const nnvm::Node* raw_power_scalar_node = n.node;
+    const nnvm::Node* raw_next_node         = output_node.node;
+    if (raw_power_scalar_node->op() && raw_power_scalar_node->op()->name == "_npi_power_scalar") {
+      if (raw_next_node->op() && status_ == kStart &&
+          raw_next_node->op()->name == "_npi_multiply_scalar") {
+        status_ = kSuccess;
+        return true;
+      } else {
+        status_ = kFail;
+        return false;
+      }
+    }
+
+    if (matched_list_.back() != &n) {
+      if (std::find(matched_list_.begin(), matched_list_.end(), &n) != matched_list_.end()) {
+        while (matched_list_.back() != &n) {
+          matched_list_.pop_back();
+        }
+      }
+      status_ = kSuccess;
+      return false;
+    }
+
+    return false;
+  }
+
+  void Reset() override {
+    CHECK_GE(matched_list_.size(), 1);
+    auto new_selector = SgDNNLPowMulScalarSelector();
+    new_selector.Select(*matched_list_[0], nullptr);
+    *this = new_selector;
+  }
+};
+
+class SgDNNLPowMulScalarProperty : public SubgraphProperty {
+ public:
+  SgDNNLPowMulScalarProperty() {}
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string& name = "DNNL PowMulScalar optimization pass";
+    auto property                  = std::make_shared<SgDNNLPowMulScalarProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_POW_MUL_SCALAR_OPT", 0)) {
+      property->SetAttr<bool>("disable", true);
+    }
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+                                     const int subgraph_id = 0) const override {
+    nnvm::ObjectPtr n = nnvm::Node::Create();
+
+    std::ostringstream node_name;
+    node_name << "sg_dnnl_pow_mul_scalar_" << std::to_string(subgraph_id);
+
+    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
+      if (node->is_variable())
+        return;
+      auto& sub_name = node->op()->name;
+      if (sub_name == "_npi_power_scalar") {

Review Comment:
   Similar like above I would use Op::Get(...) to compare



##########
src/operator/subgraph/dnnl/dnnl_pow_mul_scalar_property.h:
##########
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_pow_mul_scalar_property.h
+ * \brief Graph property for fusing _npi_power_scalar with _npi_multiply_scalar
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POW_MUL_SCALAR_PROPERTY_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "operator/subgraph/common.h"
+#include "operator/tensor/elemwise_binary_scalar_op.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgDNNLPowMulScalarSelector : public SubgraphSelectorV2 {
+ private:
+  std::vector<const BiDirectedNode*> matched_list_;
+  SelectStatus status_;
+
+ public:
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    if (seed_node.node->op() == Op::Get("_npi_power_scalar")) {
+      matched_list_.clear();
+      matched_list_.emplace_back(&seed_node);
+      status_ = kStart;
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& output_node) override {
+    const nnvm::Node* raw_power_scalar_node = n.node;
+    const nnvm::Node* raw_next_node         = output_node.node;
+    if (raw_power_scalar_node->op() && raw_power_scalar_node->op()->name == "_npi_power_scalar") {

Review Comment:
   It will be better to compare operator instead of name and checking if it is op:
   ```suggestion
       if (raw_power_scalar_node->op() == Op::Get("_npi_power_scalar") {
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org