You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/10/12 22:52:00 UTC

[GitHub] [incubator-mxnet] Zha0q1 opened a new pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Zha0q1 opened a new pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339


   fixes https://github.com/apache/incubator-mxnet/issues/11061.
   
   This PR includes a rewrite of the `mxnet.np.argmax` operator to reuse our axes reduce kernels on both cpu and gpu. Before it was using a mshadow expression implementation. This should work on all the data types. Note that this rewrite can be applied to `np.argmin` as well, but I am keeping the old implementation for `argmin` for now so that we can compare the run time. So the basic idea is to create a new struct `IndexedNum` to store both the number and index of the values in the tensor and use a custom reducer `mshadow_op::argmax` to update both `num` and `idx`.
   
   ```C++
   template <typename IType, typename DType>
   struct IndexedNum {
     IType idx;
     DType num;
   
     IndexedNum& operator+=(const IndexedNum& rhs){
       return *this;
     }  
   };
   ```
   
   ```C++
   /*! \brief arg max reducer */
   struct argmax {
     /*! \brief do reduction into dst */
     template<typename AType, typename DType>
     MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src) { // NOLINT(*)
       if (dst.num < src.num) {
         dst.num = src.num;
         dst.idx = src.idx;
       }
     }
     }
     /*! \brief combine the results of two reducers */
     template<typename DType>
     MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& src_val) { // NOLINT(*)
       Reduce(dst_val, src_val);
     }
     /*!
      *\brief set the initial value during reduction
      */
     template<typename DType>
     MSHADOW_XINLINE static void SetInitValue(DType &initv) { // NOLINT(*)
       initv.num = mshadow::red::limits::NegInfValue<decltype(initv.num)>();
       initv.idx = 0;
     }
   };
   ```
   
   This brings about a considerable speed improvement on GPU when the dimension of the axis to be reduced is large. Specifically, when dim=30000 the improvement is ~60x; dim=300000 improvement is ~130x; dim=3000000 improvement is ~180x. When the reduced axis has a small dim (e.g 64), the run time is slightly longer (~1.8x) due to the added memory overhead. However this is probably fine as in such cases the magnitude of the run time is very small.
   
   On CPU we are sort of seeing the opposite as on GPU: the run time will get significantly improved when the reduced dim (M) is small and the number of the reduced axes (N) is large. This is probably due to the fact that we are adding cpu parallelism on N but not M. I think we should probably parallelize M instead? There should be room to optimize the cpu kernel and this by itself can be another pr of impact as many ops use this reduce kernel (e.g. sum). (ref: https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/broadcast_reduce-inl.h#L270 and https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/broadcast_reduce-inl.h#L315)
   
   I used this script to time the runs:
   ```python
   import time
   import mxnet as mx
   import numpy as np
   import os
   def test(shape, xpu, axis, dtype, op):
       tmp = mx.np.random.normal(-1, 1, shape, dtype=dtype,ctx=xpu)#.as_nd_ndarray()
       tic = time.time()
       for i in range(20):
          if (i == 5):
               begin = time.time();
          elif (i == 15):
               end = time.time();
          tic = time.time()
          out = op(tmp, axis=axis)
          mx.nd.waitall()
          toc = time.time() - tic
          #print ("used time %f:"%toc)
       avg = (end - begin) / 10
       print(shape, xpu, "axis:"+str(axis), dtype, op)
       print("avg time %f"%avg)
   
   shapes = [(64, 300000)]
   axes = [1, 0]
   xpus = [mx.gpu(), mx.cpu()]
   ops = [mx.np.argmin, mx.np.argmax]
   dtypes = ['float32', 'float64', 'float16', 'int32', 'int64', 'int8']
   
   import itertools
   for s, x, a, d, o in list(itertools.product(*[shapes, xpus, axes, dtypes, ops])):
       test(s, x, a, d, o)
   #test((64, 300000), mx.gpu(), axis=1, dtype='float32', op=mx.np.argmax)
   ```
   results:
   ```
   ubuntu@ip-172-31-38-169:~/incubator-mxnet/build$ python ../sxj.py 
   a[22:45:40] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for GPU
   (64, 300000) gpu(0) axis:1 float32 <function argmin at 0x7f098421e320>
   avg time 0.367691
   (64, 300000) gpu(0) axis:1 float32 <function argmax at 0x7f098421e290>
   avg time 0.002094
   (64, 300000) gpu(0) axis:1 float64 <function argmin at 0x7f098421e320>
   avg time 0.373724
   (64, 300000) gpu(0) axis:1 float64 <function argmax at 0x7f098421e290>
   avg time 0.006975
   (64, 300000) gpu(0) axis:1 float16 <function argmin at 0x7f098421e320>
   avg time 0.347810
   (64, 300000) gpu(0) axis:1 float16 <function argmax at 0x7f098421e290>
   avg time 0.001983
   (64, 300000) gpu(0) axis:1 int32 <function argmin at 0x7f098421e320>
   avg time 0.324442
   (64, 300000) gpu(0) axis:1 int32 <function argmax at 0x7f098421e290>
   avg time 0.001897
   (64, 300000) gpu(0) axis:1 int64 <function argmin at 0x7f098421e320>
   avg time 0.354146
   (64, 300000) gpu(0) axis:1 int64 <function argmax at 0x7f098421e290>
   avg time 0.005937
   (64, 300000) gpu(0) axis:1 int8 <function argmin at 0x7f098421e320>
   avg time 0.326005
   (64, 300000) gpu(0) axis:1 int8 <function argmax at 0x7f098421e290>
   avg time 0.002361
   (64, 300000) gpu(0) axis:0 float32 <function argmin at 0x7f098421e320>
   avg time 0.001255
   (64, 300000) gpu(0) axis:0 float32 <function argmax at 0x7f098421e290>
   avg time 0.002032
   (64, 300000) gpu(0) axis:0 float64 <function argmin at 0x7f098421e320>
   avg time 0.001961
   (64, 300000) gpu(0) axis:0 float64 <function argmax at 0x7f098421e290>
   avg time 0.007079
   (64, 300000) gpu(0) axis:0 float16 <function argmin at 0x7f098421e320>
   avg time 0.001491
   (64, 300000) gpu(0) axis:0 float16 <function argmax at 0x7f098421e290>
   avg time 0.001987
   (64, 300000) gpu(0) axis:0 int32 <function argmin at 0x7f098421e320>
   avg time 0.001174
   (64, 300000) gpu(0) axis:0 int32 <function argmax at 0x7f098421e290>
   avg time 0.001968
   (64, 300000) gpu(0) axis:0 int64 <function argmin at 0x7f098421e320>
   avg time 0.001406
   (64, 300000) gpu(0) axis:0 int64 <function argmax at 0x7f098421e290>
   avg time 0.005828
   (64, 300000) gpu(0) axis:0 int8 <function argmin at 0x7f098421e320>
   avg time 0.001173
   (64, 300000) gpu(0) axis:0 int8 <function argmax at 0x7f098421e290>
   avg time 0.002448
   [22:46:23] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for CPU
   (64, 300000) cpu(0) axis:1 float32 <function argmin at 0x7f098421e320>
   avg time 0.036137
   (64, 300000) cpu(0) axis:1 float32 <function argmax at 0x7f098421e290>
   avg time 0.076985
   (64, 300000) cpu(0) axis:1 float64 <function argmin at 0x7f098421e320>
   avg time 0.036389
   (64, 300000) cpu(0) axis:1 float64 <function argmax at 0x7f098421e290>
   avg time 0.075778
   (64, 300000) cpu(0) axis:1 float16 <function argmin at 0x7f098421e320>
   avg time 0.072825
   (64, 300000) cpu(0) axis:1 float16 <function argmax at 0x7f098421e290>
   avg time 0.117842
   (64, 300000) cpu(0) axis:1 int32 <function argmin at 0x7f098421e320>
   avg time 0.032637
   (64, 300000) cpu(0) axis:1 int32 <function argmax at 0x7f098421e290>
   avg time 0.075707
   (64, 300000) cpu(0) axis:1 int64 <function argmin at 0x7f098421e320>
   avg time 0.033149
   (64, 300000) cpu(0) axis:1 int64 <function argmax at 0x7f098421e290>
   avg time 0.075667
   (64, 300000) cpu(0) axis:1 int8 <function argmin at 0x7f098421e320>
   avg time 0.032188
   (64, 300000) cpu(0) axis:1 int8 <function argmax at 0x7f098421e290>
   avg time 0.074256
   (64, 300000) cpu(0) axis:0 float32 <function argmin at 0x7f098421e320>
   avg time 0.712089
   (64, 300000) cpu(0) axis:0 float32 <function argmax at 0x7f098421e290>
   avg time 0.082273
   (64, 300000) cpu(0) axis:0 float64 <function argmin at 0x7f098421e320>
   avg time 0.802298
   (64, 300000) cpu(0) axis:0 float64 <function argmax at 0x7f098421e290>
   avg time 0.087749
   (64, 300000) cpu(0) axis:0 float16 <function argmin at 0x7f098421e320>
   avg time 1.252073
   (64, 300000) cpu(0) axis:0 float16 <function argmax at 0x7f098421e290>
   avg time 0.118619
   (64, 300000) cpu(0) axis:0 int32 <function argmin at 0x7f098421e320>
   avg time 0.627230
   (64, 300000) cpu(0) axis:0 int32 <function argmax at 0x7f098421e290>
   avg time 0.078799
   (64, 300000) cpu(0) axis:0 int64 <function argmin at 0x7f098421e320>
   avg time 0.722380
   (64, 300000) cpu(0) axis:0 int64 <function argmax at 0x7f098421e290>
   avg time 0.085350
   (64, 300000) cpu(0) axis:0 int8 <function argmin at 0x7f098421e320>
   avg time 0.547371
   (64, 300000) cpu(0) axis:0 int8 <function argmax at 0x7f098421e290>
   avg time 0.071751
   ```
   
   @sandeep-krishnamurthy @sxjscience @szha 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710665224


   Jenkins CI successfully triggered : [miscellaneous, unix-gpu, clang]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sandeep-krishnamurthy commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
sandeep-krishnamurthy commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710648900


   @ptrendx :)
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508730840



##########
File path: src/operator/numpy/np_broadcast_reduce_op.cuh
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015-2020 by Contributors
+ * \file np_broadcast_reduce-inl.cuh
+ * \brief GPU implementations for numpy binary broadcast ops
+ * \author Zhaoqi Zhu
+*/
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+
+using namespace mshadow::cuda;
+using namespace mshadow;
+using namespace broadcast;
+
+#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
+  if (do_unroll) {                                                    \
+    const int unrollVar = unrollAmount;                               \
+    {__VA_ARGS__}                                                     \
+  } else {                                                            \
+    const int unrollVar = 1;                                          \
+    {__VA_ARGS__}                                                     \
+  }
+
+template<typename Reducer, int NDim, typename DType, typename OType>
+void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,

Review comment:
       Yup, that would be better - reuse is pretty much always better than duplication ;-).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-708095169


   Jenkins CI successfully triggered : [windows-gpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707384103


   Hey @Zha0q1 , Thanks for submitting the PR 
   All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands: 
   - To trigger all jobs: @mxnet-bot run ci [all] 
   - To trigger specific jobs: @mxnet-bot run ci [job1, job2] 
   *** 
   **CI supported jobs**: [centos-cpu, clang, unix-gpu, sanity, windows-cpu, windows-gpu, miscellaneous, edge, website, centos-gpu, unix-cpu]
   *** 
   _Note_: 
    Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin. 
   All CI tests must pass before the PR can be merged. 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707391025


   > Great.
   > Apart from the individual argmax kernel, can you talk about performance (+ve or -ve) on end to end customer use case?
   > Where is argmax used commonly?
   > Also, this will help make judgement call if 1.8x increase which 80% regression on existing behavior on GPU with small dim (<64).
   
   @sxjscience Would you suggest a model to test this operator on? 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508185614



##########
File path: src/operator/numpy/np_broadcast_reduce_op.cuh
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015-2020 by Contributors
+ * \file np_broadcast_reduce-inl.cuh
+ * \brief GPU implementations for numpy binary broadcast ops
+ * \author Zhaoqi Zhu
+*/
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+
+using namespace mshadow::cuda;
+using namespace mshadow;
+using namespace broadcast;
+
+#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
+  if (do_unroll) {                                                    \
+    const int unrollVar = unrollAmount;                               \
+    {__VA_ARGS__}                                                     \
+  } else {                                                            \
+    const int unrollVar = 1;                                          \
+    {__VA_ARGS__}                                                     \
+  }
+
+template<typename Reducer, int NDim, typename DType, typename OType>
+void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,

Review comment:
       Existing reduce uses MSHADOW_TYPE_SWITCH which only supports the mshadow data types. Since I am using a custom struct as accumulation type and output type I unrolled that switch into this custom reduce function

##########
File path: src/operator/mshadow_op.h
##########
@@ -128,6 +128,27 @@ using std::is_integral;
 
 MXNET_UNARY_MATH_OP_NC(identity, a);
 
+template <typename IType, typename DType>
+struct IndexedNum {
+  IType idx;
+  DType num;
+
+  MSHADOW_XINLINE IndexedNum() : idx(0), num(0) {}
+
+  MSHADOW_XINLINE IndexedNum(DType n) : idx(0), num(n) {}
+
+  MSHADOW_XINLINE IndexedNum& operator+=(const IndexedNum& rhs){
+    return *this;
+  }
+};
+
+template<typename DType, typename OType>
+struct arg_min_max_map : public mxnet_op::tunable {

Review comment:
       Here OType is actually IndexedNum and identity is just 
   ```
     struct name : public mxnet_op::tunable { 
       template<typename DType> 
       MSHADOW_XINLINE static DType Map(DType a) { 
         return a; 
       } 
   ```

##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,19 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];
           #pragma unroll
           for (int u=0;u < unroll;u++) {
             if (k + u*by < Mend) {
               tmp[u] = OP::Map(big[idx_big[u]]);
+              // argmin/max, set IndexedNum.idx
+	      if (use_index)
+                *(reinterpret_cast<int*>(&tmp[u])) = k + u*by;

Review comment:
       The plan was sort of to come back here and use c++ 17 `constexpr if` in the future when we support only cuda 11 and up. But yeah I have thought of this solution too and I will change it tomorrow




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 edited a comment on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 edited a comment on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707438901


   Update:
   I just found that even if the dim of the axis to be reduced is small, as long as it's the last dim, my implementation is still much faster than the old implementation.
   ```python
   test((300000, 64), mx.gpu(), axis=1, dtype='float32', op=mx.np.argmin)
   test((300000, 64), mx.gpu(), axis=1, dtype='float32', op=mx.np.argmax)
   ```
   gives this result:
   ```
   [02:02:22] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for GPU
   (300000, 64) gpu(0) axis:1 float32 <function argmin at 0x7fabb9521320>
   avg time 0.014822
   (300000, 64) gpu(0) axis:1 float32 <function argmax at 0x7fabb9521290>
   avg time 0.002667
   ```
   My take is that the axes reduce kernel can benefit from caching when axis to be reduced is the last.
   
   I think reducing the last dim is the most common use case? Such as word embedding? @sxjscience 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707400093


   I think it's a very basic op. For example, it's used in calculating the accuracy. Also, if we find that the new kernel is worse for some specific workloads, we can always fall back to the old kernel for these workloads.
   
   In addition, try to also use `nvprof` to profile the performance. Directly measuring the performance via python timing is not accurate and will also count the overhead. We can also compare with the performance of PyTorch.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710665108


   @mxnet-bot run ci [clang, miscellaneous, unix-gpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sandeep-krishnamurthy commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sandeep-krishnamurthy commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707387391


   Great.
   Apart from the individual argmax kernel, can you talk about performance (+ve or -ve) on end to end customer use case?
   Where is argmax used commonly?
   Also, this will help make judgement call if 1.8x increase which 80% regression on existing behavior on GPU with small dim (<64).


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710685455


   Jenkins CI successfully triggered : [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710732674


   @mxnet-bot run ci [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508219885



##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,19 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];
           #pragma unroll
           for (int u=0;u < unroll;u++) {
             if (k + u*by < Mend) {
               tmp[u] = OP::Map(big[idx_big[u]]);
+              // argmin/max, set IndexedNum.idx
+	      if (use_index)
+                *(reinterpret_cast<int*>(&tmp[u])) = k + u*by;

Review comment:
       I changed it. Now we have one more template parameter `IndexOP` which will do nothing for regular data types and will set the index in case of argmin/max




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709643779


   Jenkins CI successfully triggered : [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504213383



##########
File path: src/operator/tensor/broadcast_reduce-inl.h
##########
@@ -278,7 +279,10 @@ MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const
   Reducer::SetInitValue(val, residual);
   for (size_t k = 0; k < M; ++k) {
     coord = mxnet_op::unravel(k, rshape);
-    Reducer::Reduce(val, AType(OP::Map(big[j + mxnet_op::dot(coord, rstride)])), residual);
+    AType temp = OP::Map(big[j + mxnet_op::dot(coord, rstride)]);
+    if (use_index)
+      memcpy(reinterpret_cast<char*>(&temp), &k, sizeof(size_t));

Review comment:
       if we use reinterpret_cast<size_t> then there is a compilation error on one ci about type punning 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508123315



##########
File path: src/operator/mshadow_op.h
##########
@@ -128,6 +128,27 @@ using std::is_integral;
 
 MXNET_UNARY_MATH_OP_NC(identity, a);
 
+template <typename IType, typename DType>
+struct IndexedNum {
+  IType idx;
+  DType num;
+
+  MSHADOW_XINLINE IndexedNum() : idx(0), num(0) {}
+
+  MSHADOW_XINLINE IndexedNum(DType n) : idx(0), num(n) {}
+
+  MSHADOW_XINLINE IndexedNum& operator+=(const IndexedNum& rhs){
+    return *this;
+  }
+};
+
+template<typename DType, typename OType>
+struct arg_min_max_map : public mxnet_op::tunable {

Review comment:
       why not just identity?

##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,19 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];
           #pragma unroll
           for (int u=0;u < unroll;u++) {
             if (k + u*by < Mend) {
               tmp[u] = OP::Map(big[idx_big[u]]);
+              // argmin/max, set IndexedNum.idx
+	      if (use_index)
+                *(reinterpret_cast<int*>(&tmp[u])) = k + u*by;

Review comment:
       That's pretty ugly and fragile. Maybe instead of `use_index` have a function template parameter here that by default would be a noop and you would change it to setting the index here?

##########
File path: src/operator/numpy/np_broadcast_reduce_op.cuh
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015-2020 by Contributors
+ * \file np_broadcast_reduce-inl.cuh
+ * \brief GPU implementations for numpy binary broadcast ops
+ * \author Zhaoqi Zhu
+*/
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+
+using namespace mshadow::cuda;
+using namespace mshadow;
+using namespace broadcast;
+
+#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
+  if (do_unroll) {                                                    \
+    const int unrollVar = unrollAmount;                               \
+    {__VA_ARGS__}                                                     \
+  } else {                                                            \
+    const int unrollVar = 1;                                          \
+    {__VA_ARGS__}                                                     \
+  }
+
+template<typename Reducer, int NDim, typename DType, typename OType>
+void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,

Review comment:
       How is this function different from existing `Reduce`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710676204


   Jenkins CI successfully triggered : [clang]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504270302



##########
File path: src/operator/mshadow_op.h
##########
@@ -1532,6 +1551,71 @@ struct sum {
   }
 };
 
+/*! \brief arg max reducer */
+struct argmax {
+  /*! \brief do reduction into dst */
+  template<typename AType, typename DType>
+  MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src) { // NOLINT(*)
+    if (dst.num < src.num) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief do stable reduction into dst */
+  template<typename AType, typename DType>
+  MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src, volatile DType& residual) { // NOLINT(*)

Review comment:
       residual is not used in argmin/max. When doing sum it's useful for safe accumulation. For compatibility with the reduce kernel we need to define this function with residue even if it's not used




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 edited a comment on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 edited a comment on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709630699






----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707428163


   I mean the kernel for sum.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709630699


   A brief recap of contiguous and non-contiguous reduction results on GPU:
   Contiguous:
   > (300000, 64) float axis=1
   > before: 14.0549
   > now: 3.1830
   
   > (64, 300000) float axis=1
   > before: 376.1702
   > now: 2.9527
   
   Non-contiguous:
   > (300000, 64) float axis=0
   > before: 14.0549
   > now: 3.1830
   
   > (64, 300000) float axis=0
   > before: 268.1024
   > now: 3.5889
   
   The results are straight out from our mx.profiler


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710606466


   Just added a large (in general sense) tensor test. This test will trigger the old argmin/max implementation to fail on float16 (I created issue https://github.com/apache/incubator-mxnet/issues/19362). However now with my new implementation this issue has been resovled


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-708095139


   @mxnet-bot run ci [windows-gpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-713101604


   If this pr gets merged I will next apply the same implementation for argmin


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709710316


   I cleaned up the code and did more benchmarks. Would you review @sxjscience @leezu @szha 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] leezu commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
leezu commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504319331



##########
File path: src/operator/tensor/broadcast_reduce-inl.h
##########
@@ -278,7 +279,10 @@ MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const
   Reducer::SetInitValue(val, residual);
   for (size_t k = 0; k < M; ++k) {
     coord = mxnet_op::unravel(k, rshape);
-    Reducer::Reduce(val, AType(OP::Map(big[j + mxnet_op::dot(coord, rstride)])), residual);
+    AType temp = OP::Map(big[j + mxnet_op::dot(coord, rstride)]);
+    if (use_index)
+      memcpy(reinterpret_cast<char*>(&temp), &k, sizeof(size_t));

Review comment:
       Can you paste the error?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r505776653



##########
File path: src/operator/numpy/np_broadcast_reduce_op_index.cu
##########
@@ -28,7 +28,7 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_argmax)
-.set_attr<FCompute>("FCompute<gpu>", NumpySearchAxisCompute<gpu, mshadow::red::maximum>);
+.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxCompute<gpu, int>);

Review comment:
       Sounds good to me.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709643754


   @mxnet-bot run ci [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508219885



##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,19 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];
           #pragma unroll
           for (int u=0;u < unroll;u++) {
             if (k + u*by < Mend) {
               tmp[u] = OP::Map(big[idx_big[u]]);
+              // argmin/max, set IndexedNum.idx
+	      if (use_index)
+                *(reinterpret_cast<int*>(&tmp[u])) = k + u*by;

Review comment:
       I changed it 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709850808


   The performance numbers look good and we may try to later accelerate the CPU version.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710732692


   Jenkins CI successfully triggered : [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-713230665


   Jenkins CI successfully triggered : [edge]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r503613280



##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,18 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];

Review comment:
       Why change DType --> AType?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707559990


   @mxnet-bot run ci [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 edited a comment on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 edited a comment on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707438901


   Update:
   I just found that even if the dim of the axis to be reduced is small, as long as it's the last dim, my implementation is still much faster than the old implementation.
   ```python
   test((300000, 64), mx.gpu(), axis=1, dtype='float32', op=mx.np.argmin)
   test((300000, 64), mx.gpu(), axis=1, dtype='float32', op=mx.np.argmax)
   ```
   gives this result:
   ```
   [02:02:22] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for GPU
   (300000, 64) gpu(0) axis:1 float32 <function argmin at 0x7fabb9521320>
   avg time 0.014822
   (300000, 64) gpu(0) axis:1 float32 <function argmax at 0x7fabb9521290>
   avg time 0.002667
   ```
   
   I think reducing the last dim is the most common use case? Such as word embedding? @sxjscience 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504214158



##########
File path: src/operator/numpy/np_broadcast_reduce_op_index.cu
##########
@@ -28,7 +28,7 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_argmax)
-.set_attr<FCompute>("FCompute<gpu>", NumpySearchAxisCompute<gpu, mshadow::red::maximum>);
+.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxCompute<gpu, int>);

Review comment:
       I think for gpu int would suffice since the kernels are all using int and don't really support large tensor. Also using size_t will make sizeof(IndexedNum<indextype, float>) for from 8 to 16 and the run time is ~ 2X




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504321441



##########
File path: src/operator/tensor/broadcast_reduce-inl.h
##########
@@ -278,7 +279,10 @@ MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const
   Reducer::SetInitValue(val, residual);
   for (size_t k = 0; k < M; ++k) {
     coord = mxnet_op::unravel(k, rshape);
-    Reducer::Reduce(val, AType(OP::Map(big[j + mxnet_op::dot(coord, rstride)])), residual);
+    AType temp = OP::Map(big[j + mxnet_op::dot(coord, rstride)]);
+    if (use_index)
+      memcpy(reinterpret_cast<char*>(&temp), &k, sizeof(size_t));

Review comment:
       @leezu 
   ```
   2020-10-13T00:32:03.091Z] /work/mxnet/src/operator/numpy/np_cross.cc:118:61:   required from here
   [2020-10-13T00:32:03.091Z] /work/mxnet/src/operator/numpy/./../tensor/./broadcast_reduce-inl.h:284:7: error: dereferencing type-punned pointer will break strict-aliasing rules [-Werror=strict-aliasing]
   [2020-10-13T00:32:03.091Z]        *(reinterpret_cast<size_t*>(&temp)) = k;
   [2020-10-13T00:32:03.091Z]        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   [2020-10-13T00:32:03.091Z] /work/mxnet/src/operator/numpy/./../tensor/./broadcast_reduce-inl.h: In instantiation of 'void mxnet::op::broadcast::seq_reduce_assign(mxnet::index_t, size_t, bool, const DType*, OType*, const mshadow::Shape<dimension>&, const mshadow::Shape<dimension>&, const mshadow::Shape<dimension>&, const mshadow::Shape<dimension>&) [with Reducer = mxnet::op::mshadow_op::sum; int ndim = 1; AType = double; DType = double; OType = double; OP = mxnet::op::mshadow_op::identity; bool use_index = false; mxnet::index_t = int; size_t = long unsigned int]':
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707427950


   > In addition, I think the previous code may dispatch the reduce op to multiple kernels based on the shape and we may just extend the logic on top of this.
   
   By previous code do you mean the previous implementation of argmax? Could you elaborate more on this


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504190037



##########
File path: src/operator/tensor/broadcast_reduce-inl.h
##########
@@ -278,7 +279,10 @@ MSHADOW_XINLINE void seq_reduce_assign(const index_t idx, const size_t M, const
   Reducer::SetInitValue(val, residual);
   for (size_t k = 0; k < M; ++k) {
     coord = mxnet_op::unravel(k, rshape);
-    Reducer::Reduce(val, AType(OP::Map(big[j + mxnet_op::dot(coord, rstride)])), residual);
+    AType temp = OP::Map(big[j + mxnet_op::dot(coord, rstride)]);
+    if (use_index)
+      memcpy(reinterpret_cast<char*>(&temp), &k, sizeof(size_t));

Review comment:
       Why memcpy?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710685442


   @mxnet-bot run ci [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 edited a comment on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 edited a comment on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709630699


   A brief recap of contiguous and non-contiguous reduction results on GPU:
   Contiguous:
   > (300000, 64) float axis=1
   > before: 14.0549
   > now: 3.1830
   
   > (64, 300000) float axis=1
   > before: 376.1702
   > now: 2.9527
   
   Non-contiguous:
   > (300000, 64) float axis=0
   > before: 268.1024
   > now: 3.5889
   
   > (64, 300000) float axis=0
   > before: 1.5355
   > now: 3.3389
   
   The results are straight out from our mx.profiler


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-713230645


   @mxnet-bot  run ci [edge]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 edited a comment on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 edited a comment on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709630699


   A brief recap of contiguous and non-contiguous reduction results on GPU:
   Contiguous:
   > (300000, 64) float axis=1
   > before: 14.0549
   > now: 3.1830
   
   > (64, 300000) float axis=1
   > before: 376.1702
   > now: 2.9527
   
   Non-contiguous:
   > (300000, 64) float axis=0
   > before: 268.1024
   > now: 3.5889
   
   > (64, 300000) float axis=0
   > before: 
   > now: 
   
   The results are straight out from our mx.profiler


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709634458


   CPU results:
   
   Contiguous:
   > (300000, 64) float axis=1
   > before: 38.3453
   > now: 66.3917
   
   > (64, 300000) float axis=1
   > before: 35.6059
   > now: 62.8450
   
   Non-contiguous:
   > (300000, 64) float axis=0
   > before: 994.6203
   > now: 147.3643
   
   > (64, 300000) float axis=0
   > before: 711.3664
   > now: 73.1802


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sandeep-krishnamurthy commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
sandeep-krishnamurthy commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-713293699


   Great contribution @Zha0q1 . Thank you @ptrendx and @sxjscience @access2rohit for your help and review.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707438901


   Update:
   I just found that even if the dim of the axis to be reduced is small, as long as it's the last dim, my implementation is still much faster than the old implementation.
   ```python
   test((300000, 64), mx.gpu(), axis=1, dtype='float32', op=mx.np.argmin)
   test((300000, 64), mx.gpu(), axis=1, dtype='float32', op=mx.np.argmax)
   ```
   gives this result:
   ```
   [02:02:22] ../src/storage/storage.cc:199: Using Pooled (Naive) StorageManager for GPU
   (300000, 64) gpu(0) axis:1 float32 <function argmin at 0x7fabb9521320>
   avg time 0.014822
   (300000, 64) gpu(0) axis:1 float32 <function argmax at 0x7fabb9521290>
   avg time 0.002667
   ```
   
   I think reducing the last dim is the most common use case?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504210967



##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,18 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];

Review comment:
       Another reason is that the return type of OP::Map()  is `struct IndexedNum`. This would be the least changes to make the kernel compatible with that. I am trying to think of another solution




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-713101077


   @ptrendx Thanks for reviewing! I have updated the pr accordingly. I have run the benchmark again and it was basically the same as my previous commits so no regression from those changes


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508797145



##########
File path: src/operator/numpy/np_broadcast_reduce_op.cuh
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015-2020 by Contributors
+ * \file np_broadcast_reduce-inl.cuh
+ * \brief GPU implementations for numpy binary broadcast ops
+ * \author Zhaoqi Zhu
+*/
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+
+using namespace mshadow::cuda;
+using namespace mshadow;
+using namespace broadcast;
+
+#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
+  if (do_unroll) {                                                    \
+    const int unrollVar = unrollAmount;                               \
+    {__VA_ARGS__}                                                     \
+  } else {                                                            \
+    const int unrollVar = 1;                                          \
+    {__VA_ARGS__}                                                     \
+  }
+
+template<typename Reducer, int NDim, typename DType, typename OType>
+void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,

Review comment:
       Just applied that change




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508730209



##########
File path: src/operator/mshadow_op.h
##########
@@ -128,6 +128,27 @@ using std::is_integral;
 
 MXNET_UNARY_MATH_OP_NC(identity, a);
 
+template <typename IType, typename DType>
+struct IndexedNum {
+  IType idx;
+  DType num;
+
+  MSHADOW_XINLINE IndexedNum() : idx(0), num(0) {}
+
+  MSHADOW_XINLINE IndexedNum(DType n) : idx(0), num(n) {}
+
+  MSHADOW_XINLINE IndexedNum& operator+=(const IndexedNum& rhs){
+    return *this;
+  }
+};
+
+template<typename DType, typename OType>
+struct arg_min_max_map : public mxnet_op::tunable {

Review comment:
       Still, if you made `OType a = identity(x)` it would still work, right?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 edited a comment on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 edited a comment on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-709630699


   A brief recap of contiguous and non-contiguous reduction results on GPU:
   Contiguous:
   > (300000, 64) float axis=1
   > before: 14.0549
   > now: 3.1830
   
   > (64, 300000) float axis=1
   > before: 376.1702
   > now: 2.9527
   
   Non-contiguous:
   > (300000, 64) float axis=0
   > before: 268.1024
   > now: 3.5889
   
   > (64, 300000) float axis=0
   > before: 1.5355
   > now: 3.3389
   
   The results are straight out from our mx.profiler and are average of 20 runs


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r505913519



##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,18 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];

Review comment:
       I just compared my pr build and master build. Same config: Cuda on, large tensor on.
   
   I checked np.sum and used mxnet profiler and this change does not seem to affect caching 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r503613902



##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,18 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];

Review comment:
       I though this was a little cleaner. This way we can get rid of the explicit conversion on line 74




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707560042


   Jenkins CI successfully triggered : [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707455290


   @Zha0q1 Usually reducing over the last dim is called "Contiguous Reduction" and reducing over another axis is called "NonContiguous Reduction". Both cases are important and I think our reduction op has been optimized for both scenarios.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504225400



##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,18 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];

Review comment:
       It's probably not a good idea to add too much branching in the gpu kernel... If there is no alternative solution I can write a custom GPU kernel. This way we can also get rid of the residual variable which we is not using for arg min max




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508185614



##########
File path: src/operator/numpy/np_broadcast_reduce_op.cuh
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015-2020 by Contributors
+ * \file np_broadcast_reduce-inl.cuh
+ * \brief GPU implementations for numpy binary broadcast ops
+ * \author Zhaoqi Zhu
+*/
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+
+using namespace mshadow::cuda;
+using namespace mshadow;
+using namespace broadcast;
+
+#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
+  if (do_unroll) {                                                    \
+    const int unrollVar = unrollAmount;                               \
+    {__VA_ARGS__}                                                     \
+  } else {                                                            \
+    const int unrollVar = 1;                                          \
+    {__VA_ARGS__}                                                     \
+  }
+
+template<typename Reducer, int NDim, typename DType, typename OType>
+void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,

Review comment:
       Existing reduce uses `TBlob.dptr<OType>()` which only supports the mshadow data types. Since I am using a custom struct I need to do `reinterpret_cast<OType*>(TBlob.dptr_)`. But yeah I can apply the second one to the original `ReduceImpl` and call into that, what do you think?

##########
File path: src/operator/numpy/np_broadcast_reduce_op.cuh
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015-2020 by Contributors
+ * \file np_broadcast_reduce-inl.cuh
+ * \brief GPU implementations for numpy binary broadcast ops
+ * \author Zhaoqi Zhu
+*/
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+
+using namespace mshadow::cuda;
+using namespace mshadow;
+using namespace broadcast;
+
+#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
+  if (do_unroll) {                                                    \
+    const int unrollVar = unrollAmount;                               \
+    {__VA_ARGS__}                                                     \
+  } else {                                                            \
+    const int unrollVar = 1;                                          \
+    {__VA_ARGS__}                                                     \
+  }
+
+template<typename Reducer, int NDim, typename DType, typename OType>
+void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,

Review comment:
       Existing reduce uses `TBlob.dptr<OType>()` which only supports the mshadow data types. Since I am using a custom struct I need to do `reinterpret_cast<OType*>(TBlob.dptr_)`. But yeah I can also apply the second one to the original `ReduceImpl` and call into that, what do you think?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-710676190


   @mxnet-bot run ci [clang]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] access2rohit commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
access2rohit commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504236931



##########
File path: src/operator/mshadow_op.h
##########
@@ -1532,6 +1551,71 @@ struct sum {
   }
 };
 
+/*! \brief arg max reducer */
+struct argmax {
+  /*! \brief do reduction into dst */
+  template<typename AType, typename DType>
+  MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src) { // NOLINT(*)
+    if (dst.num < src.num) {
+      dst.num = src.num;
+      dst.idx = src.idx;
+    }
+  }
+  /*! \brief do stable reduction into dst */
+  template<typename AType, typename DType>
+  MSHADOW_XINLINE static void Reduce(volatile AType& dst,  volatile DType src, volatile DType& residual) { // NOLINT(*)

Review comment:
       purpose of residual ?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504190709



##########
File path: src/operator/numpy/np_broadcast_reduce_op_index.cu
##########
@@ -28,7 +28,7 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_argmax)
-.set_attr<FCompute>("FCompute<gpu>", NumpySearchAxisCompute<gpu, mshadow::red::maximum>);
+.set_attr<FCompute>("FCompute<gpu>", NumpyArgMinMaxCompute<gpu, int>);

Review comment:
       How about to also use size_t?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-708208448


   @mxnet-bot run ci [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sandeep-krishnamurthy merged pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
sandeep-krishnamurthy merged pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on a change in pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r504188306



##########
File path: src/operator/tensor/broadcast_reduce-inl.cuh
##########
@@ -60,16 +60,18 @@ __global__ void reduce_kernel(const int N, const int M, const bool addto,
           for (int u=0;u < unroll;u++) {
             idx_big[u] = idx_big0 + mxnet_op::unravel_dot(k + u*by, big_shape, big_stride);
           }
-          DType tmp[unroll];
+          AType tmp[unroll];

Review comment:
       I think this is for the purpose of caching these data to the local register. Basically, we load `unroll` elements to the local register and then accumulate them to the accumulation type.
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707435594


   > I mean the kernel for sum.
   
   Do you mean the GPU reduce kernel https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/broadcast_reduce-inl.cuh#L273 ? Yeah my implementation is calling into this function. There maybe we can add a new branch/kernel for when the dim to be reduced is small


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-708208498


   Jenkins CI successfully triggered : [unix-cpu]


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508185614



##########
File path: src/operator/numpy/np_broadcast_reduce_op.cuh
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015-2020 by Contributors
+ * \file np_broadcast_reduce-inl.cuh
+ * \brief GPU implementations for numpy binary broadcast ops
+ * \author Zhaoqi Zhu
+*/
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+
+using namespace mshadow::cuda;
+using namespace mshadow;
+using namespace broadcast;
+
+#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
+  if (do_unroll) {                                                    \
+    const int unrollVar = unrollAmount;                               \
+    {__VA_ARGS__}                                                     \
+  } else {                                                            \
+    const int unrollVar = 1;                                          \
+    {__VA_ARGS__}                                                     \
+  }
+
+template<typename Reducer, int NDim, typename DType, typename OType>
+void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,

Review comment:
       Existing reduce uses `TBlob.dptr<OType>()` which only supports the mshadow data types. Since I am using a custom struct I need to do `reinterpret_cast<OType*>(TBlob.dptr_)`. But yeah I can also apply the second way to the original `ReduceImpl` and call into that, what do you think?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19339: Numpy Argmax Rewrite

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#discussion_r508731860



##########
File path: src/operator/numpy/np_broadcast_reduce_op.cuh
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015-2020 by Contributors
+ * \file np_broadcast_reduce-inl.cuh
+ * \brief GPU implementations for numpy binary broadcast ops
+ * \author Zhaoqi Zhu
+*/
+#ifndef MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+#define MXNET_OPERATOR_NUMPY_NP_BROADCAST_REDUCE_OP_CUH_
+
+using namespace mshadow::cuda;
+using namespace mshadow;
+using namespace broadcast;
+
+#define KERNEL_UNROLL_SWITCH(do_unroll, unrollAmount, unrollVar, ...) \
+  if (do_unroll) {                                                    \
+    const int unrollVar = unrollAmount;                               \
+    {__VA_ARGS__}                                                     \
+  } else {                                                            \
+    const int unrollVar = 1;                                          \
+    {__VA_ARGS__}                                                     \
+  }
+
+template<typename Reducer, int NDim, typename DType, typename OType>
+void NumpyArgMinMaxReduce(Stream<gpu> *s, const TBlob& in_data, const TBlob& out_data,

Review comment:
       I am currently preparing a PR to refactor reductions to use RTC anyway, so having this logic in less places will also help me in that effort.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
sxjscience commented on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707427048


   In addition, I think the previous code may dispatch the reduce op to multiple kernels based on the shape and we may just extend the logic on top of this.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] Zha0q1 edited a comment on pull request #19339: [WIP] Numpy Argmax(min) Rewrite

Posted by GitBox <gi...@apache.org>.
Zha0q1 edited a comment on pull request #19339:
URL: https://github.com/apache/incubator-mxnet/pull/19339#issuecomment-707435594


   > I mean the kernel for sum.
   
   Do you mean the GPU reduce kernel https://github.com/apache/incubator-mxnet/blob/master/src/operator/tensor/broadcast_reduce-inl.cuh#L273 ? Yeah my implementation is calling into this function. There maybe we can add a new branch/kernel for when the dim to be reduced is small. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org