You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/08/27 06:56:57 UTC

[GitHub] [incubator-mxnet] sxjscience opened a new issue #19026: [Bug] RTC Failed to compile

sxjscience opened a new issue #19026:
URL: https://github.com/apache/incubator-mxnet/issues/19026


   @ptrendx Let me create an issue and track it in the issue.
   
   It is reproducible by running the https://github.com/dmlc/gluon-nlp/blob/master/tests/test_attention_cell.py script on GPU with GluonNLP.
   
   Code file obtained after setting `MXNET_RTC_VERBOSE =1`:
   
   ```c++
   
   struct __align__(2) __half {
     __host__ __device__ __half() { }
     unsigned short __x;
   };
   /* Definitions of intrinsics */
   __device__ inline __half __float2half(const float f) {
     __half val;
    asm("{  cvt.rn.f16.f32 %0, %1;}\n" : "=h"(val.__x) : "f"(f));
     return val;
   }
   __device__ inline float __half2float(const __half h) {
     float val;
    asm("{  cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(h.__x));
     return val;
   }
   
   typedef __half half;
   
   template <typename DType>
   struct AccType {
     using type = DType;
   
     __device__ static inline type from(const DType& val) {
       return val;
     }
   
     __device__ static inline DType to(type val) {
       return val;
     }
   
   };
   
   template<>
   struct AccType<half> {
     using type = float;
   
     __device__ static inline type from(const half& val) {
       return __half2float(val);
     }
   
     __device__ static inline half to(type val) {
       return __float2half(val);
     }
   };
   
   
   using float32 = float;
   using float64 = double;
   using float16 = half;
   using uint8 = unsigned char;
   using int8 = char;
   using int32 = int;
   using int64 = long long;
   
   static_assert(sizeof(float32) == 4, "Size of float32 is expected to be 4B");
   static_assert(sizeof(float64) == 8, "Size of float64 is expected to be 8B");
   static_assert(sizeof(float16) == 2, "Size of float16 is expected to be 2B");
   static_assert(sizeof(uint8) == 1, "Size of uint8 is expected to be 1B");
   static_assert(sizeof(int8) == 1, "Size of int8 is expected to be 1B");
   static_assert(sizeof(int32) == 4, "Size of int32 is expected to be 4B");
   static_assert(sizeof(int64) == 8, "Size of int64 is expected to be 8B");
   
   typedef int32 index_t;
   
   // bool and int8 need to be accumulated in index_t
   // but bool needs to be treated in the special way
   // for ops like bitwise_not
   struct bool_t {
     index_t value;
   
     __device__ inline bool_t(const index_t& v) : value(v) {}
     __device__ inline bool_t(const volatile index_t& v) : value(v) {}
     __device__ inline bool_t() : value(0) {}
   
     __device__ inline operator index_t() const volatile { return value; }
     __device__ inline bool_t& operator= (const index_t& v) {
       value = v;
       return *this;
     }
     __device__ inline volatile bool_t& operator= (const index_t& v) volatile {
       value = v;
       return *this;
     }
     __device__ inline bool_t& operator= (const volatile index_t& v) {
       value = v;
       return *this;
     }
   };
   template<>
   struct AccType<bool> {
     using type = bool_t;
   
     __device__ static inline type from(const bool& val) {
       return val;
     }
   
     __device__ static inline bool to(type val) {
       return val;
     }
   };
   
   template<>
   struct AccType<int8> {
     using type = index_t;
   
     __device__ static inline type from(const int8& val) {
       return val;
     }
   
     __device__ static inline int8 to(type val) {
       return val;
     }
   };
   
   template<>
   struct AccType<uint8> {
     using type = index_t;
   
     __device__ static inline type from(const uint8& val) {
       return val;
     }
   
     __device__ static inline uint8 to(type val) {
       return val;
     }
   };
   
   namespace type_util {
   
   struct false_type {
     static constexpr bool value = false;
   };
   
   struct true_type {
     static constexpr bool value = true;
   };
   
   // is_integral
   template <typename T> struct is_integral : false_type {};
   template <> struct is_integral<uint8> : true_type {};
   template <> struct is_integral<int8>  : true_type {};
   template <> struct is_integral<int32> : true_type {};
   template <> struct is_integral<int64> : true_type {};
   template <> struct is_integral<bool>  : true_type {};
   template <> struct is_integral<bool_t>  : true_type {};
   
   // is_unsigned
   template <typename T> struct is_unsigned : false_type {};
   template <> struct is_unsigned<uint8> : true_type {};
   template <> struct is_unsigned<bool>  : true_type {};
   template <> struct is_unsigned<bool_t>  : true_type {};
   
   // is_same
   template <typename T, typename U>
   struct is_same : false_type {};
   template <typename T> struct is_same<T, T> : true_type {};
   
   // has_double
   template <typename... T> struct has_double : false_type {};
   
   template <typename A, typename... B>
   struct has_double<A, B...> {
       static constexpr bool value = is_same<A, double>::value ||
                                     has_double<B...>::value;
   };
   
   // has_double_or_integral
   template <typename... T> struct has_double_or_integral : false_type {};
   
   template <typename A, typename... B>
   struct has_double_or_integral<A, B...> {
       static constexpr bool value = is_same<A, double>::value ||
                                     is_integral<A>::value ||
                                     has_double_or_integral<B...>::value;
   };
   
   template <bool b>
   struct enable_if {};
   
   template <>
   struct enable_if<true> {
     using type = void;
   };
   
   template <typename T, typename U, class Enable = void>
   struct mixed_type;
   
   template <typename T>
   struct mixed_type<T, float64, typename enable_if<!is_same<float64, T>::value>::type> {
     using type = float64;
   };
   
   template <typename T>
   struct mixed_type<float64, T> {
     using type = float64;
   };
   
   template <typename T>
   struct mixed_type<T, float32, typename enable_if<!is_same<float64, T>::value &&
                                                    !is_same<float32, T>::value>::type> {
     using type = float32;
   };
   
   template <typename T>
   struct mixed_type<float32, T, typename enable_if<!is_same<float64, T>::value>::type> {
     using type = float32;
   };
   
   template <typename T>
   struct mixed_type<T, float16, typename enable_if<is_same<float16, T>::value ||
                                                    is_integral<T>::value>::type> {
     using type = float16;
   };
   
   template <typename T>
   struct mixed_type<float16, T, typename enable_if<is_integral<T>::value>::type> {
     using type = float16;
   };
   
   template <typename T, typename U>
   struct mixed_type<T, U, typename enable_if<is_integral<T>::value &&
                                              is_integral<U>::value &&
                                              !is_same<U, bool_t>::value &&
                                              sizeof(T) <= sizeof(U)>::type> {
     using type = U;
   };
   
   template <typename T, typename U>
   struct mixed_type<U, T, typename enable_if<is_integral<T>::value &&
                                              is_integral<U>::value &&
                                              !is_same<U, bool_t>::value &&
                                              sizeof(T) < sizeof(U)>::type> {
     using type = U;
   };
   
   template <typename T>
   struct mixed_type<T, bool_t, typename enable_if<is_integral<T>::value &&
                                                   sizeof(T) < sizeof(bool_t)>::type> {
     using type = index_t;
   };
   
   template <typename T>
   struct mixed_type<bool_t, T, typename enable_if<is_integral<T>::value &&
                                                   sizeof(T) < sizeof(bool_t)>::type> {
     using type = index_t;
   };
   
   template <typename T>
   struct mixed_type<T, bool_t, typename enable_if<is_integral<T>::value &&
                                                   sizeof(T) == sizeof(bool_t)>::type> {
     using type = T;
   };
   
   }  // namespace type_util
   
   
   enum class OpReqType {
     kNullOp,
     kWriteTo,
     kWriteInplace,
     kAddTo
   };
   
   constexpr int kRTCMaxThreadsPerBlock = 512;
   
   namespace util {
   
   constexpr int MAX_DIM = 5;
   
   template <int ndim>
   __device__ inline void unravel_dot(const index_t idx, const index_t (&shape)[MAX_DIM],
     const index_t (&stridej)[MAX_DIM], const index_t (&stridek)[MAX_DIM], index_t* j, index_t* k) {
     *j = 0;
     *k = 0;
     #pragma unroll
     for (index_t i = ndim-1, idx_t = idx; i >=0; --i) {
       const auto tmp = idx_t / shape[i];
       const auto coord = idx_t - tmp*shape[i];
       *j += coord*stridej[i];
       *k += coord*stridek[i];
       idx_t = tmp;
     }
   }
   
   template<int ndim>
   __device__ inline index_t unravel_dot(const index_t idx, const index_t (&shape)[MAX_DIM],
     const index_t (&stride)[MAX_DIM]) {
     index_t ret = 0;
     #pragma unroll
     for (index_t i = ndim-1, j = idx; i >=0; --i) {
       auto tmp = j / shape[i];
       ret += (j - tmp*shape[i])*stride[i];
       j = tmp;
     }
     return ret;
   }
   
   template<int ndim>
   __device__ inline index_t unravel_ravel(const index_t idx, const index_t (&shape1)[MAX_DIM],
                                           const index_t (&shape2)[MAX_DIM]) {
     index_t ret = 0;
     index_t total_shape = 1;
   #pragma unroll
     for (index_t i = ndim-1, j = idx; i >=0; --i) {
       if (i != ndim - 1) {
         total_shape *= shape2[i + 1];
       }
       auto tmp = j / shape1[i];
       const index_t coord = j - tmp*shape1[i];
       ret += total_shape * (shape2[i] > coord) * coord;
       j = tmp;
     }
     return ret;
   }
   
   template<int ndim, int ndim2>
   __device__ inline index_t ravel(const index_t (&coord)[ndim], const index_t (&shape)[ndim2]) {
     index_t ret = 0;
   #pragma unroll
     for (int i = 0; i < ndim; ++i) {
       ret = ret * shape[i] + (shape[i] > coord[i]) * coord[i];
     }
     return ret;
   }
   
   template<int ndim, int ndim2>
   __device__ inline void unravel(const index_t idx,
                                  const index_t (&shape)[ndim2],
                                  index_t (&coord)[ndim]) {
   #pragma unroll
     for (index_t i = ndim-1, j = idx; i >=0; --i) {
       auto tmp = j / shape[i];
       coord[i] = j - tmp*shape[i];
       j = tmp;
     }
   }
   
   template <typename DType>
   __device__ inline bool isinf(volatile const DType &val) {
     return false;
   }
   
   template <>
   __device__ inline bool isinf(volatile const float &val) {
     return ::isinf(val);
   }
   
   template <>
   __device__ inline bool isinf(volatile const double &val) {
     return ::isinf(val);
   }
   
   template <>
   __device__ inline bool isinf(volatile const long double &val) {
     return ::isinf(val);
   }
   
   template <>
   __device__ inline bool isinf(volatile const float16 &val) {
     return ::isinf(__half2float(const_cast<const float16&>(val)));
   }
   
   template <typename DType>
   __device__ inline bool isnan(volatile const DType &val) {
     return false;
   }
   
   template <>
   __device__ inline bool isnan(volatile const float &val) {
     return ::isnan(val);
   }
   
   template <>
   __device__ inline bool isnan(volatile const double &val) {
     return ::isnan(val);
   }
   
   template <>
   __device__ inline bool isnan(volatile const long double &val) {
     return ::isnan(val);
   }
   
   template <>
   __device__ inline bool isnan(volatile const float16 &val) {
     return ::isnan(__half2float(const_cast<const float16&>(val)));
   }
   
   }  // namespace util
   
   
   constexpr double DBL_MAX = 1.7976931348623157081e+308;
   
   namespace op {
   
   namespace special_functions {
   
   template<typename DType>
   __device__ inline static DType trigamma(DType x);
   
   template<>
   __device__ inline double trigamma<double>(double x) {
     double PI(3.14159265358979323846);
     double sign = +1;
     double result = 0;
     if (x < 0.5) {
       sign = -1;
       const double sin_pi_x = sin(PI * x);
       result -= (PI * PI) / (sin_pi_x * sin_pi_x);
       x = 1 - x;
     }
     for (int i = 0; i < 6; ++i) {
       result += 1 / (x * x);
       x += 1;
     }
     const double ixx = 1 / (x*x);
     result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x;
     return sign * result;
   }
   
   template<>
   __device__ inline float trigamma<float>(float x) {
     float PI(3.14159265358979323846);
     float sign = +1;
     float result = 0;
     if (x < 0.5f) {
       sign = -1;
       const float sin_pi_x = sinf(PI * x);
       result -= (PI * PI) / (sin_pi_x * sin_pi_x);
       x = 1 - x;
     }
     for (int i = 0; i < 6; ++i) {
       result += 1 / (x * x);
       x += 1;
     }
     const float ixx = 1 / (x*x);
     result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x;
     return sign * result;
   }
   
   struct cephes {
     /*
      * Helper to evaluate a polynomial given an array of coefficients.
      */
     template <typename DType>
     __device__ inline static DType polevl(DType x, const DType coef[], int N) {
       DType ans;
       DType const *p;
       int i;
   
       p = coef;
       ans = *p++;
   
       i = N;
       do {
         ans = ans * x  +  *p++;
       } while ( --i );
   
       return( ans );
     }
   
   
     /*
      * Helper function for psi that handles double/float specific differences
      * in the algorithm.
      */
     template<typename DType>
     __device__ inline static DType psi_helper(DType s);
   
     /*
      *
      *	Psi (digamma) function
      *
      *
      * SYNOPSIS:
      *
      * float x, y, psif();
      *
      * y = psif( x );
      *
      *
      * DESCRIPTION:
      *
      *              d      -
      *   psi(x)  =  -- ln | (x)
      *              dx
      *
      * is the logarithmic derivative of the gamma function.
      * For integer x,
      *                   n-1
      *                    -
      * psi(n) = -EUL  +   >  1/k.
      *                    -
      *                   k=1
      *
      * This formula is used for 0 < n <= 10.  If x is negative, it
      * is transformed to a positive argument by the reflection
      * formula  psi(1-x) = psi(x) + pi cot(pi x).
      * For general positive x, the argument is made greater than 10
      * using the recurrence  psi(x+1) = psi(x) + 1/x.
      * Then the following asymptotic expansion is applied:
      *
      *                           inf.   B
      *                            -      2k
      * psi(x) = log(x) - 1/2x -   >   -------
      *                            -        2k
      *                           k=1   2k x
      *
      * where the B2k are Bernoulli numbers.
      *
      * ACCURACY:
      *    Absolute error,  relative when |psi| > 1 :
      * arithmetic   domain     # trials      peak         rms
      *    IEEE      -33,0        30000      8.2e-7      1.2e-7
      *    IEEE      0,33        100000      7.3e-7      7.7e-8
      *
      * ERROR MESSAGES:
      *     message         condition      value returned
      * psi singularity    x integer <=0      MAXNUMF
      */
     template<typename DType>
     __device__ inline static DType psi(DType x) {
       DType p, q, nz, s, w, y;
       int i, n, negative;
   
       DType EUL(0.57721566490153286061);
       DType PI(3.14159265358979323846);
   
       negative = 0;
       nz = 0.0;
   
       if ( x <= 0.0 ) {
         negative = 1;
         q = x;
         p = ::floor(q);
         if ( p == q ) {
           return DBL_MAX;
         }
         /* Remove the zeros of tan(PI x)
          * by subtracting the nearest integer from x
          */
         nz = q - p;
         if ( nz != 0.5 ) {
           if ( nz > 0.5 ) {
             p += 1.0;
             nz = q - p;
           }
           nz = PI/::tan(PI*nz);
         } else {
           nz = 0.0;
         }
         x = 1.0 - x;
       }
   
       /* check for positive integer up to 10 */
       if ( (x <= 10.0) && (x == ::floor(x)) ) {
         y = 0.0;
         n = x;
         for ( i = 1; i < n; i++ ) {
           w = i;
           y += 1.0/w;
         }
         y -= EUL;
         goto done;
       }
   
       s = x;
       w = 0.0;
       while ( s < 10.0 ) {
         w += 1.0/s;
         s += 1.0;
       }
   
       y = psi_helper(s);
   
       y = logf(s)  -  (0.5/s)  -  y  -  w;
   
   done:
   
       if ( negative ) {
         y -= nz;
       }
   
       return(y);
     }
   };
   
   
   template<>
   __device__ inline double cephes::psi_helper<double>(double s) {
     double z;
     const double A[] = {
       8.33333333333333333333E-2,
       -2.10927960927960927961E-2,
       7.57575757575757575758E-3,
       -4.16666666666666666667E-3,
       3.96825396825396825397E-3,
       -8.33333333333333333333E-3,
       8.33333333333333333333E-2
     };
   
     if ( s < 1.0e17 ) {
       z = 1.0/(s * s);
       return z * cephes::polevl<double>(z, A, 6);
     } else {
       return 0.0;
     }
   }
   
   template<>
   __device__ inline float cephes::psi_helper<float>(float s) {
     float z;
     const float A[] = {
       -4.16666666666666666667E-3f,
       3.96825396825396825397E-3f,
       -8.33333333333333333333E-3f,
       8.33333333333333333333E-2f
     };
   
     if ( s < 1.0e8 ) {
       z = 1.0/(s * s);
       return z * cephes::polevl<float>(z, A, 3);
     } else {
       return 0.0;
     }
   }
   }  // namespace special_functions
   }  // namespace op
   
   
   
   namespace vector {
   
   template <int size>
   struct VectorType {
       static_assert(size <= 32, "VectorType needs to have size of at most 32B");
   };
   
   template <>
   struct VectorType<1> {
     using type = char;
   };
   
   template <>
   struct VectorType<2> {
     using type = short;
   };
   
   
   template <>
   struct VectorType<4> {
     using type = int;
   };
   
   template <>
   struct VectorType<8> {
     using type = long long;
   };
   
   template <>
   struct VectorType<16> {
     using type = ulonglong2;
   };
   
   template <>
   struct VectorType<32> {
     using type = ulonglong4;
   };
   
   template <typename DType>
   __device__ inline DType add_elem(const DType& x, const DType& y) {
     return x + y;
   }
   
   template <>
   __device__ inline half add_elem(const half& x, const half& y) {
     return __float2half(__half2float(x) + __half2float(y));
   }
   
   /* \brief Helper class that enables storing multiple values of type DType
             as 1 value of type LType.
   */
   template <typename DType, int n>
   class VectorizedStorage {
    public:
     using LType = typename VectorType<sizeof(DType) * n>::type;
     constexpr static int nvec = n;
     union vectorized_storage {
       LType aligned;
       DType separate[nvec];  // NOLINT(*)
   
       inline __device__ vectorized_storage() {}
       inline __device__ ~vectorized_storage() {}
     } scratch_;
   
     inline __device__ VectorizedStorage() {}
     inline __device__ VectorizedStorage (const VectorizedStorage<DType, n>& y2) {
         scratch_.aligned = y2.scratch_.aligned;
     }
     inline __device__ VectorizedStorage (const LType &y2) {
         scratch_.aligned = y2;
     }
     inline __device__ VectorizedStorage<DType, n>& operator+=(
         const VectorizedStorage<DType, n>& rhs) {
       #pragma unroll
       for (int i = 0; i < nvec; ++i) {
         scratch_.separate[i] = add_elem(scratch_.separate[i], rhs.scratch_.separate[i]);
       }
       return *this;
     }
     inline __device__ ~VectorizedStorage() {}
   };
   
   // Returns const LType is DType is const
   template <typename DType, typename LType>
   struct select_const {
     using type = LType;
   };
   
   template <typename DType, typename LType>
   struct select_const<const DType, LType> {
     using type = const LType;
   };
   
   template <typename DType>
   struct remove_const {
     using type = DType;
   };
   
   template <typename DType>
   struct remove_const<const DType> {
     using type = DType;
   };
   
   
   /* \brief Helper class that enables accessing multiple values of type DType
             as 1 value of type LType. Additional aligned template argument
             allows performance optimizations if the pointer and the size of
             the allocation is aligned to sizeof(LType) / sizeof(DType) elements.
   */
   template <typename DType, int nvec, bool aligned = false>
   class VectorizedAccessor {
    public:
     using StorageType = VectorizedStorage<typename remove_const<DType>::type,
                                           nvec>;
     using LType = typename select_const<DType, typename StorageType::LType>::type;
     StorageType storage_;
   
     LType* aligned_ptr_;
     DType* unaligned_ptr_;
     int alignment_;
     index_t n_elems_;
   
     inline __device__ VectorizedAccessor(DType* const ptr, const index_t size) {
       unaligned_ptr_ = ptr;
       if (aligned) {
         alignment_ = 0;
         aligned_ptr_ = reinterpret_cast<LType*>(ptr);
         n_elems_ = (size + nvec- 1) / nvec;
       } else {
         size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
         alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
         aligned_ptr_ = reinterpret_cast<LType*>(ptr - alignment_);
         n_elems_ = (size + alignment_ + nvec - 1) / nvec;
       }
     }
   
     /* \brief Alignment of the input pointer in elements. */
     inline __device__ int alignment() const {
       return alignment_;
     }
   
     /* \brief Access to separate elements. */
     inline __device__ DType* separate() {
       return storage_.scratch_.separate;
     }
   
     /* \brief Number of aligned elements that span the entire input tensor. */
     inline __device__ index_t num_aligned_elements() const {
       return n_elems_;
     }
   
     /* \brief Load values from the input.
        \param id Aligned index of the element.
        \param N size of the tensor.
     */
     inline __device__ void load(const index_t id, const index_t N) {
       if (aligned) {
         storage_.scratch_.aligned = aligned_ptr_[id];
       } else {
         if (id > 0 && id < n_elems_ - 1) {
           storage_.scratch_.aligned = aligned_ptr_[id];
         } else {
   #pragma unroll
           for (int j = 0; j < nvec; ++j) {
             DType* ptr = reinterpret_cast<DType*>(&(aligned_ptr_[id])) + j;
             if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(unaligned_ptr_) &&
                 reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(unaligned_ptr_ + N)) {
               storage_.scratch_.separate[j] = *ptr;
             }
           }
         }
       }
     }
   };
   
   /* \brief Class used for vectorized read-only access. */
   template <typename DType, int nvec, bool aligned = false>
   class VectorizedLoader : public VectorizedAccessor<const DType, nvec, aligned> {
    public:
     inline __device__ VectorizedLoader(const DType* ptr, const index_t N) :
       VectorizedAccessor<const DType, nvec, aligned>(ptr, N) {
     }
   };
   
   /* \brief Class used for vectorized writable access. */
   template <typename DType, int nvec, bool aligned = false>
   class VectorizedStorer : public VectorizedAccessor<DType, nvec, aligned> {
    public:
     inline __device__ VectorizedStorer(DType* ptr, const index_t N) :
       VectorizedAccessor<DType, nvec, aligned>(ptr, N) {
     }
   
     /* \brief Store values to the output.
        \param id Aligned index of the element.
        \param N size of the tensor.
     */
     inline __device__ void store(const index_t id, const index_t N) {
       if (aligned) {
         this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
       } else {
         if (id > 0 && id < this->n_elems_ - 1) {
           this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
         } else {
   #pragma unroll
           for (int j = 0; j < nvec; ++j) {
             DType* ptr = reinterpret_cast<DType*>(&(this->aligned_ptr_[id])) + j;
             if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(this->unaligned_ptr_) &&
                 reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) {
               *ptr = this->storage_.scratch_.separate[j];
             }
           }
         }
       }
     }
   };
   
   }  // namespace vector
   
   
   
   
   #define INT_MAX (2147483647)
   
   namespace op {
   
   template <typename DType>
   struct LoadType {
     using Type = DType;
   };
   
   template <>
   struct LoadType<half> {
     using Type = float;
   };
   
   template <typename DType>
   __device__ inline typename LoadType<DType>::Type load(const DType input) {
     return input;
   }
   
   template <>
   __device__ inline float load(const half input) {
     return __half2float(input);
   }
   
   template <typename DType1, typename DType2>
   __device__ inline DType1 store(const DType2 input, DType1* ref) {
     return input;
   }
   
   template <typename DType>
   __device__ inline half store(const DType input, half* ref) {
     return __float2half(input);
   }
   
   template <int ndim>
   struct Shape {
      int x[ndim];
      size_t size;
      __device__ inline const int& operator [](const int i) const {
          return x[i];
      }
      __device__ inline int& operator [](const int i) {
          return x[i];
      }
      __device__ inline void set(const int def) {
          #pragma unroll
          for (int i = 0; i < ndim; i++) {
              x[i] = def;
          }
      }
   };
   
   template <>
   struct Shape<0> {
      size_t size;
   };
   
   template <int nvec, typename DType, int ndim>
   __device__ inline vector::VectorizedStorage<DType, nvec> load_index(const DType * input, int i,
                                                                       const Shape<ndim> &shape) {
     using V = vector::VectorizedStorage<DType, nvec>;
     if (i < shape.size) {
       const auto* vector_input = reinterpret_cast<const typename V::LType *>(input + i);
       return V(*vector_input);
     } else {
       return V({0});
     }
   }
   
   template <int nvec, typename DType, int ndim>
   __device__ inline vector::VectorizedStorage<DType, nvec> global_load_index(const DType * input,
                       int i, const Shape<ndim> &shape) {
     using V = vector::VectorizedStorage<DType, nvec>;
     if (i < shape.size) {
       const auto* vector_input = reinterpret_cast<const typename V::LType *>(input + i);
       return V(__ldg(vector_input));
     } else {
       return V({0});
     }
   }
   
   template <int nvec, typename DType, int ndim>
   __device__ inline vector::VectorizedStorage<DType, nvec> load_slice(const DType * input,
                                                                       const Shape<ndim>& shape,
                                                                       Shape<ndim> begin,
                                                                       Shape<ndim> end,
                                                                       int offset) {
     int idx[nvec];
   
     Shape<ndim> ref_strides;
     Shape<ndim> strides;
     ref_strides[ndim-1] = 1;
     strides[ndim-1] = 1;
     #pragma unroll
     for (int dim = ndim-1; dim >=0; dim--) {
       if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
       if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
       if (end[dim] == INT_MAX) end[dim] = shape[dim];
       if (dim > 0) {
         ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
         strides[dim-1] = strides[dim] * shape[dim];
       }
     }
     #pragma unroll
     for (int j = 0; j < nvec; j++) {
       idx[j] = 0;
       int ref_idx = offset + j;
       #pragma unroll
       for (int dim = 0; dim < ndim; dim++) {
          int stride = ref_strides[dim];
          if (shape[dim] > 1) {
            idx[j] += (ref_idx / stride + begin[dim]) * strides[dim];
          }
          ref_idx = ref_idx % stride;
       }
     }
     vector::VectorizedStorage<DType, nvec> ret;
     #pragma unroll
     for (int j = 0; j < nvec; j++) {
         ret.scratch_.separate[j] = *(input + idx[j]);
     }
     return ret;
   }
   
   template <int nvec, typename DType, int ndim>
   __device__ inline vector::VectorizedStorage<DType, nvec> fast_load_slice(const DType * input,
                                                                            const Shape<ndim>& shape,
                                                                            Shape<ndim> begin,
                                                                            Shape<ndim> end,
                                                                            int offset) {
     int idx = 0;
   
     Shape<ndim> ref_strides;
     Shape<ndim> strides;
     ref_strides[ndim-1] = 1;
     strides[ndim-1] = 1;
     #pragma unroll
     for (int dim = ndim-1; dim >=0; dim--) {
       if (begin[dim] < 0) begin[dim] = shape[dim] + begin[dim];
       if (end[dim] < 0) end[dim] = shape[dim] + end[dim];
       if (end[dim] == INT_MAX) end[dim] = shape[dim];
       if (dim > 0) {
         ref_strides[dim-1] = ref_strides[dim] * (end[dim] - begin[dim]);
         strides[dim-1] = strides[dim] * shape[dim];
       }
     }
     int ref_idx = offset;
     #pragma unroll
     for (int dim = 0; dim < ndim; dim++) {
        int stride = ref_strides[dim];
        if (shape[dim] > 1) {
          idx += (ref_idx / stride + begin[dim]) * strides[dim];
        }
        ref_idx = ref_idx % stride;
     }
     return global_load_index<nvec>(input, idx, shape);
   }
   
   template <int nvec, typename DType, int ndim>
   __device__ inline void store_index(const vector::VectorizedStorage<DType, nvec> value, int i,
                           DType * output, const Shape<ndim>& shape) {
     if (i < (shape.size + nvec - 1) / nvec) {
       auto vector_output = reinterpret_cast<
                             typename vector::VectorizedStorage<DType, nvec>::LType *>(output);
       vector_output[i] = value.scratch_.aligned;
     }
   }
   
   template <int nvec, typename DType, int ndim>
   __device__ inline void store_add_index(const vector::VectorizedStorage<DType, nvec> value, int i,
                               DType * output, const Shape<ndim>& shape) {
     if (i < (shape.size + nvec - 1) / nvec) {
       auto vector_output = reinterpret_cast<
                             typename vector::VectorizedStorage<DType, nvec>::LType *>(output);
       vector::VectorizedStorage<DType, nvec> ret(vector_output[i]);
       ret += value;
       vector_output[i] = ret.scratch_.aligned;
     }
   }
   
   }  // namespace op
   
   
   namespace op {
   
   template <typename DType>
   __device__ inline bool isnan(const DType val) {
     return util::isnan(val);
   }
   
   template <typename DType>
   __device__ inline bool_t isinf(const DType val) {
     return util::isinf(val);
   }
   
   template <typename DType>
   __device__ inline bool_t isposinf(const DType val) {
     return util::isinf(val) && (val > 0);
   }
   
   template <typename DType>
   __device__ inline bool_t isneginf(const DType val) {
     return util::isinf(val) && (val < 0);
   }
   
   template <typename DType>
   __device__ inline bool_t isfinite(const DType val) {
     return !op::isnan(val) && !op::isinf(val);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   add(const DType a, const DType2 b) {
     return a + b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   sub(const DType a, const DType2 b) {
     return a - b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rsub(const DType a, const DType2 b) {
     return b - a;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   mul(const DType a, const DType2 b) {
     return a * b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   div(const DType a, const DType2 b) {
     return a / b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rdiv(const DType a, const DType2 b) {
     return b / a;
   }
   
   #define DEFINE_BINARY_MATH_FUNC(name, double_version, float_version) \
   template <typename DType, typename DType2> \
   __device__ inline typename type_util::mixed_type<DType, DType2>::type \
   name (const DType a, const DType2 b) { \
     if (type_util::has_double_or_integral<DType, DType2>::value) { \
       return double_version ((double)a, (double)b); \
     } else { \
       return float_version ((float)a, (float)b); \
     } \
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   power (const DType a, const DType2 b) {
     if (type_util::has_double<DType, DType2>::value) {
       return ::pow ((double)a, (double)b); \
     } else {
       return ::powf ((float)a, (float)b);
     }
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rpow(const DType a, const DType2 b) {
     return power(b, a);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   max(const DType a, const DType2 b) {
     if (isnan(a)) return a;
     return a > b ? a : b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   fmax(const DType a, const DType2 b) {
     if (isnan(b)) return a;
     return a > b ? a : b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   min(const DType a, const DType2 b) {
     if (isnan(a)) return a;
     return a < b ? a : b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   fmin(const DType a, const DType2 b) {
     if (isnan(b)) return a;
     return a < b ? a : b;
   }
   
   DEFINE_BINARY_MATH_FUNC(hypot, ::hypot, ::hypotf)
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   mod(const DType a, const DType2 b) {
     if (b == 0) {
       return 0;
     }
     const double ad = static_cast<double>(a);
     const double bd = static_cast<double>(b);
     if (bd < 0) {
       if (ad < 0) {
         return -::fmod(-ad, -bd);
       } else {
         return ::fmod(ad, -bd) +
                (::fmod(ad, -bd) != 0 ? bd : 0);
       }
     } else {
       if (ad < 0) {
         return -::fmod(-ad, bd) +
                 (::fmod(-ad, bd) != 0 ? bd : 0);
       } else {
         return ::fmod(ad, bd);
       }
     }
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   fmod(const DType a, const DType2 b) {
     if (b == 0) {
       return 0;
     }
     return ::fmod(static_cast<double>(a), static_cast<double>(b));
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rmod(const DType a, const DType2 b) {
     return op::mod(b, a);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rfmod(const DType a, const DType2 b) {
     return op::fmod(b, a);
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType equal(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a == real_b ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType not_equal(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a != real_b ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType greater(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a > real_b ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType greater_equal(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a >= real_b ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType less(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a < real_b ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType less_equal(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a <= real_b ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline bool_t np_equal(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a == real_b ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline bool_t np_not_equal(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a != real_b ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline bool_t np_greater(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a > real_b ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline bool_t np_greater_equal(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a >= real_b ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline bool_t np_less(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a < real_b ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline bool_t np_less_equal(const DType a, const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a <= real_b ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType logical_and(const DType a, const DType2 b) {
     return a && b ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType logical_or(const DType a, const DType2 b) {
     return a || b ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType logical_xor(const DType a, const DType2 b) {
     return ((a || b) && !(a && b)) ? 1 : 0;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType copysign(const DType a, const DType2 b) {
     return (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType2 rcopysign(const DType a, const DType2 b) {
     return copysign(b, a);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   lcm(const DType a, const DType2 b) {
     if (type_util::is_integral<DType>::value &&
         type_util::is_integral<DType2>::value) {
       DType A = a;
       DType2 B = b;
       // minus cases.
       if (a < 0) {
         A = -a;
       }
       if (b < 0) {
         B = -b;
       }
       // handle zero-valued cases.
       DType c;
       if (a == 0 || b == 0) {
         c = 0;
       } else {
         DType tmp;
         DType tmp_a = A;
         DType tmp_b = B;
         if (A < B) {
           tmp = A;
           A = B;
           B = tmp;
         }
         while (A % B != 0) {
           A = A % B;
           tmp = A;
           A = B;
           B = tmp;
         }
         c = tmp_a / B * tmp_b;
       }
       return c;
     } else {
       return 0;
     }
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type bitwise_xor(const DType a,
                                                                          const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a ^ real_b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type bitwise_or(const DType a,
                                                                          const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a | real_b;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type bitwise_and(const DType a,
                                                                          const DType2 b) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     const mixed_type real_a = a;
     const mixed_type real_b = b;
     return real_a & real_b;
   }
   
   DEFINE_BINARY_MATH_FUNC(arctan2, ::atan2, ::atan2f)
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rarctan2(const DType a, const DType2 b) {
     return arctan2(b, a);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   ldexp(const DType a, const DType2 b) {
     if (type_util::has_double_or_integral<DType, DType2>::value) {
       return a * ::pow(2.0, static_cast<double>(b));
     } else {
       return a * ::powf(2.0f, static_cast<float>(b));
     }
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rldexp(const DType a, const DType2 b) {
     return ldexp(b, a);
   }
   
   #undef DEFINE_BINARY_MATH_FUNC
   
   template <typename DType, typename DType2>
   __device__ inline bool np_logical_and(const DType val, const DType2 val2) {
     return (val && val2) ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline bool np_logical_or(const DType val, const DType2 val2) {
     return (val || val2) ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline bool np_logical_xor(const DType val, const DType2 val2) {
     return ((val || val2) && !(val && val2)) ? true : false;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType left(const DType left_val, const DType2 right_val) {
     return left_val;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType2 right(const DType left_val, const DType2 right_val) {
     return right_val;
   }
   
   }  // namespace op
   
   
   namespace op {
   
   template <typename DType>
   __device__ inline DType identity(const DType val) {
     return val;
   }
   
   template <typename DType>
   __device__ inline DType negation(const DType val) {
     return -val;
   }
   
   template <typename OutType, typename DType>
   __device__ inline typename LoadType<OutType>::Type cast(const DType val) {
     return static_cast<typename LoadType<OutType>::Type>(val);
   }
   
   // activations
   
   template <typename DType>
   __device__ inline DType relu(const DType val) {
     return (isnan(val) || val > 0) ? val : 0;
   }
   
   template <typename DType>
   __device__ inline DType sigmoid(const DType val) {
     if (type_util::has_double_or_integral<DType>::value) {
       return 1./(1 + ::exp(-val));
     } else {
       return 1.f/(1 + expf(-val));
     }
   }
   
   template <typename DType>
   __device__ inline DType softrelu(const DType val) {
     if (type_util::has_double_or_integral<DType>::value) {
       return ::log(1 + ::exp(val));
     } else {
       return logf(1 + expf(val));
     }
   }
   
   template <typename DType>
   __device__ inline DType softsign(const DType val) {
     if (type_util::has_double_or_integral<DType>::value) {
       return val / (1 + fabs(val));
     } else {
       return val / (1 + fabsf(val));
     }
   }
   
   // exp and log
   
   #define DEFINE_UNARY_MATH_FUNC(name, double_version, float_version) \
   template <typename DType> \
   __device__ inline DType name (const DType a) { \
     if (type_util::has_double_or_integral<DType>::value) { \
       return double_version ((double)a); \
     } else { \
       return float_version (a); \
     } \
   }
   
   DEFINE_UNARY_MATH_FUNC(exp, ::exp, ::expf)
   DEFINE_UNARY_MATH_FUNC(expm1, ::expm1, ::expm1f)
   DEFINE_UNARY_MATH_FUNC(log, ::log, ::logf)
   DEFINE_UNARY_MATH_FUNC(log10, ::log10, ::log10f)
   DEFINE_UNARY_MATH_FUNC(log2, ::log2, ::log2f)
   DEFINE_UNARY_MATH_FUNC(log1p, ::log1p, ::log1pf)
   
   // trigonometric
   
   constexpr double pi = 3.14159265358979323846;
   
   template <typename DType>
   __device__ inline DType degrees(const DType val) {
     if (type_util::has_double_or_integral<DType>::value) {
       return (val / pi) * 180;
     } else {
       return (val / static_cast<float>(pi)) * 180.f;
     }
   }
   
   template <typename DType>
   __device__ inline DType radians(const DType val) {
     if (type_util::has_double_or_integral<DType>::value) {
       return (val / 180.0) * pi;
     } else {
       return (val / 180.0f) * static_cast<float>(pi);
     }
   }
   
   DEFINE_UNARY_MATH_FUNC(sin, ::sin, ::sinf)
   DEFINE_UNARY_MATH_FUNC(cos, ::cos, ::cosf)
   DEFINE_UNARY_MATH_FUNC(tan, ::tan, ::tanf)
   DEFINE_UNARY_MATH_FUNC(arcsin, ::asin, ::asinf)
   DEFINE_UNARY_MATH_FUNC(arccos, ::acos, ::acosf)
   DEFINE_UNARY_MATH_FUNC(arctan, ::atan, ::atanf)
   
   DEFINE_UNARY_MATH_FUNC(sinh, ::sinh, ::sinhf)
   DEFINE_UNARY_MATH_FUNC(cosh, ::cosh, ::coshf)
   DEFINE_UNARY_MATH_FUNC(tanh, ::tanh, ::tanhf)
   DEFINE_UNARY_MATH_FUNC(arcsinh, ::asinh, ::asinhf)
   DEFINE_UNARY_MATH_FUNC(arccosh, ::acosh, ::acoshf)
   DEFINE_UNARY_MATH_FUNC(arctanh, ::atanh, ::atanhf)
   
   // sqrt
   
   DEFINE_UNARY_MATH_FUNC(sqrt, ::sqrt, ::sqrtf)
   DEFINE_UNARY_MATH_FUNC(rsqrt, ::rsqrt, ::rsqrtf)
   DEFINE_UNARY_MATH_FUNC(cbrt, ::cbrt, ::cbrtf)
   DEFINE_UNARY_MATH_FUNC(rcbrt, ::rcbrt, ::rcbrtf)
   
   template <typename DType>
   __device__ inline DType square(const DType val) {
     return val * val;
   }
   
   template <typename DType, typename... DTypes>
   __device__ inline typename LoadType<DType>::Type zero(const DType val, const DTypes... args) {
     return 0;
   }
   
   template <typename DType>
   __device__ inline typename LoadType<DType>::Type zero() {
     return 0;
   }
   
   template <typename DType, typename... DTypes>
   __device__ inline typename LoadType<DType>::Type one(const DType val, const DTypes... args) {
     return 1;
   }
   
   template <typename DType>
   __device__ inline typename LoadType<DType>::Type one() {
     return 1;
   }
   
   template <typename DType, typename... DTypes>
   __device__ inline typename LoadType<DType>::Type negone(const DType val, const DTypes... args) {
     return -1;
   }
   
   template <typename DType>
   __device__ inline typename LoadType<DType>::Type negone() {
     return -1;
   }
   
   template <typename DType>
   __device__ inline DType round(const DType val) {
     if (type_util::has_double<DType>::value) {
       return ::round((double)val);
     } else if (type_util::is_integral<DType>::value) {
       return val;
     } else {
       return ::roundf(val);
     }
   }
   
   template <typename DType>
   __device__ inline DType floor(const DType val) {
     if (type_util::has_double<DType>::value) {
       return ::floor((double)val);
     } else if (type_util::is_integral<DType>::value) {
       return val;
     } else {
       return ::floorf(val);
     }
   }
   
   template <typename DType>
   __device__ inline DType ceil(const DType val) {
     if (type_util::has_double<DType>::value) {
       return ::ceil((double)val);
     } else if (type_util::is_integral<DType>::value) {
       return val;
     } else {
       return ::ceilf(val);
     }
   }
   
   template <typename DType>
   __device__ inline DType rint(const DType val) {
     if (type_util::has_double<DType>::value) {
       return ::rint((double)val);
     } else if (type_util::is_integral<DType>::value) {
       return val;
     } else {
       return ::rintf(val);
     }
   }
   
   template <typename DType>
   __device__ inline DType fix(const DType val) {
     const auto f = floor(val);
     const auto c = ceil(val);
     return (f > 0 ? f : -f) < (c > 0 ? c : -c) ? f : c;
   }
   
   template <typename DType>
   __device__ inline DType trunc(const DType val) {
     if (type_util::has_double<DType>::value) {
       return ::trunc((double)val);
     } else if (type_util::is_integral<DType>::value) {
       return val;
     } else {
       return ::truncf(val);
     }
   }
   
   template <typename DType>
   __device__ inline DType clip(const DType val, const float a_min, const float a_max) {
     return max(min(val, a_max), a_min);
   }
   
   template <typename DType>
   __device__ inline DType sign(const DType val) {
     if (val < 0) return -1;
     return val > 0 ? 1 : 0;
   }
   
   template <typename DType>
   __device__ inline DType reciprocal(const DType val) {
     return 1.0f / val;
   }
   
   DEFINE_UNARY_MATH_FUNC(abs, ::fabs, ::fabsf)
   DEFINE_UNARY_MATH_FUNC(gamma, ::tgamma, ::tgammaf)
   DEFINE_UNARY_MATH_FUNC(gammaln, ::lgamma, ::lgammaf)
   DEFINE_UNARY_MATH_FUNC(erf, ::erf, ::erff)
   DEFINE_UNARY_MATH_FUNC(erfinv, ::erfinv, ::erfinvf)
   
   template <typename DType>
   __device__ inline DType gelu(const DType val) {
     return 0.5f * val * (1.0f + op::erf(val / op::sqrt(2.0f)));
   }
   
   template <typename DType1, typename DType2>
   __device__ inline DType1 smooth_l1(const DType1 val, const DType2 scalar) {
     const auto bsq = scalar * scalar;
     const auto ibsq = 1.0f / bsq;
     if (val > ibsq) {
       return val - 0.5f * ibsq;
     } else if (val < -ibsq) {
       return -val - 0.5f * ibsq;
     } else {
       return 0.5f * val * val * bsq;
     }
   }
   
   template <typename DType>
   __device__ inline DType digamma(const DType val) {
     if (type_util::has_double_or_integral<DType>::value) {
       return special_functions::cephes::psi<double>(val);
     } else {
       return special_functions::cephes::psi<float>(val);
     }
   }
   
   template <typename DType>
   __device__ inline DType logical_not(const DType val) {
     return val != DType(0) ? DType(0) : DType(1);
   }
   
   template <typename DType>
   __device__ inline bool_t np_logical_not(const DType val) {
     return !static_cast<bool>(val);
   }
   
   #undef DEFINE_UNARY_MATH_FUNC
   
   template <typename DType>
   __device__ inline DType bitwise_not(const DType a) {
     if (type_util::is_same<DType, bool_t>::value) {
       return !a;
     } else {
       return ~static_cast<int64>(a);
     }
   }
   
   }  // namespace op
   
   
   
   
   namespace op {
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_relu(const DTypeGrad grad, const DType val) {
     if (isnan(val)) return val;
     return val > 0 ? grad : 0;
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_sigmoid(const DTypeGrad grad, const DType out) {
     return grad * out * (1 - out);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_softrelu(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad * sigmoid(v);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_softsign(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     const auto ap1 = 1 + op::abs(v);
     return grad / (ap1 * ap1);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_abs(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad * op::sign(v);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_exp(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad * op::exp(v);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_expm1(const DTypeGrad grad, const DType val) {
     return backward_exp(grad, val);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_log(const DTypeGrad grad, const DType val) {
     return grad / val;
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_log10(const DTypeGrad grad, const DType val) {
     return grad / (val * op::log(static_cast<DTypeGrad>(10)));
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_log2(const DTypeGrad grad, const DType val) {
     return grad / (val * op::log(static_cast<DTypeGrad>(2)));
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_log1p(const DTypeGrad grad, const DType val) {
     return grad / (1 + val);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_sin(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad * op::cos(v);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_cos(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return -grad * op::sin(v);
   }
   
   // Uses output from tan
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_tan(const DTypeGrad grad, const DType out) {
     return grad * (out * out + 1);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_arcsin(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad / op::sqrt(1 - v*v);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_arccos(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return -grad / op::sqrt(1 - v*v);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_arctan(const DTypeGrad grad, const DType val) {
     return grad / (1 + val*val);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_degrees(const DTypeGrad grad, const DType /* val */) {
     return op::degrees(grad);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_radians(const DTypeGrad grad, const DType /* val */) {
     return op::radians(grad);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_sinh(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad * op::cosh(v);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_cosh(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad * op::sinh(v);
   }
   
   // Uses tanh output
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_tanh(const DTypeGrad grad, const DType out) {
     return grad * (1 - out * out);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_arcsinh(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad / op::sqrt(v * v + 1);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_arccosh(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     return grad / op::sqrt(v * v - 1);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_arctanh(const DTypeGrad grad, const DType val) {
     return grad / (1 - val * val);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_sqrt(const DTypeGrad grad, const DType out) {
     return 0.5 * grad / out;
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_rsqrt(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     const auto inv = 1 / v;
     return -0.5 * grad * op::sqrt(inv) * inv;
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_cbrt(const DTypeGrad grad, const DType out) {
     return grad / (3.0f * out * out);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_rcbrt(const DTypeGrad grad, const DType val) {
     const typename type_util::mixed_type<DTypeGrad, DType>::type v = val;
     const auto inv = 1 / v;
     return -1.f/3.f * grad * op::cbrt(inv) * inv;
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_square(const DTypeGrad grad, const DType val) {
     return 2 * val * grad;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rdiv_grad(const DType val,
             const DType2 val2) {
     return -val2 / (val * val);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   div_grad(const DType val,
            const DType2 val2) {
     const typename type_util::mixed_type<DType, DType2>::type temp = val2;
     return op::reciprocal(temp);
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType div_rgrad(const DType val,
                                     const DType2 val2) {
     return -val / (val2 * val2);
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType mod_grad(const DType val,
                                    const DType2 val2) {
     if (type_util::is_integral<DType>::value) {
       return 0;
     } else {
       return 1;
     }
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType mod_rgrad(const DType val,
                                     const DType2 val2) {
     if (type_util::is_integral<DType>::value) {
       return 0;
     } else {
       return -op::floor(val / val2);
     }
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType rmod_grad(const DType val,
                                     const DType2 val2) {
     if (type_util::is_integral<DType>::value) {
       return 0;
     } else {
       return -op::floor(val2 / val);
     }
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   power_grad(const DType val,
              const DType2 val2) {
     return op::power(val, val2 - 1.f) * val2;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   power_rgrad(const DType val,
               const DType2 val2) {
     const typename type_util::mixed_type<DType, DType2>::type temp = val;
     return op::power(val, val2) * op::log(temp);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rpower_grad(const DType val,
               const DType2 val2) {
     const typename type_util::mixed_type<DType, DType2>::type temp = val2;
     return val * op::log(temp);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   hypot_grad_left(const DType val,
                   const DType2 val2) {
     return val / op::hypot(val, val2);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   hypot_grad_right(const DType val,
                    const DType2 val2) {
     return val2 / op::hypot(val, val2);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   copysign_grad(const DType val,
                 const DType2 val2) {
     return (val >= 0 && val2 >= 0) || (val < 0 && val2 < 0) ? 1 : -1;
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   arctan2_grad(const DType val,
                const DType2 val2) {
     return val2 / (val * val + val2 * val2);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rarctan2_grad(const DType val,
                 const DType2 val2) {
     return val / (val * val + val2 * val2);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   arctan2_rgrad(const DType val,
                 const DType2 val2) {
     return -rarctan2_grad(val, val2);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   ldexp_grad(const DType val,
              const DType2 val2) {
     return op::power(static_cast<DType>(2), val2);
   }
   
   template <typename DType, typename DType2>
   __device__ inline typename type_util::mixed_type<DType, DType2>::type
   rldexp_grad(const DType val,
               const DType2 val2) {
     using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
     return val2 * op::power(static_cast<mixed_type>(2), val) * op::log(static_cast<mixed_type>(2));
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_clip(const DTypeGrad grad, const DType val,
                 const float a_min, const float a_max) {
     if (val > a_max || val < a_min) {
       return 0;
     } else {
       return grad;
     }
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_reciprocal(const DTypeGrad grad, const DType val) {
     return -grad / (val * val);
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_erf(const DTypeGrad grad, const DType val) {
     using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
     const mixed_type v = val;
     constexpr mixed_type my_pi = pi;
     return 2.0f / op::sqrt(my_pi) * op::exp(-(v*v)) * grad;
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_erfinv(const DTypeGrad grad, const DType val) {
     using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
     constexpr mixed_type my_pi = pi;
     const mixed_type g = grad;
     const mixed_type v = val;
     return 0.5f * op::sqrt(my_pi) * op::exp(v * v) * g;
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_gamma(const DTypeGrad grad, const DType val) {
     using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
     const mixed_type v = val;
     if (type_util::is_same<DTypeGrad, double>::value) {
       return grad * op::gamma(v) * op::special_functions::cephes::psi<double>(v);
     } else {
       return grad * op::gamma(v) * op::special_functions::cephes::psi<float>(v);
     }
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_gammaln(const DTypeGrad grad, const DType val) {
     using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
     const mixed_type v = val;
     if (type_util::is_same<DTypeGrad, double>::value) {
       return grad * op::special_functions::cephes::psi<double>(v);
     } else {
       return grad * op::special_functions::cephes::psi<float>(v);
     }
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_digamma(const DTypeGrad grad, const DType val) {
     using mixed_type = typename type_util::mixed_type<DTypeGrad, DType>::type;
     const mixed_type v = val;
     if (type_util::is_same<DTypeGrad, double>::value) {
       return grad * op::special_functions::trigamma<double>(v);
     } else {
       return grad * op::special_functions::trigamma<float>(v);
     }
   }
   
   template <typename DType, typename DTypeGrad>
   __device__ inline typename type_util::mixed_type<DTypeGrad, DType>::type
   backward_gelu(const DTypeGrad grad, const DType val) {
     return 0.5f * (grad + grad * op::erf(val / op::sqrt(2.0f)) +
                    val * backward_erf(grad, val / op::sqrt(2.0f)) / op::sqrt(2.0f));
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType smooth_l1_grad(const DType val, const DType2 scalar) {
     auto bsq = scalar * scalar;
     auto ibsq = 1.0f / bsq;
     if (val > ibsq) {
       return 1;
     } else if (val < -ibsq) {
       return -1;
     } else {
       return bsq * val;
     }
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType2 xelu_grad(const DType val,
                                      const DType2 val2) {
     return (val > 0) ? 1 : val2;
   }
   
   template <typename DType, typename DType2>
   __device__ inline DType prelu_grad(const DType val,
                                      const DType2 val2) {
     return (val > 0) ? 0 : val;
   }
   
   }  // namespace op
   
   
   
   
   namespace red {
   
   /*! \brief sum reducer */
   struct sum {
     /*! \brief do reduction into dst */
     template<typename DType, typename DType2>
     __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src) {
       dst = op::add(dst, src);
     }
   
     /*! \brief do stable reduction into dst */
     template<typename DType, typename DType2>
     __device__ inline static void Reduce(volatile DType& dst,  volatile DType2 src,
                                          volatile DType& residual) {
       DType y = op::sub(src, residual);
       DType t = dst + y;
       if (util::isinf(t)) {
         residual = 0;
       } else {
         residual = (t - dst) - y;
       }
       dst = t;
     }
     /*! \brief combine the results of two reducers */
     template<typename DType>
     __device__ inline static void Merge(volatile DType& dst_val, volatile DType& src_val) {
       Reduce(dst_val, src_val);
     }
     /*! \brief combine the results of two reducers */
     template<typename DType>
     __device__ inline static void Merge(volatile DType& dst_val, volatile DType& dst_residual,
                                         volatile DType& src_val, volatile DType& src_residual) {
       DType t1 = dst_val + src_val;
       if (util::isinf(t1)) {
         dst_val = t1;
         dst_residual = 0;
       } else {
         DType e = t1 - dst_val;
         DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual;
         dst_val = t1 + t2;
         dst_residual = t2 - (dst_val - t1);
       }
     }
     /*! \brief finalize reduction result */
     template<typename DType>
     __device__ inline static void Finalize(volatile DType& dst) {}
     /*! \brief finalize reduction result */
     template<typename DType>
     __device__ inline static void Finalize(volatile DType& dst, volatile DType& none) {}
     /*!
      *\brief set the initial value during reduction
      */
     template<typename DType>
     __device__ inline static void SetInitValue(DType &initv) {
       initv = 0;
     }
     /*!
      *\brief set the initial value during reduction
      */
     template<typename DType>
     __device__ inline static void SetInitValue(DType &initv, DType &residual) {
       SetInitValue(initv);
       residual = 0;
     }
   };
   }  // namespace red
   
   
   using DType_output1 = float;
   static const int ndim_output1 = 2;
   using DType_output0 = int;
   static const int ndim_output0 = 2;
   static const int ndim_input_0 = 2;
   using DType_input_0 = int;
   static const int nvec = 1;
   
   __launch_bounds__(512)
   __global__ void FusedKernel_clip_Cast(size_t N,  const op::Shape<2> input_0_shape,  const op::Shape<2> output0_shape,  const op::Shape<2> output1_shape, DType_input_0* input_0, DType_output0* output0, DType_output1* output1) {
   
   const int tid = threadIdx.x + blockIdx.x * blockDim.x;
   for (int i = tid; i < N; i+= gridDim.x * blockDim.x) {
       int offset = i*nvec;
   
   const auto vec_input_0 = op::load_index<nvec>(input_0, offset, input_0_shape);
   vector::VectorizedStorage<DType_output0, nvec> vec_output0;
   vector::VectorizedStorage<DType_output1, nvec> vec_output1;
   for (int j = 0; j < nvec; j++ ) {
   const auto temp0 = op::load(vec_input_0.scratch_.separate[j]);
   const auto temp2 = op::clip(temp0, 0, inf);
   const auto temp4 = op::cast<float32>(temp2);
   vec_output0.scratch_.separate[j] = op::store(temp2, output0);
   vec_output1.scratch_.separate[j] = op::store(temp4, output1);
   }
   op::store_index(vec_output0, i, output0, output0_shape);
   op::store_index(vec_output1, i, output1, output1_shape);
   
   }
   }
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on issue #19026: [Bug] RTC Failed to compile

Posted by GitBox <gi...@apache.org>.
ptrendx commented on issue #19026:
URL: https://github.com/apache/incubator-mxnet/issues/19026#issuecomment-682250717


   Ok, I believe `inf` is generated by the ffi for `np.clip` here: https://github.com/apache/incubator-mxnet/blob/master/src/api/operator/tensor/matrix_op.cc#L52
   
   I will make PR with support for fusion of `clip` without `a_min` or `a_max` parameters tomorrow.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience closed issue #19026: [Bug] RTC Failed to compile

Posted by GitBox <gi...@apache.org>.
sxjscience closed issue #19026:
URL: https://github.com/apache/incubator-mxnet/issues/19026


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org