You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/10/30 02:05:32 UTC

[incubator-mxnet] branch master updated: [FEATURE] Add oneDNN support for numpy concatenate operator (#20652)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 94fc557  [FEATURE] Add oneDNN support for numpy concatenate operator (#20652)
94fc557 is described below

commit 94fc557c1716ca633508d9aa69ede5ed1b01892e
Author: AdamGrabowski <ad...@intel.com>
AuthorDate: Sat Oct 30 04:04:07 2021 +0200

    [FEATURE] Add oneDNN support for numpy concatenate operator (#20652)
    
    * Integrate oneDNN support for numpy concatenate operator
    
    * Fix sanity check errors
    
    * Remove redundant npi_concatenate nnvm_register_op for gpu
    
    * Fix sanity check
    
    * Fix gpu register op and dmlc::optional problems
    
    * Fix GPU build
---
 cpp-package/example/charRNN.cpp                    |   3 +-
 python/mxnet/amp/lists/symbol_fp16.py              |   1 -
 src/api/operator/numpy/np_matrix_op.cc             |   8 +-
 src/operator/nn/concat-inl.h                       |  34 ++++--
 src/operator/nn/concat.cc                          |  39 +++++--
 src/operator/nn/concat.cu                          |   5 +-
 src/operator/nn/dnnl/dnnl_concat.cc                |   8 +-
 src/operator/numpy/np_matrix_op.cc                 | 128 +--------------------
 src/operator/numpy/np_matrix_op.cu                 |   6 -
 .../quantization/dnnl/dnnl_quantized_concat.cc     |   4 +-
 src/operator/quantization/quantized_concat.cc      |  43 ++-----
 src/operator/subgraph/tensorrt/nnvm_to_onnx.cc     |   3 +-
 src/operator/subgraph/tensorrt/tensorrt-inl.h      |   3 +-
 tests/python/unittest/test_numpy_op.py             |   2 +-
 14 files changed, 85 insertions(+), 202 deletions(-)

diff --git a/cpp-package/example/charRNN.cpp b/cpp-package/example/charRNN.cpp
index 0b87abf..39d88a8 100644
--- a/cpp-package/example/charRNN.cpp
+++ b/cpp-package/example/charRNN.cpp
@@ -125,7 +125,8 @@ Symbol LSTMUnroll(int num_lstm_layer, int sequence_length, int input_dim,
     hidden_all.push_back(hidden);
   }
 
-  auto hidden_concat = isTrain? Concat(hidden_all, hidden_all.size(), 0) : hidden_all[0];
+  auto hidden_concat =
+      isTrain ? Concat(hidden_all, hidden_all.size(), dmlc::optional<int>(0)) : hidden_all[0];
   auto cls_weight = Symbol::Variable("cls_weight");
   auto cls_bias = Symbol::Variable("cls_bias");
   auto pred = FullyConnected("pred", hidden_concat, cls_weight, cls_bias, input_dim);
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index 7e2f715..45ac56c 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -646,7 +646,6 @@ WIDEST_TYPE_CASTS = [
     '_mod',
     '_not_equal',
     '_npi_column_stack',
-    '_npi_concatenate',
     '_npi_copysign',
     '_npi_cross',
     '_npi_dot',
diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc
index 921dd5f..498e11b 100644
--- a/src/api/operator/numpy/np_matrix_op.cc
+++ b/src/api/operator/numpy/np_matrix_op.cc
@@ -139,17 +139,17 @@ MXNET_REGISTER_API("_npi.concatenate")
       using namespace runtime;
       const nnvm::Op* op = Op::Get("_npi_concatenate");
       nnvm::NodeAttrs attrs;
-      op::NumpyConcatenateParam param;
+      op::ConcatParam param;
       int arg_size   = args.num_args;
       param.num_args = arg_size - 2;
       if (args[arg_size - 2].type_code() == kNull) {
-        param.axis = dmlc::nullopt;
+        param.dim = dmlc::nullopt;
       } else {
-        param.axis = args[arg_size - 2].operator int();
+        param.dim = args[arg_size - 2].operator int();
       }
       attrs.parsed = param;
       attrs.op     = op;
-      SetAttrDict<op::NumpyConcatenateParam>(&attrs);
+      SetAttrDict<op::ConcatParam>(&attrs);
       int num_inputs = arg_size - 2;
       std::vector<NDArray*> inputs;
       inputs.reserve(num_inputs);
diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h
index 778324a..01cde26 100644
--- a/src/operator/nn/concat-inl.h
+++ b/src/operator/nn/concat-inl.h
@@ -47,10 +47,12 @@ enum ConcatOpOutputs { kOut };
 
 struct ConcatParam : public dmlc::Parameter<ConcatParam> {
   int num_args;
-  int dim;
+  dmlc::optional<int> dim;
   DMLC_DECLARE_PARAMETER(ConcatParam) {
     DMLC_DECLARE_FIELD(num_args).set_lower_bound(1).describe("Number of inputs to be concated.");
-    DMLC_DECLARE_FIELD(dim).set_default(1).describe("the dimension to be concated.");
+    DMLC_DECLARE_FIELD(dim)
+        .set_default(dmlc::optional<int>(1))
+        .describe("the dimension to be concated.");
   }
   void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
     std::ostringstream num_args_s, dim_s;
@@ -66,7 +68,7 @@ class ConcatOp {
  public:
   void Init(const ConcatParam& param) {
     this->size_      = param.num_args;
-    this->dimension_ = param.dim;
+    this->dimension_ = param.dim.has_value() ? param.dim.value() : 0;
   }
 
   void Forward(const OpContext& ctx,
@@ -140,10 +142,18 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs,
                    const std::vector<OpReqType>& req,
                    const std::vector<TBlob>& outputs) {
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
-  MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
+  std::vector<TBlob> in_data(param.num_args);
+  for (int i = 0; i < param.num_args; i++) {
+    if (!param.dim.has_value()) {
+      in_data[i] = inputs[i].reshape(mxnet::TShape(1, inputs[i].shape_.Size()));
+    } else {
+      in_data[i] = inputs[i];
+    }
+  }
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(in_data[concat_enum::kData0].type_flag_, DType, {
     ConcatOp<xpu, DType> op;
     op.Init(param);
-    op.Forward(ctx, inputs, req, outputs);
+    op.Forward(ctx, in_data, req, outputs);
   });
 }
 
@@ -209,10 +219,18 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs,
                        const std::vector<OpReqType>& req,
                        const std::vector<TBlob>& outputs) {
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
-  MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
+  std::vector<TBlob> out_data(param.num_args);
+  for (int i = 0; i < param.num_args; i++) {
+    if (!param.dim.has_value()) {
+      out_data[i] = outputs[i].reshape(mxnet::TShape(1, outputs[i].shape_.Size()));
+    } else {
+      out_data[i] = outputs[i];
+    }
+  }
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[concat_enum::kOut].type_flag_, DType, {
     ConcatOp<xpu, DType> op;
     op.Init(param);
-    op.Backward(ctx, inputs[concat_enum::kOut], req, outputs);
+    op.Backward(ctx, inputs[concat_enum::kOut], req, out_data);
   });
 }
 
@@ -318,7 +336,7 @@ void ConcatCSRImpl(const nnvm::NodeAttrs& attrs,
   using namespace csr;
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
   int num_args             = param.num_args;
-  int concat_dim           = param.dim;
+  int concat_dim           = param.dim.has_value() ? param.dim.value() : 0;
   CHECK_EQ(inputs.size(), num_args);
   CHECK_EQ(outputs.size(), 1);
   int axis = CheckAxis(concat_dim, inputs[0].shape().ndim());
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 580183f..6206c8e 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -39,12 +39,18 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
   mxnet::TShape dshape;
   dim_t size                = 0;
+  int param_dim             = param_.dim.has_value() ? param_.dim.value() : 0;
   bool has_unknown_dim_size = false;
   int axis                  = -1;
+  if (!param_.dim.has_value()) {
+    for (int i = 0; i < param_.num_args; ++i) {
+      (*in_shape)[i] = Shape1((*in_shape)[i].Size());
+    }
+  }
   for (int i = 0; i < param_.num_args; ++i) {
     mxnet::TShape tmp = (*in_shape)[i];
     if (tmp.ndim() > 0) {
-      axis                 = CheckAxis(param_.dim, tmp.ndim());
+      axis                 = CheckAxis(param_dim, tmp.ndim());
       has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
       size += tmp[axis];
       tmp[axis] = -1;
@@ -54,7 +60,7 @@ bool ConcatShape(const nnvm::NodeAttrs& attrs,
 
   mxnet::TShape tmp = (*out_shape)[0];
   if (tmp.ndim() > 0) {
-    axis      = CheckAxis(param_.dim, tmp.ndim());
+    axis      = CheckAxis(param_dim, tmp.ndim());
     tmp[axis] = -1;
     shape_assign(&dshape, tmp);
   }
@@ -89,11 +95,12 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
   mxnet::TShape dshape;
   index_t size = 0;
   std::vector<int> zero_indices;
-  int axis = -1;
+  int axis      = -1;
+  int param_dim = param_.dim.has_value() ? param_.dim.value() : 0;
   for (int i = 0; i < param_.num_args; ++i) {
     mxnet::TShape tmp = (*in_shape)[i];
     if (tmp.ndim() > 0) {
-      axis = CheckAxis(param_.dim, tmp.ndim());
+      axis = CheckAxis(param_dim, tmp.ndim());
       if (!mxnet::dim_size_is_known(tmp, axis)) {
         zero_indices.emplace_back(i);
       } else {
@@ -107,7 +114,7 @@ static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
 
   mxnet::TShape tmp = (*out_shape)[0];
   if (tmp.ndim() > 0) {
-    axis      = CheckAxis(param_.dim, tmp.ndim());
+    axis      = CheckAxis(param_dim, tmp.ndim());
     tmp[axis] = -1;
     shape_assign(&dshape, tmp);
   }
@@ -193,13 +200,14 @@ inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
   auto& out_stype          = out_attrs->at(0);
   bool dispatched          = false;
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
-  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage) && param.dim == 0) {
+  int param_dim            = param.dim.has_value() ? param.dim.value() : 0;
+  if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage) && param_dim == 0) {
     dispatched =
         storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx);
   }
 #if MXNET_USE_ONEDNN == 1
   if (!dispatched && dev_mask == mshadow::cpu::kDevMask &&
-      common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.dim > 0) {
+      common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
     dispatched =
         storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx);
   }
@@ -225,10 +233,8 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
                                              std::vector<int>* out_attrs) {
   DispatchMode wanted_mode;
 #if MXNET_USE_ONEDNN == 1
-  const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
   CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
-  if (dev_mask == mshadow::cpu::kDevMask &&
-      common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.dim > 0)
+  if (dev_mask == mshadow::cpu::kDevMask && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage))
     wanted_mode = DispatchMode::kFComputeEx;
   else
 #endif  // MXNET_USE_ONEDNN == 1
@@ -251,8 +257,9 @@ bool SupportDNNLConcat(const std::vector<NDArray>& arrs) {
       return false;
     int ndim               = arr.shape().ndim();
     const int dnnl_ndims   = arr.GetDNNLData()->get_desc().data.ndims;
-    if (!(ndim == 2 || ndim == 4) || ndim != dnnl_ndims)
+    if ((ndim != 2 && ndim != 4) || ndim != dnnl_ndims) {
       return false;
+    }
   }
   return true;
 }
@@ -347,12 +354,14 @@ DMLC_REGISTER_PARAMETER(ConcatParam);
 NNVM_REGISTER_OP(Concat)
 MXNET_ADD_SPARSE_OP_ALIAS(concat)
     .add_alias("concat")
+    .add_alias("_npi_concatenate")
     .describe(R"code(Joins input arrays along a given axis.
 
 .. note:: `Concat` is deprecated. Use `concat` instead.
 
 The dimensions of the input arrays should be the same except the axis along
-which they will be concatenated.
+which they will be concatenated. With dimension parameter ``None`` input 
+arrays are flattened before concatenating them along axis 0.
 The dimension of the output array along the concatenated axis will be equal
 to the sum of the corresponding dimensions of the input arrays.
 
@@ -376,6 +385,11 @@ Example::
                           [ 7.,  7.],
                           [ 8.,  8.]]
 
+   concat(x,y,z,dim=None) = [1., 1., 2., 2., 
+                             3., 3., 4., 4.,
+                             5., 5., 6., 6.,
+                             7., 7., 8., 8.]
+
    Note that you cannot concat x,y,z along dimension 1 since dimension
    0 is not the same for all the input arrays.
 
@@ -397,6 +411,7 @@ Example::
     .add_arguments(ConcatParam::__FIELDS__());
 
 NNVM_REGISTER_OP(_backward_Concat)
+    .add_alias("_backward_np_concat")
     .set_num_inputs([](const NodeAttrs& attrs) {
 #if MXNET_USE_ONEDNN == 1
       const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
diff --git a/src/operator/nn/concat.cu b/src/operator/nn/concat.cu
index d50d218..b9926d0 100644
--- a/src/operator/nn/concat.cu
+++ b/src/operator/nn/concat.cu
@@ -47,6 +47,7 @@ static void ConcatComputeExGPU(const nnvm::NodeAttrs& attrs,
 }
 
 NNVM_REGISTER_OP(Concat)
+    .add_alias("_npi_concatenate")
     .set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
     .set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);
 
@@ -54,7 +55,9 @@ NNVM_REGISTER_OP(_rnn_param_concat)
     .set_attr<FCompute>("FCompute<gpu>", ConcatCompute<gpu>)
     .set_attr<FComputeEx>("FComputeEx<gpu>", ConcatComputeExGPU);
 
-NNVM_REGISTER_OP(_backward_Concat).set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);
+NNVM_REGISTER_OP(_backward_Concat)
+    .add_alias("_backward_np_concat")
+    .set_attr<FCompute>("FCompute<gpu>", ConcatGradCompute<gpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/dnnl/dnnl_concat.cc b/src/operator/nn/dnnl/dnnl_concat.cc
index 1214a31..83ba9df 100644
--- a/src/operator/nn/dnnl/dnnl_concat.cc
+++ b/src/operator/nn/dnnl/dnnl_concat.cc
@@ -64,7 +64,8 @@ void DNNLConcatForward(const nnvm::NodeAttrs& attrs,
   TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
   const int num_in_data    = param.num_args;
-  const int concat_dim     = param.dim;
+  int concat_dim           = param.dim.has_value() ? param.dim.value() : 0;
+  concat_dim               = CheckAxis(concat_dim, in_data[concat_enum::kData0].shape().ndim());
   std::vector<dnnl::memory::desc> data_md;
   std::vector<const dnnl::memory*> data_mem;
   data_md.reserve(num_in_data);
@@ -96,7 +97,8 @@ void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
   TmpMemMgr::Get()->Init(ctx.requested[concat_enum::kTempSpace]);
   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
   const int num_in_data    = param.num_args;
-  const int axis           = param.dim;
+  int concat_dim           = param.dim.has_value() ? param.dim.value() : 0;
+  concat_dim               = CheckAxis(concat_dim, outputs[concat_enum::kData0].shape().ndim());
   const auto gradz_mem     = inputs[0].GetDNNLData();
   /* init the offset */
   dnnl::memory::dims offsets(outputs[0].shape().ndim());
@@ -112,7 +114,7 @@ void DNNLConcatBackward(const nnvm::NodeAttrs& attrs,
     auto from_md = gradz_mem->get_desc().submemory_desc(diff_src_tz, offsets);
     auto from_mem =
         new dnnl::memory(from_md, gradz_mem->get_engine(), gradz_mem->get_data_handle());
-    offsets[axis] += diff_src_tz[axis];
+    offsets[concat_dim] += diff_src_tz[concat_dim];
 
     std::unordered_map<int, dnnl::memory> net_args(
         {{DNNL_ARG_FROM, *gradz_mem}, {DNNL_ARG_TO, *gradi_mem.second}});
diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc
index 042ff10..bc9e139 100644
--- a/src/operator/numpy/np_matrix_op.cc
+++ b/src/operator/numpy/np_matrix_op.cc
@@ -499,7 +499,7 @@ bool HStackShape(const nnvm::NodeAttrs& attrs,
 
   mxnet::TShape tmp = (*out_shape)[0];
   if (tmp.ndim() > 0) {
-    axis      = CheckAxis(param_.dim, tmp.ndim());
+    axis      = CheckAxis(param_.dim.value(), tmp.ndim());
     tmp[axis] = -1;
     shape_assign(&dshape, tmp);
   }
@@ -561,7 +561,7 @@ bool DStackShape(const nnvm::NodeAttrs& attrs,
 
   mxnet::TShape tmp = (*out_shape)[0];
   if (tmp.ndim() > 0) {
-    axis      = CheckAxis(param_.dim, tmp.ndim());
+    axis      = CheckAxis(param_.dim.value(), tmp.ndim());
     tmp[axis] = -1;
     shape_assign(&dshape, tmp);
   }
@@ -588,86 +588,6 @@ bool ConcatType(const nnvm::NodeAttrs& attrs,
                 std::vector<int>* in_type,
                 std::vector<int>* out_type);
 
-bool NumpyConcatenateType(const nnvm::NodeAttrs& attrs,
-                          std::vector<int>* in_type,
-                          std::vector<int>* out_type) {
-  const NumpyConcatenateParam& param = nnvm::get<NumpyConcatenateParam>(attrs.parsed);
-  const int num_args                 = param.num_args;
-  CHECK_EQ(in_type->size(), num_args);
-  CHECK_EQ(out_type->size(), 1);
-  int dtype = -1;
-  for (int i = 0; i < num_args; i++) {
-    if (dtype == -1) {
-      dtype = in_type->at(i);
-    }
-  }
-  if (dtype == -1) {
-    dtype = out_type->at(0);
-  }
-  for (int i = 0; i < num_args; i++) {
-    TYPE_ASSIGN_CHECK(*in_type, i, dtype);
-  }
-  TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
-  return dtype != -1;
-}
-
-bool NumpyConcatenateShape(const nnvm::NodeAttrs& attrs,
-                           mxnet::ShapeVector* in_shape,
-                           mxnet::ShapeVector* out_shape) {
-  using namespace mshadow;
-  const NumpyConcatenateParam& param_ = nnvm::get<NumpyConcatenateParam>(attrs.parsed);
-  const int num_args                  = param_.num_args;
-  CHECK_EQ(in_shape->size(), num_args);
-
-  int param_axis;
-  if (!(param_.axis.has_value())) {
-    for (int i = 0; i < num_args; ++i) {
-      (*in_shape)[i] = Shape1((*in_shape)[i].Size());
-    }
-    param_axis = 0;
-  } else {
-    param_axis = param_.axis.value();
-  }
-
-  mxnet::TShape dshape;
-  dim_t size                = 0;
-  bool has_unknown_dim_size = false;
-  int axis                  = -1;
-  for (int i = 0; i < num_args; ++i) {
-    mxnet::TShape tmp = (*in_shape)[i];
-    if (tmp.ndim() > 0) {
-      axis                 = CheckAxis(param_axis, tmp.ndim());
-      has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
-      size += tmp[axis];
-      tmp[axis] = -1;
-      shape_assign(&dshape, tmp);
-    }
-  }
-
-  mxnet::TShape tmp = (*out_shape)[0];
-  if (tmp.ndim() > 0) {
-    axis      = CheckAxis(param_axis, tmp.ndim());
-    tmp[axis] = -1;
-    shape_assign(&dshape, tmp);
-  }
-
-  if (dshape.ndim() == -1)
-    return false;
-  CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";
-
-  for (int i = 0; i < num_args; ++i) {
-    CHECK(shape_assign(&(*in_shape)[i], dshape))
-        << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
-  }
-
-  if (!has_unknown_dim_size)
-    dshape[axis] = size;
-  CHECK(shape_assign(&(*out_shape)[0], dshape))
-      << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
-
-  return shape_is_known(dshape);
-}
-
 struct NumpyConcatGrad {
   const char* op_name;
   std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
@@ -677,50 +597,6 @@ struct NumpyConcatGrad {
     return MakeGradNode(op_name, n, heads, n->attrs.dict);
   }
 };
-
-DMLC_REGISTER_PARAMETER(NumpyConcatenateParam);
-
-NNVM_REGISTER_OP(_npi_concatenate)
-    .describe(R"code(Join a sequence of arrays along an existing axis.)code" ADD_FILELINE)
-    .set_num_inputs([](const NodeAttrs& attrs) {
-      const NumpyConcatenateParam& params = nnvm::get<NumpyConcatenateParam>(attrs.parsed);
-      return params.num_args;
-    })
-    .set_num_outputs(1)
-    .set_attr_parser(ParamParser<NumpyConcatenateParam>)
-    .set_attr<nnvm::FListInputNames>("FListInputNames",
-                                     [](const NodeAttrs& attrs) {
-                                       const NumpyConcatenateParam& params =
-                                           nnvm::get<NumpyConcatenateParam>(attrs.parsed);
-                                       std::vector<std::string> ret;
-                                       ret.reserve(params.num_args);
-                                       for (int i = 0; i < params.num_args; ++i) {
-                                         ret.push_back(std::string("data") + std::to_string(i));
-                                       }
-                                       return ret;
-                                     })
-    .set_attr<nnvm::FListOutputNames>("FListOutputNames",
-                                      [](const NodeAttrs& attrs) {
-                                        return std::vector<std::string>{"out"};
-                                      })
-    .set_attr<std::string>("key_var_num_args", "num_args")
-    .set_attr<nnvm::FInferType>("FInferType", NumpyConcatenateType)
-    .set_attr<mxnet::FInferShape>("FInferShape", NumpyConcatenateShape)
-    .set_attr<FCompute>("FCompute<cpu>", NumpyConcatenateForward<cpu>)
-    .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_np_concat"})
-    .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
-    .add_arguments(ConcatParam::__FIELDS__());
-
-NNVM_REGISTER_OP(_backward_np_concat)
-    .set_num_inputs(1)
-    .set_num_outputs([](const NodeAttrs& attrs) {
-      const NumpyConcatenateParam& params = nnvm::get<NumpyConcatenateParam>(attrs.parsed);
-      return params.num_args;
-    })
-    .set_attr_parser(ParamParser<NumpyConcatenateParam>)
-    .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-    .set_attr<FCompute>("FCompute<cpu>", NumpyConcatenateBackward<cpu>);
-
 NNVM_REGISTER_OP(_npi_stack)
     .describe(R"code(Join a sequence of arrays along a new axis.
 
diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu
index b1ac4fe..7b7a3bd 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -34,12 +34,6 @@ NNVM_REGISTER_OP(_np_reshape).set_attr<FCompute>("FCompute<gpu>", UnaryOp::Ident
 
 NNVM_REGISTER_OP(_npi_squeeze).set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
 
-NNVM_REGISTER_OP(_npi_concatenate)
-    .set_attr<FCompute>("FCompute<gpu>", NumpyConcatenateForward<gpu>);
-
-NNVM_REGISTER_OP(_backward_np_concat)
-    .set_attr<FCompute>("FCompute<gpu>", NumpyConcatenateBackward<gpu>);
-
 NNVM_REGISTER_OP(_npi_stack).set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);
 
 NNVM_REGISTER_OP(_npi_vstack).set_attr<FCompute>("FCompute<gpu>", NumpyVstackForward<gpu>);
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_concat.cc b/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
index 06582cb..3409ec2 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_concat.cc
@@ -95,7 +95,9 @@ static void DNNLQuantizedConcatForward(const nnvm::NodeAttrs& attrs,
       data_md.push_back(mem_desc);
     }
   }
-  DNNLConcatFwd& fwd           = GetConcatForward(param_.dim, in_data, data_md);
+  int param_dim                = param_.dim.has_value() ? param_.dim.value() : 0;
+  param_dim                    = CheckAxis(param_dim, in_data[concat_enum::kData0].shape().ndim());
+  DNNLConcatFwd& fwd           = GetConcatForward(param_dim, in_data, data_md);
   mxnet::dnnl_output_t out_mem = CreateDNNLMem(
       out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_desc(), req[concat_enum::kOut]);
   dnnl_args_map_t net_args;
diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc
index f9f889e..43b0022 100644
--- a/src/operator/quantization/quantized_concat.cc
+++ b/src/operator/quantization/quantized_concat.cc
@@ -38,10 +38,11 @@ static bool QuantizedConcatShape(const nnvm::NodeAttrs& attrs,
   index_t size              = 0;
   bool has_unknown_dim_size = false;
   int axis                  = -1;
+  int param_dim             = param_.dim.has_value() ? param_.dim.value() : 0;
   for (int i = 0; i < param_.num_args; ++i) {
     mxnet::TShape tmp = (*in_shape)[i];
     if (tmp.ndim() > 0) {
-      axis                 = CheckAxis(param_.dim, tmp.ndim());
+      axis                 = CheckAxis(param_dim, tmp.ndim());
       has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
       size += tmp[axis];
       tmp[axis] = -1;
@@ -51,7 +52,7 @@ static bool QuantizedConcatShape(const nnvm::NodeAttrs& attrs,
 
   mxnet::TShape tmp = (*out_shape)[0];
   if (tmp.ndim() > 0) {
-    axis      = CheckAxis(param_.dim, tmp.ndim());
+    axis      = CheckAxis(param_dim, tmp.ndim());
     tmp[axis] = -1;
     shape_assign(&dshape, tmp);
   }
@@ -146,45 +147,15 @@ If any input holds int8, then the output will be int8. Otherwise output will be
     .add_arguments(ConcatParam::__FIELDS__());
 
 NNVM_REGISTER_OP(Concat).set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
-  const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
-  nnvm::ObjectPtr node     = nnvm::Node::Create();
-  if (param.dim > 0) {
-    node->attrs.op   = Op::Get("_contrib_quantized_concat");
-    node->attrs.name = "quantized_" + attrs.name;
-  } else {
-    LOG(INFO) << "Currently, quantized concat only supports dim>0, exclude " << attrs.name
-              << " which dim is " << param.dim;
-    node->attrs.op   = nullptr;
-    node->attrs.name = attrs.name;
-  }
-  node->attrs.dict = attrs.dict;
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+  node->attrs.op       = Op::Get("_contrib_quantized_concat");
+  node->attrs.name     = "quantized_" + attrs.name;
+  node->attrs.dict     = attrs.dict;
   if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
     node->op()->attr_parser(&(node->attrs));
   }
   return node;
 });
 
-NNVM_REGISTER_OP(_npi_concatenate)
-    .set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
-      const NumpyConcatenateParam& param = nnvm::get<NumpyConcatenateParam>(attrs.parsed);
-      nnvm::ObjectPtr node               = nnvm::Node::Create();
-      if (param.axis.has_value() && param.axis.value() > 0) {
-        node->attrs.op   = Op::Get("_contrib_quantized_concat");
-        node->attrs.name = "quantized_" + attrs.name;
-      } else {
-        LOG(INFO) << "Currently, quantized numpy concatenate only supports axis>0, exclude "
-                  << attrs.name << " which axis is " << param.axis;
-        node->attrs.op   = nullptr;
-        node->attrs.name = attrs.name;
-      }
-      node->attrs.dict        = attrs.dict;
-      node->attrs.dict["dim"] = node->attrs.dict["axis"];
-      node->attrs.dict.erase("axis");
-      if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
-        node->op()->attr_parser(&(node->attrs));
-      }
-      return node;
-    });
-
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
index 1b96726..5db3bb0 100644
--- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
+++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
@@ -669,13 +669,14 @@ void ConvertConcatenate(GraphProto* graph_proto,
   NodeProto* node_proto = graph_proto->add_node();
   node_proto->set_name(node_name);
   const auto& _param = nnvm::get<ConcatParam>(attrs.parsed);
+  const int param_dim = _param.dim.has_value() ? _param.dim.value() : 0;
   node_proto->set_op_type("Concat");
   node_proto->set_name(attrs.name);
   // axis
   AttributeProto* const axis = node_proto->add_attribute();
   axis->set_name("axis");
   axis->set_type(AttributeProto::INT);
-  axis->set_i(static_cast<int64_t>(_param.dim));
+  axis->set_i(static_cast<int64_t>(param_dim));
   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
 }
 
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index 564401b..d142dc1 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -193,7 +193,8 @@ class TensorrtSelector : public SubgraphSelector {
 
     if (op_name == "Concat") {
       const auto& param = nnvm::get<ConcatParam>(n.attrs.parsed);
-      return (param.dim != 0);
+      const int param_dim = param.dim.has_value() ? param.dim.value() : 0;
+      return (param_dim != 0);
     }
 
     if (op_name == "Dropout") {
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 1572061..6a2e6ac 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -4039,7 +4039,7 @@ def test_np_concat():
 
     shapes = [(0, 0), (2, 3), (2, 1, 3)]
     hybridizes = [True, False]
-    axes = [0, 1, None]
+    axes = [0, 1, -1, None]
     grad_reqs = ['write', 'add', 'null']
     dtypes = [np.float32, np.float64, np.bool]
     combinations = itertools.product(shapes, hybridizes, axes, grad_reqs, dtypes)