You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/30 18:45:29 UTC

[incubator-mxnet] branch master updated: Expand gpu-kernel-launch synchronous error checking. (#9560)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new a0a52b3  Expand gpu-kernel-launch synchronous error checking. (#9560)
a0a52b3 is described below

commit a0a52b38a9046a3ede20c900ff64154642e8e2da
Author: Dick Carter <di...@comcast.net>
AuthorDate: Tue Jan 30 10:45:25 2018 -0800

    Expand gpu-kernel-launch synchronous error checking. (#9560)
    
    * Expand gpu-kernel-launch synchronous error checking.
    
    * Added python test that passes only after PR code is merged.
    
    * Print SKIP message to stderr so it appears in CI log.
    
    * Add python version to test skip message.
    
    * Added diagnostic info when skipping test.
    
    * Improve robustness of test.
---
 src/common/random_generator.cu                    |  1 +
 src/operator/contrib/count_sketch.cu              |  2 ++
 src/operator/contrib/ctc_include/detail/gpu_ctc.h |  4 +++
 src/operator/linalg_impl.h                        |  2 ++
 src/operator/mxnet_op.h                           |  2 ++
 src/operator/nn/sequence_mask-inl.h               |  1 +
 src/operator/nn/softmax-inl.h                     |  2 ++
 src/operator/pad.cu                               | 12 +++++++
 src/operator/roi_pooling.cu                       |  2 ++
 src/operator/svm_output.cu                        |  2 ++
 src/operator/tensor/broadcast_reduce-inl.cuh      |  6 ++++
 tests/python/gpu/test_operator_gpu.py             | 38 +++++++++++++++++++++++
 12 files changed, 74 insertions(+)

diff --git a/src/common/random_generator.cu b/src/common/random_generator.cu
index f6f31cf..930e5e0 100644
--- a/src/common/random_generator.cu
+++ b/src/common/random_generator.cu
@@ -55,6 +55,7 @@ void RandGenerator<gpu, float>::Seed(mshadow::Stream<gpu> *s, uint32_t seed) {
           states_,
           RandGenerator<gpu, float>::kNumRandomStates,
           seed);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(rand_generator_seed_kernel);
   s->Wait();
 }
 
diff --git a/src/operator/contrib/count_sketch.cu b/src/operator/contrib/count_sketch.cu
index b849f4c..373ff3e 100644
--- a/src/operator/contrib/count_sketch.cu
+++ b/src/operator/contrib/count_sketch.cu
@@ -129,6 +129,7 @@ inline void CountSketchForward(const Tensor<gpu, 2, DType> &out,
                                     nthreads, out_ptr+bstart*out_dim, h_ptr,
                                     s_ptr, in_ptr+bstart*in_dim, batchlen,
                                     in_dim, out_dim);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(sketch_forward_kernel);
     // cudaThreadSynchronize();
     bstart = (i+1)*batchlen;
   }
@@ -164,6 +165,7 @@ inline void CountSketchBackward(const Tensor<gpu, 2, DType> &in_grad,
                             nthreads, in_grad_ptr+bstart*in_dim, h_ptr,
                             s_ptr, out_grad_ptr+bstart*out_dim, batchlen,
                             in_dim, out_dim);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(sketch_backward_kernel);
     bstart = (i+1)*batchlen;
   }
 }
diff --git a/src/operator/contrib/ctc_include/detail/gpu_ctc.h b/src/operator/contrib/ctc_include/detail/gpu_ctc.h
index c9cab80..c249046 100644
--- a/src/operator/contrib/ctc_include/detail/gpu_ctc.h
+++ b/src/operator/contrib/ctc_include/detail/gpu_ctc.h
@@ -404,6 +404,10 @@ GpuCTC<ProbT>::compute_log_probs(const ProbT* const activations) {
         (ctc_helper::identity<ProbT>(), log_probs_,
          denoms_, out_dim_, num_elements);
 
+    cuda_status = cudaGetLastError();
+    if (cuda_status != cudaSuccess)
+        return CTC_STATUS_EXECUTION_FAILED;
+
     return CTC_STATUS_SUCCESS;
 }
 
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index b2a672f..d128617 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -641,6 +641,7 @@ void linalg_potri<gpu, DType>(const Tensor<gpu, 2, DType>& A, bool lower, Stream
                        static_cast<int>((A.MSize() + kBaseThreadNum - 1) / kBaseThreadNum)); \
   linalgInitIdentityGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \
     (static_cast<DType *>(buffer.dptr), A.MSize(), A.stride_, A.MSize());  \
+  MSHADOW_CUDA_POST_KERNEL_CHECK(linalgInitIdentityGPU); \
   Tensor<gpu, 2, DType> B((DType *)buffer.dptr, A.shape_, A.stride_, s); \
   linalg_trsm(A, B, DType(1.0), false, lower, !lower, s); \
   linalg_trsm(A, B, DType(1.0), false, lower, lower, s); \
@@ -664,6 +665,7 @@ void linalg_batch_potri<gpu, DType>(const Tensor<gpu, 3, DType>& A, bool lower,
                        static_cast<int>((A.MSize() + kBaseThreadNum - 1) / kBaseThreadNum)); \
   linalgInitIdentityGPU<<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>> \
     (static_cast<DType *>(buffer.dptr), A.size(1)*A.stride_, A.stride_, A.MSize()); \
+  MSHADOW_CUDA_POST_KERNEL_CHECK(linalgInitIdentityGPU); \
   Tensor<gpu, 3, DType> B((DType *)buffer.dptr, A.shape_, A.stride_, s); \
   linalg_batch_trsm(A, B, DType(1.0), false, lower, !lower, s); \
   linalg_batch_trsm(A, B, DType(1.0), false, lower, lower, s); \
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 5a36954..cd52524 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -563,6 +563,7 @@ struct Kernel<OP, gpu> {
     mxnet_generic_kernel<OP, Args...>
       <<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
         N, args...);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(mxnet_generic_kernel);
   }
 
   template<typename ...Args>
@@ -572,6 +573,7 @@ struct Kernel<OP, gpu> {
     mxnet_generic_kernel_ex<OP, Args...>
       <<<ngrid, kBaseThreadNum, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
         N, args...);
+    MSHADOW_CUDA_POST_KERNEL_CHECK(mxnet_generic_kernel_ex);
   }
 };
 #endif  // __CUDACC__
diff --git a/src/operator/nn/sequence_mask-inl.h b/src/operator/nn/sequence_mask-inl.h
index a3b41f6..df98116 100644
--- a/src/operator/nn/sequence_mask-inl.h
+++ b/src/operator/nn/sequence_mask-inl.h
@@ -66,6 +66,7 @@ inline void SequenceMask(const mshadow::Tensor<gpu, 3, DType> &dst,
   CheckLaunchParam(dimGrid, dimBlock, "SequenceMask");
   cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
   SequenceMaskKernel<kBaseThreadBits, DType><<<dimGrid, dimBlock, 0, stream>>>(dst, lengths, value);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(SequenceMaskKernel);
 }
 #endif
 
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 2badecf..080bc08 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -173,6 +173,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, DType *out,
   softmax_compute_kernel<x_bits, OP, DType, ndim>
     <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
       in, out, M, axis, sshape, stride);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_compute_kernel);
 }
 
 
@@ -216,6 +217,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, DType *out, DType *ograd,
   softmax_gradient_kernel<x_bits, OP1, OP2, DType, ndim>
     <<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
       out, ograd, igrad, M, axis, sshape, stride);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
 }
 #endif
 
diff --git a/src/operator/pad.cu b/src/operator/pad.cu
index 54242a4..372683a 100644
--- a/src/operator/pad.cu
+++ b/src/operator/pad.cu
@@ -78,6 +78,7 @@ inline void image_pad_edge(Tensor<gpu, 4, DType> dst,
   image_2d_pad_edge_kernel<kBaseThreadBits,
                            DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src,
                                                                     padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_2d_pad_edge_kernel);
 }
 
 template <int n_bits, typename DType>
@@ -119,6 +120,7 @@ inline void image_pad_edge_grad(Tensor<gpu, 4, DType> grad_in,
   image_2d_pad_edge_grad_kernel<kBaseThreadBits,
                                 DType><<<dimGrid, dimBlock, 0, stream>>>(
       grad_in, grad_out, padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_2d_pad_edge_grad_kernel);
 }
 
 // Case 2: Constant Padding
@@ -166,6 +168,7 @@ inline void image_pad_constant(Tensor<gpu, 4, DType> dst,
   image_2d_pad_constant_kernel<kBaseThreadBits,
                                DType><<<dimGrid, dimBlock, 0, stream>>>(
       dst, src, padT, padL, constant);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_2d_pad_constant_kernel);
 }
 
 template <int n_bits, typename DType>
@@ -202,6 +205,7 @@ inline void image_pad_constant_grad(Tensor<gpu, 4, DType> grad_in,
   image_2d_pad_constant_grad_kernel<kBaseThreadBits,
                                     DType><<<dimGrid, dimBlock, 0, stream>>>(
       grad_in, grad_out, padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_2d_pad_constant_grad_kernel);
 }
 
 
@@ -257,6 +261,7 @@ inline void image_pad_reflect(Tensor<gpu, 4, DType> dst,
   image_2d_pad_reflect_kernel<kBaseThreadBits,
                            DType><<<dimGrid, dimBlock, 0, stream>>>(dst, src,
                                                                     padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_2d_pad_reflect_kernel);
 }
 
 template <int n_bits, typename DType>
@@ -307,6 +312,7 @@ inline void image_pad_reflect_grad(Tensor<gpu, 4, DType> grad_in,
   image_2d_pad_reflect_grad_kernel<kBaseThreadBits,
                                 DType><<<dimGrid, dimBlock, 0, stream>>>(
       grad_in, grad_out, padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_2d_pad_reflect_grad_kernel);
 }
 
 
@@ -365,6 +371,7 @@ inline void image_pad_edge(Tensor<gpu, 5, DType> dst,
   image_3d_pad_edge_kernel<kBaseThreadBits,
                            DType><<<dimGrid, dimBlock, 0, stream>>>(
       dst, src, padF, padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_3d_pad_edge_kernel);
 }
 
 template <int n_bits, typename DType>
@@ -416,6 +423,7 @@ inline void image_pad_edge_grad(Tensor<gpu, 5, DType> grad_in,
   image_3d_pad_edge_grad_kernel<kBaseThreadBits,
                                 DType><<<dimGrid, dimBlock, 0, stream>>>(
       grad_in, grad_out, padF, padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_3d_pad_edge_grad_kernel);
 }
 
 // Case 2: Constant Padding
@@ -473,6 +481,7 @@ inline void image_pad_constant(Tensor<gpu, 5, DType> dst,
   image_3d_pad_constant_kernel<kBaseThreadBits,
                                DType><<<dimGrid, dimBlock, 0, stream>>>(
       dst, src, padF, padT, padL, constant);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_3d_pad_constant_kernel);
 }
 
 template <int n_bits, typename DType>
@@ -515,6 +524,7 @@ inline void image_pad_constant_grad(Tensor<gpu, 5, DType> grad_in,
   image_3d_pad_constant_grad_kernel<kBaseThreadBits,
                                     DType><<<dimGrid, dimBlock, 0, stream>>>(
       grad_in, grad_out, padF, padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_3d_pad_constant_grad_kernel);
 }
 
 // Case 3: Reflection Padding
@@ -578,6 +588,7 @@ inline void image_pad_reflect(Tensor<gpu, 5, DType> dst,
   image_3d_pad_reflect_kernel<kBaseThreadBits,
                            DType><<<dimGrid, dimBlock, 0, stream>>>(
       dst, src, padF, padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_3d_pad_reflect_kernel);
 }
 
 template <int n_bits, typename DType>
@@ -670,6 +681,7 @@ inline void image_pad_reflect_grad(Tensor<gpu, 5, DType> grad_in,
   image_3d_pad_reflect_grad_kernel<kBaseThreadBits,
                                 DType><<<dimGrid, dimBlock, 0, stream>>>(
       grad_in, grad_out, padF, padT, padL);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(image_3d_pad_reflect_grad_kernel);
 }
 
 ////////////////////////////////////////////////////////////////////////////////
diff --git a/src/operator/roi_pooling.cu b/src/operator/roi_pooling.cu
index 0f637b0..066c2ff 100644
--- a/src/operator/roi_pooling.cu
+++ b/src/operator/roi_pooling.cu
@@ -129,6 +129,7 @@ inline void ROIPoolForward(const Tensor<gpu, 4, Dtype> &out,
   ROIPoolForwardKernel<Dtype><<<dimGrid, dimBlock, 0, stream>>>(
       count, bottom_data, spatial_scale, channels, height, width,
       pooled_height, pooled_width, bottom_rois, top_data, argmax_data);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(ROIPoolForwardKernel);
 }
 
 template<typename Dtype>
@@ -232,6 +233,7 @@ inline void ROIPoolBackwardAcc(const Tensor<gpu, 4, Dtype> &in_grad,
   ROIPoolBackwardAccKernel<Dtype><<<dimGrid, dimBlock, 0, stream>>>(
       count, top_diff, argmax_data, num_rois, spatial_scale, channels, height, width,
       pooled_height, pooled_width, bottom_diff, bottom_rois);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(ROIPoolBackwardAccKernel);
 }
 
 }  // namespace cuda
diff --git a/src/operator/svm_output.cu b/src/operator/svm_output.cu
index d950107..fa11a6c 100644
--- a/src/operator/svm_output.cu
+++ b/src/operator/svm_output.cu
@@ -62,6 +62,7 @@ inline void L1_SVM(const DType & margin,
   cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
   L1_SVMKernel<cuda::kBaseThreadBits, DType> <<<dimGrid, dimBlock, 0, stream >>>
     (margin, reg_coef, dst, label, src);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(L1_SVMKernel);
 }
 
 
@@ -98,6 +99,7 @@ inline void L2_SVM(const DType & margin,
   cudaStream_t stream = Stream<gpu>::GetStream(dst.stream_);
   L2_SVMKernel<cuda::kBaseThreadBits, DType> <<<dimGrid, dimBlock, 0, stream >>>
     (margin, reg_coef, dst, label, src);
+  MSHADOW_CUDA_POST_KERNEL_CHECK(L2_SVMKernel);
 }
 }  // namespace mshadow
 
diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh
index f5d0152..630fef6 100644
--- a/src/operator/tensor/broadcast_reduce-inl.cuh
+++ b/src/operator/tensor/broadcast_reduce-inl.cuh
@@ -507,6 +507,7 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
     <<< config.kernel_1.gridDim, config.kernel_1.blockDim, 0, stream >>>(
       config.N, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(), big.shape_.get<ndim>(),
       small.shape_.get<ndim>());
+    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
   } else {
 
     DType* small_dptr = small.dptr<DType>();
@@ -531,11 +532,13 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const OpReqType req,
         small.shape_.get<ndim>(), config.rshape, config.rstride, config.Mnext,
         config.kernel_1.do_transpose);
     });
+    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);
 
     if (config.Mnext > 1) {
       reduce_lines_kernel<Reducer, DType>
       <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
         (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>());
+      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
     }
   }
 }
@@ -550,6 +553,7 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const
       config.N, req == kAddTo, big.dptr<DType>(), lhs.dptr<DType>(), rhs.dptr<DType>(),
       small.dptr<DType>(), big.shape_.get<ndim>(), lhs.shape_.get<ndim>(),
       rhs.shape_.get<ndim>(), small.shape_.get<ndim>());
+    MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel_M1);
   } else {
     DType* small_dptr = small.dptr<DType>();
     bool addto = (req == kAddTo);
@@ -574,12 +578,14 @@ void ReduceImpl(cudaStream_t stream, const TBlob& small, const TBlob& lhs, const
         rhs.shape_.get<ndim>(), small.shape_.get<ndim>(), config.rshape, config.lhs_shape,
         config.rhs_shape, config.rstride, config.lhs_stride, config.rhs_stride, config.Mnext,
         config.kernel_1.do_transpose);
+      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_kernel);
     });
 
     if (config.Mnext > 1) {
       reduce_lines_kernel<Reducer, DType>
       <<< config.kernel_2.gridSize, config.kernel_2.blockSize, 0, stream >>>
         (config.N, config.Mnext, req == kAddTo, config.N, small_dptr, small.dptr<DType>());
+      MSHADOW_CUDA_POST_KERNEL_CHECK(reduce_lines_kernel);
     }
   }
 }
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 52aca09..55bb30c 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -15,9 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from __future__ import print_function
 import sys
 import os
 import time
+import multiprocessing as mp
 import unittest
 import mxnet as mx
 import numpy as np
@@ -1478,6 +1480,42 @@ def test_cross_device_autograd():
 
     assert_almost_equal(dx, x.grad.asnumpy())
 
+
+# The following 2 functions launch 0-thread kernels, an error that should be caught and signaled.
+def kernel_error_check_imperative():
+    os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
+    a = mx.nd.array([1,2,3],ctx=mx.gpu(0))
+    b = mx.nd.array([],ctx=mx.gpu(0))
+    c = (a / b).asnumpy()
+
+def kernel_error_check_symbolic():
+    os.environ['MXNET_ENGINE_TYPE'] = 'NaiveEngine'
+    a = mx.sym.Variable('a')
+    b = mx.sym.Variable('b')
+    c = a / b
+    f = c.bind(mx.gpu(0), { 'a':mx.nd.array([1,2,3],ctx=mx.gpu(0)),
+                            'b':mx.nd.array([],ctx=mx.gpu(0))})
+    f.forward()
+    g = f.outputs[0].asnumpy()
+
+def test_kernel_error_checking():
+    # Running tests that may throw exceptions out of worker threads will stop CI testing
+    # if not run in a separate process (with its own address space for CUDA compatibility).
+    try:
+        mpctx = mp.get_context('spawn')
+    except:
+        print('SKIP: python%s.%s lacks the required process fork-exec support ... ' %
+              sys.version_info[0:2], file=sys.stderr, end='')
+    else:
+        with discard_stderr():
+            for f in [kernel_error_check_imperative, kernel_error_check_symbolic]:
+                p = mpctx.Process(target=f)
+                p.start()
+                p.join()
+                assert p.exitcode != 0,\
+                    "Expected a synchronous kernel error from %s(), none seen." % f.__name__
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.