You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/11 19:44:08 UTC

[GitHub] haojin2 commented on a change in pull request #11229: [MXNET-379] L1 Normalization

haojin2 commented on a change in pull request #11229: [MXNET-379] L1 Normalization
URL: https://github.com/apache/incubator-mxnet/pull/11229#discussion_r194524295
 
 

 ##########
 File path: src/operator/tensor/broadcast_reduce_op.h
 ##########
 @@ -880,30 +880,34 @@ inline bool L2NormStorageType(const nnvm::NodeAttrs& attrs,
   int& out_stype = out_attrs->at(0);
   const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
   bool dispatched = false;
-  // l2 norm on a particular axis only supports cpu
-  const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
-  const auto dispatch_ex =
-      invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
-  if (!dispatched && in_stype == kDefaultStorage) {
-    // dns -> dns
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
-                                     DispatchMode::kFCompute);
-  }
-  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
-  if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) &&
-      axis.ndim() == 0 && param.ord == 2) {
-    // l2 norm: rsp/csr, axis = () -> dns
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
-                                     DispatchMode::kFComputeEx);
-  }
-  if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && !param.keepdims &&
-      (axis[0] == 0 || axis[0] == 1) && param.ord == 2) {
-    // l2 norm: csr, axis = 0/1 -> dns
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
-                                     dispatch_ex);
-  }
-  if (!dispatched) {
+  if (param.ord == 1) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+  } else if (param.ord == 2) {
+    // l2 norm on a particular axis only supports cpu
+    const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
+    const auto dispatch_ex =
+      invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx;
+    if (!dispatched && in_stype == kDefaultStorage) {
+      // dns -> dns
+      dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                       DispatchMode::kFCompute);
+    }
+    const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
+    if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) &&
+        axis.ndim() == 0 && param.ord == 2) {
+      // l2 norm: rsp/csr, axis = () -> dns
+      dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                       DispatchMode::kFComputeEx);
+    }
+    if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && !param.keepdims &&
+        (axis[0] == 0 || axis[0] == 1) && param.ord == 2) {
+      // l2 norm: csr, axis = 0/1 -> dns
+      dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode,
+                                       dispatch_ex);
+    }
+    if (!dispatched) {
+      dispatched = dispatch_fallback(out_attrs, dispatch_mode);
+    }
 
 Review comment:
   You do not need the extra else branch for param.ord. Just do:
   ```c++
   If (param.ord == 2) {
        // original storage type assign logic
   }
   if (!dispatched) {
         dispatched = dispatch_fallback(out_attrs, dispatch_mode);
   }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services