You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/08/31 07:01:48 UTC
[incubator-mxnet] branch master updated: [FEATURE] Dnnl sum primitive path (#21132)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3a19f0e50d [FEATURE] Dnnl sum primitive path (#21132)
3a19f0e50d is described below
commit 3a19f0e50d75fedb05eb558a9c835726b57df4cf
Author: Kacper Pietkun <ka...@intel.com>
AuthorDate: Wed Aug 31 09:01:31 2022 +0200
[FEATURE] Dnnl sum primitive path (#21132)
* Added dnnl_sum primitive path to mxnet binary_add when shapes are the same
* added test coverage
* added operation check
* Random number for tests
* delete unnecessary variables
* review changes
---
.../tensor/elemwise_binary_broadcast_op_basic.cc | 63 ++++++++++++----------
tests/python/unittest/test_operator.py | 15 ++++++
2 files changed, 51 insertions(+), 27 deletions(-)
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
index a29914ecbd..ebbcd9d3d9 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
@@ -21,11 +21,12 @@
* \file elemwise_binary_broadcast_op_basic.cc
* \brief CPU Implementation of basic functions for elementwise binary broadcast operator.
*/
-#include "./elemwise_unary_op.h"
-#include "./elemwise_binary_op-inl.h"
-#include "./elemwise_binary_broadcast_op.h"
+#include "operator/tensor/elemwise_unary_op.h"
+#include "operator/tensor/elemwise_binary_op-inl.h"
+#include "operator/tensor/elemwise_binary_broadcast_op.h"
#if MXNET_USE_ONEDNN == 1
-#include "../nn/dnnl/dnnl_binary-inl.h"
+#include "operator/nn/dnnl/dnnl_binary-inl.h"
+#include "operator/nn/dnnl/dnnl_sum-inl.h"
#endif // MXNET_USE_ONEDNN == 1
namespace mxnet {
@@ -38,31 +39,39 @@ void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
- mxnet::TShape new_lshape, new_rshape, new_oshape;
- int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(),
- inputs[1].shape(),
- outputs[0].shape(),
- &new_lshape,
- &new_rshape,
- &new_oshape);
- std::vector<NDArray> new_inputs;
- std::vector<NDArray> new_outputs;
- if (ndim_diff) {
- new_inputs = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)};
- new_outputs = {outputs[0].Reshape(new_oshape)};
- } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) {
- // BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1)
- // into shape (1). It is mandatory for oneDNN primitive to have this reshape done.
- mxnet::TShape one_shape = mxnet::TShape(1, 1);
- new_inputs = {inputs[0].Reshape(one_shape), inputs[1].Reshape(one_shape)};
- new_outputs = {outputs[0].Reshape(one_shape)};
+ const mxnet::TShape& input_0_shape = inputs[0].shape();
+ const mxnet::TShape& input_1_shape = inputs[1].shape();
+ const mxnet::TShape& output_0_shape = outputs[0].shape();
+ // We can use more efficient sum kernel, when there is no broadcast - when shapes are the
+ // same.
+ const bool same_shape = (input_0_shape == input_1_shape);
+
+ if (same_shape && alg == dnnl::algorithm::binary_add) {
+ DNNLSumFwd& fwd = DNNLSumFwd::GetCached(inputs, outputs);
+ fwd.Execute(ctx, inputs, req, outputs);
} else {
- new_inputs = {inputs[0], inputs[1]};
- new_outputs = {outputs[0]};
- }
+ mxnet::TShape new_lshape, new_rshape, new_oshape;
+ int ndim_diff = BinaryBroadcastShapeCompact(
+ input_0_shape, input_1_shape, output_0_shape, &new_lshape, &new_rshape, &new_oshape);
+ std::vector<NDArray> new_inputs;
+ std::vector<NDArray> new_outputs;
+ if (ndim_diff) {
+ new_inputs = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)};
+ new_outputs = {outputs[0].Reshape(new_oshape)};
+ } else if (input_0_shape.Size() == 1 && input_1_shape.Size() == 1) {
+ // BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1)
+ // into shape (1). It is mandatory for oneDNN primitive to have this reshape done.
+ mxnet::TShape one_shape = mxnet::TShape(1, 1);
+ new_inputs = {inputs[0].Reshape(one_shape), inputs[1].Reshape(one_shape)};
+ new_outputs = {outputs[0].Reshape(one_shape)};
+ } else {
+ new_inputs = {inputs[0], inputs[1]};
+ new_outputs = {outputs[0]};
+ }
- DNNLBinaryOpFwd& fwd = DNNLBinaryOpFwd::GetBinaryOpForward<alg>(new_inputs, new_outputs);
- fwd.Execute(new_inputs, req, new_outputs);
+ DNNLBinaryOpFwd& fwd = DNNLBinaryOpFwd::GetBinaryOpForward<alg>(new_inputs, new_outputs);
+ fwd.Execute(new_inputs, req, new_outputs);
+ }
}
#endif
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 2e65ae4f53..07cf2e8436 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -9414,6 +9414,21 @@ def test_elementwise_ops_on_misaligned_input():
mx.nd.waitall()
assert a[3].asscalar() == 4.0
+
+@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
+@pytest.mark.parametrize('ndim', [1, 2, 3, 4, 5])
+@pytest.mark.parametrize('max_dim_size', [1, 2, 3, 4, 5])
+def test_broadcast_ops_on_input_with_the_same_shape(dtype, ndim, max_dim_size):
+ shape = list(rand_shape_nd(ndim, dim=max_dim_size))
+ a = np.random.uniform(low=-100, high=100, size=shape)
+ b = np.random.uniform(low=-100, high=100, size=shape)
+ expected = a + b
+ am = mx.nd.array(a)
+ bm = mx.nd.array(b)
+ cm = am + bm
+ mx.nd.waitall()
+ assert_almost_equal(cm, expected)
+
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
@pytest.mark.parametrize('lead_dim', [2, 3, 4, 6, 10])
@pytest.mark.parametrize('both_ways', [False, True])