You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/11/06 20:03:48 UTC

[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19067: Fix compilation for large tensor with MKL

Zha0q1 commented on a change in pull request #19067:
URL: https://github.com/apache/incubator-mxnet/pull/19067#discussion_r518975411



##########
File path: src/operator/tensor/la_op-inl.h
##########
@@ -931,15 +1020,20 @@ struct det_backward {
     if (dA.shape_.Size() == 0U) {
       return;
     }
-    // compute inverse(A) and stores it to LU
-    linalg_batch_det_backward_helper(LU, pivot, det, dA, DType(0), ctx);
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    convert_to_int_if_needed(s, pivot);
+    // Calculations on the GPU path are internally done on int type.
+    using IndexInternalT = typename LapackIndex<xpu>::IndexT;
+    linalg_batch_det_backward_helper(LU,
+                                     reinterpret_cast<const Tensor<xpu, 2, IndexInternalT>&>(pivot),
+                                     det, dA, DType(0), ctx);
     const_cast<Tensor<xpu, 3, DType>&>(dA) = broadcast_to(reshape(det * ddet, \
       Shape3(det.size(0), 1, 1)), mxnet::TShape(LU.shape_)) * \
       transpose(LU, Shape3(0, 2, 1));
-    Stream<xpu> *s = ctx.get_stream<xpu>();
     // stop grad for zero det temporarily
     Kernel<StopZeroDetGrad, xpu>::Launch(s, dA.shape_.Size(), dA.size(1) * dA.size(2), \
                                          dA.dptr_, det.dptr_, DType(0));
+    convert_to_int64_if_needed(s, pivot);

Review comment:
       I think the only output is dA? In that case we might not need to convert the results back to int64




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org