You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/01 18:20:27 UTC

[GitHub] szha commented on a change in pull request #9931: Add axes support to Dropout for variational dropout in NLP

szha commented on a change in pull request #9931: Add axes support to Dropout for variational dropout in NLP
URL: https://github.com/apache/incubator-mxnet/pull/9931#discussion_r171647499
 
 

 ##########
 File path: src/operator/nn/dropout-inl.h
 ##########
 @@ -67,9 +70,92 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
     .add_enum("always", dropout::kAlways)
     .set_default(dropout::kTraining)
     .describe("Whether to only turn on dropout during training or to also turn on for inference.");
+    DMLC_DECLARE_FIELD(axes).set_default(TShape())
+    .describe("Axes for variational dropout kernel.");
   }
 };  // struct DropoutParam
 
+namespace mxnet_op {
+template<int ndim, typename DType, typename OP>
+struct binary_broadcast_kernel {
+  /*! \brief Map function for binary_broadcast_kernel */
+  MSHADOW_XINLINE static void Map(int base, int length, OpReqType req,
+                                  const Shape <ndim> &lstride, const Shape <ndim> &rstride,
+                                  const Shape <ndim> &oshape, DType *lhs, DType *rhs,
+                                  DType *out) {
+    Shape <ndim> coord = unravel(base, oshape);
+    auto lidx = static_cast<index_t>(dot(coord, lstride));
+    auto ridx = static_cast<index_t>(dot(coord, rstride));
+    KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
+    // starts from 1 to avoid extra inc at end of loop
+    for (int i = 1; i < length; ++i) {
+      inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+      // When tuning, don't actually run the op, since it's not going to be tuned against
+      // the actual op we'll eventually be using
+      KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
+    }
+  }
+};
+}  // namespace mxnet_op
+
+#define BROADCAST_NDIM_SWITCH(ndim, NDim, ...)  \
+  if (ndim <= 2) {                    \
+    const int NDim = 2;               \
+    {__VA_ARGS__}                     \
+  } else if (ndim <= 4) {             \
+    const int NDim = 4;               \
+    {__VA_ARGS__}                     \
+  } else if (ndim <= MAX_DIM) {  \
+    const int NDim = MAX_DIM;    \
+    {__VA_ARGS__}                     \
+  } else {                            \
+    LOG(FATAL) << "NDim too large ";  \
+  }
+
+inline int BinaryBroadcastShapeCompact(const TShape& lshape, const TShape& rshape,
 
 Review comment:
   These look like a copy from src/operator/tensor/elemwise_binary_broadcast_op.h. Can we avoid copying the code?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services