You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/12/27 18:55:21 UTC
[incubator-mxnet] branch master updated: Support negative axis in
concat (#9204)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 911c81e Support negative axis in concat (#9204)
911c81e is described below
commit 911c81ea7f0598235adea68a788fec429481c317
Author: Xingjian Shi <xs...@ust.hk>
AuthorDate: Wed Dec 27 10:55:18 2017 -0800
Support negative axis in concat (#9204)
* try to enable negative axis in concat
fix bug
update test
revise test
* initialize the variable
* revise test
---
src/operator/concat-inl.h | 39 +++++++++++++++++-----------------
tests/python/unittest/test_operator.py | 11 +++++++++-
2 files changed, 30 insertions(+), 20 deletions(-)
diff --git a/src/operator/concat-inl.h b/src/operator/concat-inl.h
index fdbe330..4225ddf 100644
--- a/src/operator/concat-inl.h
+++ b/src/operator/concat-inl.h
@@ -35,6 +35,7 @@
#include <utility>
#include "./operator_common.h"
#include "./channel_op_common.h"
+#include "./tensor/broadcast_reduce_op.h"
namespace mxnet {
namespace op {
@@ -50,7 +51,7 @@ struct ConcatParam : public dmlc::Parameter<ConcatParam> {
DMLC_DECLARE_PARAMETER(ConcatParam) {
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs to be concated.");
- DMLC_DECLARE_FIELD(dim).set_range(0, 4).set_default(1)
+ DMLC_DECLARE_FIELD(dim).set_default(1)
.describe("the dimension to be concated.");
}
}; // struct ConcatParam
@@ -70,23 +71,23 @@ class ConcatOp : public Operator {
using namespace mshadow::expr;
CHECK_EQ(static_cast<int>(in_data.size()), size_);
CHECK_EQ(out_data.size(), 1U);
- CHECK_LT(dimension_, in_data[concat_enum::kData0].ndim());
+ int axis = CheckAxis(dimension_, in_data[concat_enum::kData0].ndim());
Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Tensor<xpu, 3, DType> > data(size_);
Tensor<xpu, 3, DType> out;
size_t leading = 1, trailing = 1;
- for (int i = 0; i < dimension_; ++i) {
+ for (int i = 0; i < axis; ++i) {
leading *= out_data[concat_enum::kOut].shape_[i];
}
- for (int i = dimension_ + 1; i < out_data[concat_enum::kOut].ndim(); ++i) {
+ for (int i = axis + 1; i < out_data[concat_enum::kOut].ndim(); ++i) {
trailing *= out_data[concat_enum::kOut].shape_[i];
}
- size_t mid = out_data[concat_enum::kOut].shape_[dimension_];
+ size_t mid = out_data[concat_enum::kOut].shape_[axis];
Shape<3> oshape = Shape3(leading, mid, trailing);
out = out_data[concat_enum::kOut].get_with_shape<xpu, 3, DType>(oshape, s);
for (int i = 0; i < size_; ++i) {
- Shape<3> dshape = Shape3(leading, in_data[i].shape_[dimension_], trailing);
+ Shape<3> dshape = Shape3(leading, in_data[i].shape_[axis], trailing);
data[i] = in_data[i].get_with_shape<xpu, 3, DType>(dshape, s);
}
Concatenate(data, &out, 1, req[concat_enum::kOut]);
@@ -103,22 +104,23 @@ class ConcatOp : public Operator {
using namespace mshadow::expr;
CHECK_EQ(out_grad.size(), 1U);
CHECK_EQ(in_grad.size(), static_cast<size_t>(size_));
+ int axis = CheckAxis(dimension_, out_grad[concat_enum::kData0].ndim());
Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Tensor<xpu, 3, DType> > grad_in(size_);
Tensor<xpu, 3, DType> grad;
size_t leading = 1, trailing = 1;
- for (int i = 0; i < dimension_; ++i) {
+ for (int i = 0; i < axis; ++i) {
leading *= out_grad[concat_enum::kOut].shape_[i];
}
- for (int i = dimension_ + 1; i < out_grad[concat_enum::kOut].ndim(); ++i) {
+ for (int i = axis + 1; i < out_grad[concat_enum::kOut].ndim(); ++i) {
trailing *= out_grad[concat_enum::kOut].shape_[i];
}
- size_t mid = out_grad[concat_enum::kOut].shape_[dimension_];
+ size_t mid = out_grad[concat_enum::kOut].shape_[axis];
Shape<3> oshape = Shape3(leading, mid, trailing);
grad = out_grad[concat_enum::kOut].get_with_shape<xpu, 3, DType>(oshape, s);
for (int i = 0; i < size_; ++i) {
- Shape<3> dshape = Shape3(leading, in_grad[i].shape_[dimension_], trailing);
+ Shape<3> dshape = Shape3(leading, in_grad[i].shape_[axis], trailing);
grad_in[i] = in_grad[i].get_with_shape<xpu, 3, DType>(dshape, s);
}
Split(grad, &grad_in, 1, req);
@@ -159,23 +161,22 @@ class ConcatProp : public OperatorProperty {
TShape dshape;
index_t size = 0;
bool has_zero = false;
+ int axis = -1;
for (int i = 0; i < param_.num_args; ++i) {
TShape tmp = (*in_shape)[i];
if (tmp.ndim()) {
- CHECK_LT(static_cast<index_t>(param_.dim), tmp.ndim())
- << "concat dim " << param_.dim << " out of range of input shape " << tmp;
- has_zero = tmp[param_.dim] == 0 || has_zero;
- size += tmp[param_.dim];
- tmp[param_.dim] = 0;
+ axis = CheckAxis(param_.dim, tmp.ndim());
+ has_zero = tmp[axis] == 0 || has_zero;
+ size += tmp[axis];
+ tmp[axis] = 0;
shape_assign(&dshape, tmp);
}
}
TShape tmp = (*out_shape)[0];
if (tmp.ndim()) {
- CHECK_LT(static_cast<index_t>(param_.dim), tmp.ndim())
- << "concat dim " << param_.dim << " out of range of input shape " << tmp;
- tmp[param_.dim] = 0;
+ axis = CheckAxis(param_.dim, tmp.ndim());
+ tmp[axis] = 0;
shape_assign(&dshape, tmp);
}
@@ -186,7 +187,7 @@ class ConcatProp : public OperatorProperty {
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}
- if (!has_zero) dshape[param_.dim] = size;
+ if (!has_zero) dshape[axis] = size;
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index ba1b991..d05e325 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -134,6 +134,10 @@ def test_concat():
shapes.append((a, merge[i]))
check_concat_with_shape(shapes,dimension,True)
check_concat_with_shape(shapes,dimension,False)
+ # Test negative dim
+ check_concat_with_shape(shapes, dimension - 2, True)
+ check_concat_with_shape(shapes, dimension - 2, False)
+
#test 3D
if dimension<3:
for dim in range(2, 6):
@@ -147,6 +151,9 @@ def test_concat():
shapes.append((a,b,merge[i]))
check_concat_with_shape(shapes,dimension,True)
check_concat_with_shape(shapes,dimension,False)
+ # Test negative dim
+ check_concat_with_shape(shapes, dimension - 3, True)
+ check_concat_with_shape(shapes, dimension - 3, False)
# test 4D
for dim in range(2, 6):
shapes = []
@@ -161,7 +168,9 @@ def test_concat():
shapes.append((a,b,c,merge[i]))
check_concat_with_shape(shapes,dimension,True)
check_concat_with_shape(shapes,dimension,False)
-
+ # Test negative dim
+ check_concat_with_shape(shapes, dimension - 4, True)
+ check_concat_with_shape(shapes, dimension - 4, False)
def test_slice_channel():
def check_slice_channel(data_ndim, axis, num_outputs, squeeze_axis):
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].