You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/02/11 22:59:51 UTC

[GitHub] larroy commented on a change in pull request #14098: softmax for fp16 with fp32 accumulator

larroy commented on a change in pull request #14098: softmax for fp16 with fp32 accumulator
URL: https://github.com/apache/incubator-mxnet/pull/14098#discussion_r255730106
 
 

 ##########
 File path: src/operator/nn/softmax-inl.h
 ##########
 @@ -275,14 +292,70 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
 struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
   int axis;
   dmlc::optional<double> temperature;
+  dmlc::optional<int> dtype;
   DMLC_DECLARE_PARAMETER(SoftmaxParam) {
     DMLC_DECLARE_FIELD(axis).set_default(-1)
-      .describe("The axis along which to compute softmax.");
+    .describe("The axis along which to compute softmax.");
     DMLC_DECLARE_FIELD(temperature).set_default(dmlc::optional<double>())
-      .describe("Temperature parameter in softmax");
+    .describe("Temperature parameter in softmax");
+    DMLC_DECLARE_FIELD(dtype)
+    .add_enum("float16", mshadow::kFloat16)
+    .add_enum("float32", mshadow::kFloat32)
+    .add_enum("float64", mshadow::kFloat64)
+    .set_default(dmlc::optional<int>())
+    .describe("DType of the output in case this can't be inferred. "
+              "Defaults to the same as input's dtype if not defined (dtype=None).");
   }
 };
 
+inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int>* in_attrs,
+                          std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
+
+  int arg_dtype = param.dtype.has_value()?param.dtype.value():-1,
 
 Review comment:
   can we add spaces around ternary for readability?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services