You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/23 18:15:42 UTC

[GitHub] szha closed pull request #11984: Generalized broadcast_like operator

szha closed pull request #11984: Generalized broadcast_like operator
URL: https://github.com/apache/incubator-mxnet/pull/11984
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h
index 351315ab0c8..0944d255a45 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -147,6 +147,17 @@ struct BroadcastToParam : public dmlc::Parameter<BroadcastToParam> {
   }
 };
 
+struct BroadcastLikeParam : public dmlc::Parameter<BroadcastLikeParam> {
+  dmlc::optional<TShape> lhs_axes;
+  dmlc::optional<TShape> rhs_axes;
+  DMLC_DECLARE_PARAMETER(BroadcastLikeParam) {
+    DMLC_DECLARE_FIELD(lhs_axes).set_default(dmlc::optional<TShape>())
+      .describe("Axes to perform broadcast on in the first input array");
+    DMLC_DECLARE_FIELD(rhs_axes).set_default(dmlc::optional<TShape>())
+      .describe("Axes to copy from the second input array");
+  }
+};
+
 inline int CheckAxis(int axis, int ndim) {
   CHECK(axis < ndim && axis >= -ndim)
     << "axis " << axis << " exceeds the input dimension of " << ndim;
@@ -350,20 +361,60 @@ inline bool BroadcastLikeShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   TShape& lhs_shape = (*in_attrs)[0];
   TShape& rhs_shape = (*in_attrs)[1];
-  TShape oshape = TShape(rhs_shape);
-  if (lhs_shape.ndim() == 0 || lhs_shape.ndim() == 0) return false;
 
-  CHECK_EQ(lhs_shape.ndim(), rhs_shape.ndim())
-    << "Operand of shape " << lhs_shape << " cannot be broadcasted to " << rhs_shape;
+  if ((lhs_shape.ndim() == 0) || (lhs_shape.ndim() == 0)) {
+    return false;
+  }
 
-  for (index_t i = 0; i < lhs_shape.ndim(); ++i) {
-    if (rhs_shape[i] != 0) {
-      CHECK(lhs_shape[i] == rhs_shape[i] || lhs_shape[i] == 1)
-        << "Array cannot be broadcasted from " << lhs_shape << " to " << rhs_shape;
-    } else {
-      oshape[i] = lhs_shape[i];
+  const BroadcastLikeParam& param = nnvm::get<BroadcastLikeParam>(attrs.parsed);
+  TShape oshape;
+
+  // lhs or rhs or both params were not specified
+  if (!param.lhs_axes.has_value() || !param.rhs_axes.has_value()) {
+    CHECK_EQ(lhs_shape.ndim(), rhs_shape.ndim())
+      << "Operand of shape " << lhs_shape << " cannot be broadcasted to " << rhs_shape;
+
+    oshape = TShape(rhs_shape);
+    for (index_t i = 0; i < lhs_shape.ndim(); ++i) {
+      if (rhs_shape[i] != 0) {
+        CHECK(lhs_shape[i] == rhs_shape[i] || lhs_shape[i] == 1)
+          << "Array cannot be broadcasted from " << lhs_shape << " to " << rhs_shape;
+      } else {
+        oshape[i] = lhs_shape[i];
+      }
+    }
+  } else {
+    auto lhs_axes = param.lhs_axes.value();
+    auto rhs_axes = param.rhs_axes.value();
+
+    CHECK(rhs_axes.ndim() == lhs_axes.ndim())
+      << "Input_axis and other_axis size does not match";
+
+    CHECK(lhs_axes.ndim() > 0)
+      << "Empty axes tuple is not allowed";
+
+    oshape = TShape(lhs_shape);
+    for (index_t i = 0; i < lhs_axes.ndim(); ++i) {
+      auto copyfrom = lhs_axes[i];
+      if (copyfrom < 0) {
+        copyfrom =  lhs_shape.ndim() + copyfrom;
+      }
+      CHECK(copyfrom >= 0 && copyfrom < oshape.ndim())
+        << "Invalid dimension specified in lhs_axes: " << lhs_axes[i];
+
+      auto copyto = rhs_axes[i];
+      if (copyto < 0) {
+        copyto =  rhs_shape.ndim() + copyto;
+      }
+      CHECK(copyto >= 0 && copyto < rhs_shape.ndim())
+        << "Invalid dimension specified in rhs_axes: " << rhs_axes[i];
+
+      CHECK(lhs_shape[copyfrom] == 1) << "Input axis " << lhs_axes[i]
+        << " at dimension " << i << " cannot be broadcasted to " << rhs_shape[copyto];
+      oshape[copyfrom] = rhs_shape[copyto];
     }
   }
+
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
   return true;
 }
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc
index 929c3dfcf0a..c3bc9cfd3f0 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -31,6 +31,7 @@ DMLC_REGISTER_PARAMETER(NormParam);
 DMLC_REGISTER_PARAMETER(ReduceAxisParam);
 DMLC_REGISTER_PARAMETER(BroadcastAxesParam);
 DMLC_REGISTER_PARAMETER(BroadcastToParam);
+DMLC_REGISTER_PARAMETER(BroadcastLikeParam);
 
 inline std::string get_reduce_axes_description(const std::string& op_name, int line) {
   std::string doc = R"code(Computes the __op__ of array elements over given axes.
@@ -309,7 +310,11 @@ For example::
    broadcast_like([[1,2,3]], [[5,6,7],[7,8,9]]) = [[ 1.,  2.,  3.],
                                                    [ 1.,  2.,  3.]])
 
+   broadcast_like([9], [1,2,3,4,5], lhs_axes=(0,), rhs_axes=(-1,)) = [9,9,9,9,9]
+
 )code" ADD_FILELINE)
+.set_attr_parser(ParamParser<BroadcastLikeParam>)
+.add_arguments(BroadcastLikeParam::__FIELDS__())
 .set_attr<nnvm::FInferShape>("FInferShape", BroadcastLikeShape)
 .set_attr<FCompute>("FCompute<cpu>", BroadcastCompute<cpu>);
 
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 2db39d5dd53..36522601814 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -551,8 +551,27 @@ def test_broadcast_like():
             err = np.square(ndarray_ret - numpy_ret).mean()
             assert err < 1E-8
 
+    def test_broadcast_like_axis():
+        testcases = [
+            # Lhs shape, rhs shape, lhs axis, rhs axis, result
+            [(1, 2, 1, 3), (5, 6, 7, 8), (0,2), (1,3), (6, 2, 8, 3)],
+            [(1,), (5,), (0,), (-1,), (5,)],
+            [(1, 7, 9, 1, 1), (9,), (-2,), (0,), (1, 7, 9, 9, 1)],
+            [(1, 7, 9, 1, 1), (9, 1), (-2, -1), (-2, -1), (1, 7, 9, 9, 1)],
+            [(2, 1), (1, 7, 9, 1, 1), (1,), (-3,), (2, 9)]
+        ]
+        
+        for test_data in testcases:
+            lhs = mx.nd.random.uniform(shape=test_data[0])
+            rhs = mx.nd.random.uniform(shape=test_data[1])
+            output = mx.nd.broadcast_like(lhs, rhs, lhs_axes=test_data[2], rhs_axes=test_data[3])
+
+            assert_exception(mx.nd.broadcast_like, mx.base.MXNetError, lhs, rhs, lhs_axes=(), rhs_axes=())
+            assert output.shape == test_data[4]
+
     test_broadcast_to()
     test_broadcast_like()
+    test_broadcast_like_axis()
 
 
 @with_seed()
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index aece9a37812..d022c68237a 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -171,7 +171,7 @@ def test_symbol_infer_shape_var():
 def test_symbol_fluent():
     has_grad = set(['flatten', 'expand_dims', 'flip', 'tile', 'transpose', 'sum', 'nansum', 'prod',
                     'nanprod', 'mean', 'max', 'min', 'reshape', 'broadcast_to', 'split',
-                    'broadcast_axes', 'pad', 'swapaxes', 'slice', 'slice_axis', 'slice_like',
+                    'broadcast_axes', 'broadcast_like', 'pad', 'swapaxes', 'slice', 'slice_axis', 'slice_like',
                     'take', 'one_hot', 'pick', 'sort', 'topk', 'argsort', 'argmax', 'argmin',
                     'clip', 'abs', 'sign', 'sin', 'cos', 'tan', 'arcsin', 'arccos', 'arctan',
                     'degrees', 'radians', 'sinh', 'cosh', 'tanh', 'arcsinh', 'arccosh', 'arctanh',
@@ -212,6 +212,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
     check_fluent_regular('slice_like', {'axes': (0, -2), 'shape_like': mx.sym.zeros((3, 3))})
     check_fluent_regular('clip', {'a_min': 0.25, 'a_max': 0.75})
     check_fluent_regular('broadcast_axes', {'axis': (2,), 'size': (5,)})
+    check_fluent_regular('broadcast_like', {'rhs': mx.sym.ones((1, 5)), 'lhs_axes': (0,), 'rhs_axes': (1,)}, shape=(1,9))
     check_fluent_regular('pad', {'mode': 'constant', 'pad_width': (0,0,0,0,3,0,0,4)}, shape=(5, 17, 2, 3))
     check_fluent_regular('reshape_like', {'rhs': mx.sym.ones((30, 17))}, shape=(5, 17, 2, 3))
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services