You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/14 21:57:57 UTC

[GitHub] eric-haibin-lin commented on a change in pull request #10104: [WIP] Fused RNN implementation for CPU

eric-haibin-lin commented on a change in pull request #10104: [WIP] Fused RNN implementation for CPU
URL: https://github.com/apache/incubator-mxnet/pull/10104#discussion_r174622615
 
 

 ##########
 File path: src/operator/rnn.cc
 ##########
 @@ -21,38 +21,212 @@
  * Copyright (c) 2015 by Contributors
  * \file rnn.cc
  * \brief
- * \author Sebastian Bodenstein
+ * \author Sebastian Bodenstein, Shu Zhang(shu.zhang@intel.com)
 */
-
 #include "./rnn-inl.h"
 
 namespace mxnet {
 namespace op {
-template<>
-Operator *CreateOp<cpu>(RNNParam param, int dtype) {
-  Operator *op = NULL;
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new RNNOp<cpu, DType>(param);
-  });
-  return op;
+
+DMLC_REGISTER_PARAMETER(RNNParam);
+static inline std::vector<std::string> ListArguments(const RNNParam& param_) {
+  if (param_.mode == rnn_enum::kLstm) {
+    return {"data", "parameters", "state", "state_cell"};
+  } else {
+    return {"data", "parameters", "state"};
+  }
+}
+static inline int NumVisibleOutputs(const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  int mode_num = (params.mode == rnn_enum::kLstm) ? 2 : 1;
+  int num_outputs = params.state_outputs ? (mode_num + 1) : 1;
+  return num_outputs;
 }
+static bool RNNShape(const nnvm::NodeAttrs& attrs,
+                     std::vector<TShape> *in_shape,
+                     std::vector<TShape> *out_shape) {
+  const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
+  using namespace mshadow;
+  if (param_.mode == rnn_enum::kLstm) {
+    CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]";
+  } else {
+    CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
+  }
+  const TShape &dshape = (*in_shape)[rnn_enum::kData];
+  if (dshape.ndim() ==  0) return false;
+  CHECK_EQ(dshape.ndim(), 3U) \
+      << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]";
+  // data: [sequence len, batch, input dimension]
+  int batch_size = dshape[1];
+  int input_size = dshape[2];
+  int numDirections = param_.bidirectional ? 2 : 1;
+  int total_layers = numDirections * param_.num_layers;  // double for bidirectional
+  SHAPE_ASSIGN_CHECK(*in_shape,
+                     rnn_enum::kState,
+                     Shape3(total_layers, batch_size, param_.state_size));
+  if (param_.mode == rnn_enum::kLstm)
+    SHAPE_ASSIGN_CHECK(*in_shape,
+                      rnn_enum::kStateCell,
+                      Shape3(total_layers, batch_size, param_.state_size));
 
-Operator *RNNProp::CreateOperatorEx(Context ctx,
-                                  std::vector<TShape> *in_shape,
-                                  std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+  // calculate parameter vector length
+  int param_size = rnn_param_size(param_.num_layers,
+                                  input_size,
+                                  param_.state_size,
+                                  param_.bidirectional,
+                                  param_.mode);
+  SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
+
+  out_shape->clear();
+  // output: [sequence len, batch, output size]
+  TShape oshape = dshape;
+  oshape[2] = numDirections * param_.state_size;
+  out_shape->push_back(oshape);
+  if (param_.state_outputs) {
+    // outStateShape: [layer_num, batch, state size]
+    TShape outStateShape = dshape;
+    outStateShape[0] = total_layers;
+    outStateShape[1] = batch_size;
+    outStateShape[2] = param_.state_size;
+    out_shape->push_back(outStateShape);
+    // Deal with lstm cell state
+    if (param_.mode == rnn_enum::kLstm)
+      out_shape->push_back(outStateShape);
+  }
+  // the reserve space shape
+  TShape outReserveShape = (*in_shape)[rnn_enum::kParams];
+  outReserveShape[0] = GetRNNReserveSpaceSize(dshape[0],
+                                              batch_size,
+                                              param_.state_size,
+                                              param_.mode);
+  out_shape->push_back(outReserveShape);
+  return true;
 }
 
-DMLC_REGISTER_PARAMETER(RNNParam);
+static bool RNNType(const nnvm::NodeAttrs& attrs,
+                    std::vector<int> *in_type,
+                    std::vector<int> *out_type) {
+  const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
+  CHECK_GE(in_type->size(), 1U);
+  int dtype = (*in_type)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  for (index_t i = 0; i < in_type->size(); ++i) {
+    if ((*in_type)[i] == -1) {
+      (*in_type)[i] = dtype;
+    } else {
+      UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
+    }
+  }
+  out_type->clear();
+  out_type->push_back(dtype);
+  if (param_.state_outputs) {
+    out_type->push_back(dtype);
+    // Deal with lstm cell state
+    if (param_.mode == rnn_enum::kLstm)
+      out_type->push_back(dtype);
+  }
+  out_type->push_back(dtype);
+  return true;
+}
 
-MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
-.describe("Applies a recurrent layer to input.")
+inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
+                                  const int dev_mask,
+                                  DispatchMode* dispatch_mode,
+                                  std::vector<int> *in_attrs,
+                                  std::vector<int> *out_attrs) {
+  DispatchMode wanted_mode = DispatchMode::kFCompute;
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
+}
+
+inline static bool BackwardRNNStorageType(const nnvm::NodeAttrs& attrs,
+                                          const int dev_mask,
+                                          DispatchMode* dispatch_mode,
+                                          std::vector<int> *in_attrs,
+                                          std::vector<int> *out_attrs) {
+  DispatchMode wanted_mode = DispatchMode::kFCompute;
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
+}
+
+struct RNNGrad {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr &n,
+          const std::vector<nnvm::NodeEntry> &ograd) const {
+    const RNNParam& params = nnvm::get<RNNParam>(n->attrs.parsed);
+    std::vector<nnvm::NodeEntry> heads{ n->inputs[rnn_enum::kData],
+      n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] };
+    heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kOut, 0});
+    heads.push_back(ograd[rnn_enum::kOut]);
+    // index of space that reserve forward intermediate result
+    uint32_t kTmpSpaceIdx = rnn_enum::kOut + 1;
+    if (params.state_outputs) {
+      heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateOut, 0});
+      heads.push_back(ograd[rnn_enum::kStateOut]);
+      ++kTmpSpaceIdx;
+    }
+    if (params.mode == rnn_enum::kLstm) {
+      heads.push_back(n->inputs[rnn_enum::kStateCell]);
+      if (params.state_outputs) {
+        heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateCellOut, 0});
+        heads.push_back(ograd[rnn_enum::kStateCellOut]);
+        ++kTmpSpaceIdx;
+      }
+    }
+    heads.emplace_back(nnvm::NodeEntry{n, kTmpSpaceIdx, 0});
+    return MakeGradNode(op_name, n, heads, n->attrs.dict);
+  }
+};
+
+NNVM_REGISTER_OP(RNN)
+.describe(R"code(Applies a recurrent layer to input
 
 Review comment:
   Please provide more detailed descriptions

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services