You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/27 18:55:21 UTC

[incubator-mxnet] branch master updated: Support negative axis in concat (#9204)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 911c81e  Support negative axis in concat (#9204)
911c81e is described below

commit 911c81ea7f0598235adea68a788fec429481c317
Author: Xingjian Shi <xs...@ust.hk>
AuthorDate: Wed Dec 27 10:55:18 2017 -0800

    Support negative axis in concat (#9204)
    
    * try to enable negative axis in concat
    
    fix bug
    
    update test
    
    revise test
    
    * initialize the variable
    
    * revise test
---
 src/operator/concat-inl.h              | 39 +++++++++++++++++-----------------
 tests/python/unittest/test_operator.py | 11 +++++++++-
 2 files changed, 30 insertions(+), 20 deletions(-)

diff --git a/src/operator/concat-inl.h b/src/operator/concat-inl.h
index fdbe330..4225ddf 100644
--- a/src/operator/concat-inl.h
+++ b/src/operator/concat-inl.h
@@ -35,6 +35,7 @@
 #include <utility>
 #include "./operator_common.h"
 #include "./channel_op_common.h"
+#include "./tensor/broadcast_reduce_op.h"
 
 namespace mxnet {
 namespace op {
@@ -50,7 +51,7 @@ struct ConcatParam : public dmlc::Parameter<ConcatParam> {
   DMLC_DECLARE_PARAMETER(ConcatParam) {
     DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
     .describe("Number of inputs to be concated.");
-    DMLC_DECLARE_FIELD(dim).set_range(0,  4).set_default(1)
+    DMLC_DECLARE_FIELD(dim).set_default(1)
     .describe("the dimension to be concated.");
   }
 };  // struct ConcatParam
@@ -70,23 +71,23 @@ class ConcatOp : public Operator {
     using namespace mshadow::expr;
     CHECK_EQ(static_cast<int>(in_data.size()), size_);
     CHECK_EQ(out_data.size(), 1U);
-    CHECK_LT(dimension_, in_data[concat_enum::kData0].ndim());
+    int axis = CheckAxis(dimension_, in_data[concat_enum::kData0].ndim());
     Stream<xpu> *s = ctx.get_stream<xpu>();
     std::vector<Tensor<xpu, 3, DType> > data(size_);
     Tensor<xpu, 3, DType> out;
     size_t leading = 1, trailing = 1;
-    for (int i = 0; i < dimension_; ++i) {
+    for (int i = 0; i < axis; ++i) {
       leading *= out_data[concat_enum::kOut].shape_[i];
     }
-    for (int i = dimension_ + 1; i < out_data[concat_enum::kOut].ndim(); ++i) {
+    for (int i = axis + 1; i < out_data[concat_enum::kOut].ndim(); ++i) {
       trailing *= out_data[concat_enum::kOut].shape_[i];
     }
-    size_t mid = out_data[concat_enum::kOut].shape_[dimension_];
+    size_t mid = out_data[concat_enum::kOut].shape_[axis];
     Shape<3> oshape = Shape3(leading, mid, trailing);
     out = out_data[concat_enum::kOut].get_with_shape<xpu, 3, DType>(oshape, s);
 
     for (int i = 0; i < size_; ++i) {
-      Shape<3> dshape = Shape3(leading, in_data[i].shape_[dimension_], trailing);
+      Shape<3> dshape = Shape3(leading, in_data[i].shape_[axis], trailing);
       data[i] = in_data[i].get_with_shape<xpu, 3, DType>(dshape, s);
     }
     Concatenate(data, &out, 1, req[concat_enum::kOut]);
@@ -103,22 +104,23 @@ class ConcatOp : public Operator {
     using namespace mshadow::expr;
     CHECK_EQ(out_grad.size(), 1U);
     CHECK_EQ(in_grad.size(), static_cast<size_t>(size_));
+    int axis = CheckAxis(dimension_, out_grad[concat_enum::kData0].ndim());
     Stream<xpu> *s = ctx.get_stream<xpu>();
     std::vector<Tensor<xpu, 3, DType> > grad_in(size_);
     Tensor<xpu, 3, DType> grad;
     size_t leading = 1, trailing = 1;
-    for (int i = 0; i < dimension_; ++i) {
+    for (int i = 0; i < axis; ++i) {
       leading *= out_grad[concat_enum::kOut].shape_[i];
     }
-    for (int i = dimension_ + 1; i < out_grad[concat_enum::kOut].ndim(); ++i) {
+    for (int i = axis + 1; i < out_grad[concat_enum::kOut].ndim(); ++i) {
       trailing *= out_grad[concat_enum::kOut].shape_[i];
     }
-    size_t mid = out_grad[concat_enum::kOut].shape_[dimension_];
+    size_t mid = out_grad[concat_enum::kOut].shape_[axis];
     Shape<3> oshape = Shape3(leading, mid, trailing);
     grad = out_grad[concat_enum::kOut].get_with_shape<xpu, 3, DType>(oshape, s);
 
     for (int i = 0; i < size_; ++i) {
-      Shape<3> dshape = Shape3(leading, in_grad[i].shape_[dimension_], trailing);
+      Shape<3> dshape = Shape3(leading, in_grad[i].shape_[axis], trailing);
       grad_in[i] = in_grad[i].get_with_shape<xpu, 3, DType>(dshape, s);
     }
     Split(grad, &grad_in, 1, req);
@@ -159,23 +161,22 @@ class ConcatProp : public OperatorProperty {
     TShape dshape;
     index_t size = 0;
     bool has_zero = false;
+    int axis = -1;
     for (int i = 0; i < param_.num_args; ++i) {
       TShape tmp = (*in_shape)[i];
       if (tmp.ndim()) {
-        CHECK_LT(static_cast<index_t>(param_.dim), tmp.ndim())
-          << "concat dim " << param_.dim << " out of range of input shape " << tmp;
-        has_zero = tmp[param_.dim] == 0 || has_zero;
-        size += tmp[param_.dim];
-        tmp[param_.dim] = 0;
+        axis = CheckAxis(param_.dim, tmp.ndim());
+        has_zero = tmp[axis] == 0 || has_zero;
+        size += tmp[axis];
+        tmp[axis] = 0;
         shape_assign(&dshape, tmp);
       }
     }
 
     TShape tmp = (*out_shape)[0];
     if (tmp.ndim()) {
-      CHECK_LT(static_cast<index_t>(param_.dim), tmp.ndim())
-        << "concat dim " << param_.dim << " out of range of input shape " << tmp;
-      tmp[param_.dim] = 0;
+      axis = CheckAxis(param_.dim, tmp.ndim());
+      tmp[axis] = 0;
       shape_assign(&dshape, tmp);
     }
 
@@ -186,7 +187,7 @@ class ConcatProp : public OperatorProperty {
         << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
     }
 
-    if (!has_zero) dshape[param_.dim] = size;
+    if (!has_zero) dshape[axis] = size;
     CHECK(shape_assign(&(*out_shape)[0], dshape))
       << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
 
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index ba1b991..d05e325 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -134,6 +134,10 @@ def test_concat():
                         shapes.append((a, merge[i]))
                     check_concat_with_shape(shapes,dimension,True)
                     check_concat_with_shape(shapes,dimension,False)
+                    # Test negative dim
+                    check_concat_with_shape(shapes, dimension - 2, True)
+                    check_concat_with_shape(shapes, dimension - 2, False)
+
         #test 3D
         if dimension<3:
             for dim in range(2, 6):
@@ -147,6 +151,9 @@ def test_concat():
                         shapes.append((a,b,merge[i]))
                 check_concat_with_shape(shapes,dimension,True)
                 check_concat_with_shape(shapes,dimension,False)
+                # Test negative dim
+                check_concat_with_shape(shapes, dimension - 3, True)
+                check_concat_with_shape(shapes, dimension - 3, False)
         # test 4D
         for dim in range(2, 6):
             shapes = []
@@ -161,7 +168,9 @@ def test_concat():
                     shapes.append((a,b,c,merge[i]))
             check_concat_with_shape(shapes,dimension,True)
             check_concat_with_shape(shapes,dimension,False)
-
+            # Test negative dim
+            check_concat_with_shape(shapes, dimension - 4, True)
+            check_concat_with_shape(shapes, dimension - 4, False)
 
 def test_slice_channel():
     def check_slice_channel(data_ndim, axis, num_outputs, squeeze_axis):

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].