You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2022/08/25 12:18:10 UTC

[GitHub] [incubator-mxnet] Kacper-Pietkun commented on a diff in pull request #21132: [FEATURE] Dnnl sum primitive path

Kacper-Pietkun commented on code in PR #21132:
URL: https://github.com/apache/incubator-mxnet/pull/21132#discussion_r954894121


##########
src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:
##########
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs) {
-  mxnet::TShape new_lshape, new_rshape, new_oshape;
-  int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
-                                              inputs[1].shape(),
-                                              outputs[0].shape(),
-                                              &new_lshape,
-                                              &new_rshape,
-                                              &new_oshape);
-  std::vector<NDArray> new_inputs;
-  std::vector<NDArray> new_outputs;
-  if (ndim_diff) {
-    new_inputs  = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)};
-    new_outputs = {outputs[0].Reshape(new_oshape)};
-  } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
-    // BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1)
-    // into shape (1). It is mandatory for oneDNN primitive to have this reshape done.
-    mxnet::TShape one_shape = mxnet::TShape(1, 1);
-    new_inputs              = {inputs[0].Reshape(one_shape), inputs[1].Reshape(one_shape)};
-    new_outputs             = {outputs[0].Reshape(one_shape)};
+  // We can use more efficient sum kernel when there is no broadcast - when shapes are the same
+  const bool same_shape = (inputs[0].shape() == inputs[1].shape());
+
+  if (same_shape && alg == dnnl::algorithm::binary_add) {
+    DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
+    fwd.Execute(ctx, inputs, req, outputs);
   } else {
-    new_inputs  = {inputs[0], inputs[1]};
-    new_outputs = {outputs[0]};
-  }
+    mxnet::TShape new_lshape, new_rshape, new_oshape;
+    int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
+                                                inputs[1].shape(),
+                                                outputs[0].shape(),
+                                                &new_lshape,
+                                                &new_rshape,
+                                                &new_oshape);
+    std::vector<NDArray> new_inputs;
+    std::vector<NDArray> new_outputs;
+    if (ndim_diff) {
+      new_inputs  = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)};
+      new_outputs = {outputs[0].Reshape(new_oshape)};
+    } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
+      // BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1)
+      // into shape (1). It is mandatory for oneDNN primitive to have this reshape done.

Review Comment:
   BinaryBroadcastShapeCompact  does not support reshape of tensors of size (1,1,...,1), but we need to reshape such tensors into shape (1) for oneDNN to work properly



##########
src/operator/tensor/elemwise_binary_broadcast_op_basic.cc:
##########
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
                          const std::vector<NDArray>& inputs,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& outputs) {
-  mxnet::TShape new_lshape, new_rshape, new_oshape;
-  int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
-                                              inputs[1].shape(),
-                                              outputs[0].shape(),
-                                              &new_lshape,
-                                              &new_rshape,
-                                              &new_oshape);
-  std::vector<NDArray> new_inputs;
-  std::vector<NDArray> new_outputs;
-  if (ndim_diff) {
-    new_inputs  = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)};
-    new_outputs = {outputs[0].Reshape(new_oshape)};
-  } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
-    // BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1)
-    // into shape (1). It is mandatory for oneDNN primitive to have this reshape done.
-    mxnet::TShape one_shape = mxnet::TShape(1, 1);
-    new_inputs              = {inputs[0].Reshape(one_shape), inputs[1].Reshape(one_shape)};
-    new_outputs             = {outputs[0].Reshape(one_shape)};
+  // We can use more efficient sum kernel when there is no broadcast - when shapes are the same
+  const bool same_shape = (inputs[0].shape() == inputs[1].shape());
+
+  if (same_shape && alg == dnnl::algorithm::binary_add) {
+    DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
+    fwd.Execute(ctx, inputs, req, outputs);
   } else {
-    new_inputs  = {inputs[0], inputs[1]};
-    new_outputs = {outputs[0]};
-  }
+    mxnet::TShape new_lshape, new_rshape, new_oshape;
+    int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
+                                                inputs[1].shape(),
+                                                outputs[0].shape(),
+                                                &new_lshape,
+                                                &new_rshape,
+                                                &new_oshape);
+    std::vector<NDArray> new_inputs;
+    std::vector<NDArray> new_outputs;
+    if (ndim_diff) {
+      new_inputs  = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)};
+      new_outputs = {outputs[0].Reshape(new_oshape)};
+    } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
+      // BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1)
+      // into shape (1). It is mandatory for oneDNN primitive to have this reshape done.

Review Comment:
   BinaryBroadcastShapeCompact  does not support reshape of tensors of size (1,1,...,1), but we need to reshape such tensors into shape (1) for oneDNN to work properly.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org