You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2018/08/23 18:15:51 UTC

[incubator-mxnet] branch master updated: Generalized broadcast_like operator (#11984)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 43581a7  Generalized broadcast_like operator (#11984)
43581a7 is described below

commit 43581a7cb3393b1e6660c36c5ef4d59a09b212dc
Author: Istvan Fehervari <go...@gmail.com>
AuthorDate: Thu Aug 23 11:15:40 2018 -0700

    Generalized broadcast_like operator (#11984)
    
    * Added input_axes and other_axes to broadcast_like
    
    See https://github.com/apache/incubator-mxnet/issues/11871
    
    * Added a simple sanity test
    
    * Fixed linting
    
    * Fixed linting issues
    
    * Renamed parameters, added negative indexing, more testcases
    
    * Fixed linting
    
    * Replaced params with optionals
    
    Not specified axes will result into whole shape, empty tuples shall raise exception.
    Added tests
    
    * Re-added the default param values
    
    * Fixed indentation
---
 src/operator/tensor/broadcast_reduce_op.h        | 71 ++++++++++++++++++++----
 src/operator/tensor/broadcast_reduce_op_value.cc |  5 ++
 tests/python/unittest/test_ndarray.py            | 19 +++++++
 tests/python/unittest/test_symbol.py             |  3 +-
 4 files changed, 87 insertions(+), 11 deletions(-)

diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 351315a..0944d25 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -147,6 +147,17 @@ struct BroadcastToParam : public dmlc::Parameter<BroadcastToParam> {
   }
 };
 
+struct BroadcastLikeParam : public dmlc::Parameter<BroadcastLikeParam> {
+  dmlc::optional<TShape> lhs_axes;
+  dmlc::optional<TShape> rhs_axes;
+  DMLC_DECLARE_PARAMETER(BroadcastLikeParam) {
+    DMLC_DECLARE_FIELD(lhs_axes).set_default(dmlc::optional<TShape>())
+      .describe("Axes to perform broadcast on in the first input array");
+    DMLC_DECLARE_FIELD(rhs_axes).set_default(dmlc::optional<TShape>())
+      .describe("Axes to copy from the second input array");
+  }
+};
+
 inline int CheckAxis(int axis, int ndim) {
   CHECK(axis < ndim && axis >= -ndim)
     << "axis " << axis << " exceeds the input dimension of " << ndim;
@@ -350,20 +361,60 @@ inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   TShape& lhs_shape = (*in_attrs)[0];
   TShape& rhs_shape = (*in_attrs)[1];
-  TShape oshape = TShape(rhs_shape);
-  if (lhs_shape.ndim() == 0 || lhs_shape.ndim() == 0) return false;
 
-  CHECK_EQ(lhs_shape.ndim(), rhs_shape.ndim())
-    << "Operand of shape " << lhs_shape << " cannot be broadcasted to " << rhs_shape;
+  if ((lhs_shape.ndim() == 0) || (lhs_shape.ndim() == 0)) {
+    return false;
+  }
 
-  for (index_t i = 0; i < lhs_shape.ndim(); ++i) {
-    if (rhs_shape[i] != 0) {
-      CHECK(lhs_shape[i] == rhs_shape[i] || lhs_shape[i] == 1)
-        << "Array cannot be broadcasted from " << lhs_shape << " to " << rhs_shape;
-    } else {
-      oshape[i] = lhs_shape[i];
+  const BroadcastLikeParam& param = nnvm::get<BroadcastLikeParam>(attrs.parsed);
+  TShape oshape;
+
+  // lhs or rhs or both params were not specified
+  if (!param.lhs_axes.has_value() || !param.rhs_axes.has_value()) {
+    CHECK_EQ(lhs_shape.ndim(), rhs_shape.ndim())
+      << "Operand of shape " << lhs_shape << " cannot be broadcasted to " << rhs_shape;
+
+    oshape = TShape(rhs_shape);
+    for (index_t i = 0; i < lhs_shape.ndim(); ++i) {
+      if (rhs_shape[i] != 0) {
+        CHECK(lhs_shape[i] == rhs_shape[i] || lhs_shape[i] == 1)
+          << "Array cannot be broadcasted from " << lhs_shape << " to " << rhs_shape;
+      } else {
+        oshape[i] = lhs_shape[i];
+      }
+    }
+  } else {
+    auto lhs_axes = param.lhs_axes.value();
+    auto rhs_axes = param.rhs_axes.value();
+
+    CHECK(rhs_axes.ndim() == lhs_axes.ndim())
+      << "Input_axis and other_axis size does not match";
+
+    CHECK(lhs_axes.ndim() > 0)
+      << "Empty axes tuple is not allowed";
+
+    oshape = TShape(lhs_shape);
+    for (index_t i = 0; i < lhs_axes.ndim(); ++i) {
+      auto copyfrom = lhs_axes[i];
+      if (copyfrom < 0) {
+        copyfrom =  lhs_shape.ndim() + copyfrom;
+      }
+      CHECK(copyfrom >= 0 && copyfrom < oshape.ndim())
+        << "Invalid dimension specified in lhs_axes: " << lhs_axes[i];
+
+      auto copyto = rhs_axes[i];
+      if (copyto < 0) {
+        copyto =  rhs_shape.ndim() + copyto;
+      }
+      CHECK(copyto >= 0 && copyto < rhs_shape.ndim())
+        << "Invalid dimension specified in rhs_axes: " << rhs_axes[i];
+
+      CHECK(lhs_shape[copyfrom] == 1) << "Input axis " << lhs_axes[i]
+        << " at dimension " << i << " cannot be broadcasted to " << rhs_shape[copyto];
+      oshape[copyfrom] = rhs_shape[copyto];
     }
   }
+
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
   return true;
 }
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index 929c3df..c3bc9cf 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -31,6 +31,7 @@ DMLC_REGISTER_PARAMETER(NormParam);
 DMLC_REGISTER_PARAMETER(ReduceAxisParam);
 DMLC_REGISTER_PARAMETER(BroadcastAxesParam);
 DMLC_REGISTER_PARAMETER(BroadcastToParam);
+DMLC_REGISTER_PARAMETER(BroadcastLikeParam);
 
 inline std::string get_reduce_axes_description(const std::string& op_name, int line) {
   std::string doc = R"code(Computes the __op__ of array elements over given axes.
@@ -309,7 +310,11 @@ For example::
    broadcast_like([[1,2,3]], [[5,6,7],[7,8,9]]) = [[ 1.,  2.,  3.],
                                                    [ 1.,  2.,  3.]])
 
+   broadcast_like([9], [1,2,3,4,5], lhs_axes=(0,), rhs_axes=(-1,)) = [9,9,9,9,9]
+
 )code" ADD_FILELINE)
+.set_attr_parser(ParamParser<BroadcastLikeParam>)
+.add_arguments(BroadcastLikeParam::__FIELDS__())
 .set_attr<nnvm::FInferShape>("FInferShape", BroadcastLikeShape)
 .set_attr<FCompute>("FCompute<cpu>", BroadcastCompute<cpu>);
 
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index e9eea43..071c770 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -549,8 +549,27 @@ def test_broadcast():
             err = np.square(ndarray_ret - numpy_ret).mean()
             assert err < 1E-8
 
+    def test_broadcast_like_axis():
+        testcases = [
+            # Lhs shape, rhs shape, lhs axis, rhs axis, result
+            [(1, 2, 1, 3), (5, 6, 7, 8), (0,2), (1,3), (6, 2, 8, 3)],
+            [(1,), (5,), (0,), (-1,), (5,)],
+            [(1, 7, 9, 1, 1), (9,), (-2,), (0,), (1, 7, 9, 9, 1)],
+            [(1, 7, 9, 1, 1), (9, 1), (-2, -1), (-2, -1), (1, 7, 9, 9, 1)],
+            [(2, 1), (1, 7, 9, 1, 1), (1,), (-3,), (2, 9)]
+        ]
+        
+        for test_data in testcases:
+            lhs = mx.nd.random.uniform(shape=test_data[0])
+            rhs = mx.nd.random.uniform(shape=test_data[1])
+            output = mx.nd.broadcast_like(lhs, rhs, lhs_axes=test_data[2], rhs_axes=test_data[3])
+
+            assert_exception(mx.nd.broadcast_like, mx.base.MXNetError, lhs, rhs, lhs_axes=(), rhs_axes=())
+            assert output.shape == test_data[4]
+
     test_broadcast_to()
     test_broadcast_like()
+    test_broadcast_like_axis()
 
 
 @with_seed()
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index aece9a3..d022c68 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -171,7 +171,7 @@ def test_symbol_infer_shape_var():
 def test_symbol_fluent():
     has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
                     'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
-                    'broadcast_axes', 'pad', 'swapaxes', 'slice', 'slice_axis', 'slice_like',
+                    'broadcast_axes', 'broadcast_like', 'pad', 'swapaxes', 'slice', 'slice_axis', 'slice_like',
                     'take', 'one_hot', 'pick', 'sort', 'topk', 'argsort', 'argmax', 'argmin',
                     'clip', 'abs', 'sign', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
                     'degrees', 'radians', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
@@ -212,6 +212,7 @@ def test_symbol_fluent():
     check_fluent_regular('slice_like', {'axes': (0, -2), 'shape_like': mx.sym.zeros((3, 3))})
     check_fluent_regular('clip', {'a_min': 0.25, 'a_max': 0.75})
     check_fluent_regular('broadcast_axes', {'axis': (2,), 'size': (5,)})
+    check_fluent_regular('broadcast_like', {'rhs': mx.sym.ones((1, 5)), 'lhs_axes': (0,), 'rhs_axes': (1,)}, shape=(1,9))
     check_fluent_regular('pad', {'mode': 'constant', 'pad_width': (0,0,0,0,3,0,0,4)}, shape=(5, 17, 2, 3))
     check_fluent_regular('reshape_like', {'rhs': mx.sym.ones((30, 17))}, shape=(5, 17, 2, 3))