You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/26 18:00:21 UTC

[GitHub] [tvm] JCBrouwer opened a new issue #10397: [Bug] Error: identifier “hfabs” is undefined

JCBrouwer opened a new issue #10397:
URL: https://github.com/apache/tvm/issues/10397


   ### Expected behavior
   
   It should be possible to take the absolute value of a half precision tensor.
   
   ### Actual behavior
   
   The kernel fails to compile due to hfabs not being found:
   
   <details>
   <summary>Stack trace</summary>
   
   ```
   Traceback (most recent call last):
     File "/home/hans/code/stylegan3/test.py", line 10, in <module>
       lib = relay.build(mod)
     File "/home/hans/code/tvm/python/tvm/relay/build_module.py", line 468, in build
       graph_json, runtime_mod, params = bld_mod.build(
     File "/home/hans/code/tvm/python/tvm/relay/build_module.py", line 196, in build
       self._build(mod, target, target_host, executor, runtime, workspace_memory_pools, mod_name)
     File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 267, in tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     7: TVMFuncCall
     6: tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
     5: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&)
     4: tvm::build(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
     3: tvm::codegen::Build(tvm::IRModule, tvm::Target)
     2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     1: tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
     0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) [clone .cold]
     File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
     File "/home/hans/code/tvm/python/tvm/contrib/nvcc.py", line 183, in tvm_callback_cuda_compile
       ptx = compile_cuda(code, target_format="fatbin")
     File "/home/hans/code/tvm/python/tvm/contrib/nvcc.py", line 113, in compile_cuda
       raise RuntimeError(msg)
   RuntimeError: #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
   #include <cuda_fp16.h>
   __device__ half max(half a, half b)
   {
     return __hgt(__half(a), __half(b)) ? a : b;
   }
   __device__ half min(half a, half b)
   {
     return __hlt(__half(a), __half(b)) ? a : b;
   }
   #else
   
   typedef unsigned short uint16_t;
   typedef unsigned char uint8_t;
   typedef signed char int8_t;
   typedef int int32_t;
   typedef unsigned long long uint64_t;
   typedef unsigned int uint32_t;
   
   #define TVM_FORCE_INLINE inline __attribute__((always_inline))
   #define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
   #define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
   #define TVM_HALF_OPERATOR(RTYPE, OP)                              \
     TVM_XINLINE RTYPE operator OP (half a, half b) {                \
       return RTYPE(float(a) OP float(b));                           \
     }                                                               \
     template<typename T>                                            \
     TVM_XINLINE RTYPE operator OP (half a, T b) {                   \
       return RTYPE(float(a) OP float(b));                           \
     }                                                               \
     template<typename T>                                            \
     TVM_XINLINE RTYPE operator OP (T a, half b) {                   \
       return RTYPE(float(a) OP float(b));                           \
     }
   
   #define TVM_HALF_ASSIGNOP(AOP, OP)                                \
     template<typename T>                                            \
     TVM_XINLINE half operator AOP (const T& a) {                    \
       return *this = half(float(*this) OP float(a));                \
     }                                                               \
     template<typename T>                                            \
     TVM_XINLINE half operator AOP (const volatile T& a) volatile {  \
       return *this = half(float(*this) OP float(a));                \
     }
   
   class TVM_ALIGNED(2) half {
    public:
     uint16_t half_;
   
     static TVM_XINLINE half Binary(uint16_t value) {
       half res;
       res.half_ = value;
       return res;
     }
   
     TVM_XINLINE half() {}
   
     TVM_XINLINE half(const float& value) { constructor(value); }
     TVM_XINLINE explicit half(const double& value) { constructor(value); }
     TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
     TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
     TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
     TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
     TVM_XINLINE explicit half(const long long& value) { constructor(value); }
     TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }
   
     TVM_XINLINE operator float() const {                          \
       return float(half2float(half_));                            \
     }                                                             \
     TVM_XINLINE operator float() const volatile {                 \
       return float(half2float(half_));                            \
     }
   
   
     TVM_HALF_ASSIGNOP(+=, +)
     TVM_HALF_ASSIGNOP(-=, -)
     TVM_HALF_ASSIGNOP(*=, *)
     TVM_HALF_ASSIGNOP(/=, /)
   
     TVM_XINLINE half operator+() {
       return *this;
     }
   
     TVM_XINLINE half operator-() {
       return half(-float(*this));
     }
   
     TVM_XINLINE half operator=(const half& a) {
       half_ = a.half_;
       return a;
     }
   
     template<typename T>
     TVM_XINLINE half operator=(const T& a) {
       return *this = half(a);
     }
   
     TVM_XINLINE half operator=(const half& a) volatile {
       half_ = a.half_;
       return a;
     }
   
     template<typename T>
     TVM_XINLINE half operator=(const T& a) volatile {
       return *this = half(a);
     }
   
    private:
     union Bits {
       float f;
       int32_t si;
       uint32_t ui;
     };
   
     static int const fp16FractionBits = 10;
     static int const fp32FractionBits = 23;
     static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);   // == 0x7fffff
     static int32_t const fp32HiddenBit = 1 << fp32FractionBits;   // == 0x800000
     static int const shift = fp32FractionBits - fp16FractionBits;   // == 13
     static int const shiftSign = 16;
     static int32_t const expAdjust = 127 - 15;   // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
   
     static int32_t const infN = 0x7F800000;   // flt32 infinity
     static int32_t const maxN = 0x477FFFFF;   // max flt32 that's a flt16 normal after >> by shift
     static int32_t const minN = 0x38800000;   // min flt16 normal as a flt32
     static int32_t const maxZ = 0x33000000;   // max fp32 number that's still rounded to zero in fp16
     static int32_t const signN = 0x80000000;  // flt32 sign bit
   
     static int32_t const infC = infN >> shift;
     static int32_t const nanN = (infC + 1) << shift;   // minimum flt16 nan as a flt32
     static int32_t const maxC = maxN >> shift;
     static int32_t const minC = minN >> shift;
     static int32_t const signC = signN >> shiftSign;  // flt16 sign bit
   
     static int32_t const mulN = 0x52000000;  // (1 << 23) / minN
     static int32_t const mulC = 0x33800000;  // minN / (1 << (23 - shift))
   
     static int32_t const subC = 0x003FF;  // max flt32 subnormal down shifted
     static int32_t const norC = 0x00400;  // min flt32 normal down shifted
   
     static int32_t const maxD = infC - maxC - 1;
     static int32_t const minD = minC - subC - 1;
   
     TVM_XINLINE uint16_t float2half(const float& value) const {
       Bits v;
       v.f = value;
       uint32_t sign = v.si & signN;    // grab sign bit
       v.si ^= sign;                    // clear sign bit from v
       sign >>= shiftSign;              // logical shift sign to fp16 position
   
       if (v.si <= maxZ) {
         // Handle eventual zeros here to ensure
         // vshift will not exceed 32 below.
         v.ui = 0;
       } else if (v.si < minN) {
         // Handle denorms
         uint32_t exp32 = v.ui >> fp32FractionBits;
         int32_t exp16 = exp32 - expAdjust;
         // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
         // Smaller (so negative) exp16 values should result in greater right shifts.
         uint32_t vshift = 1 - exp16;
         uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
         v.ui = significand >> vshift;
         v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
       } else if (v.si <= maxN) {
         // Handle norms
         v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
         v.ui -= expAdjust << fp32FractionBits;
       } else if (v.si <= infN) {
         v.si = infN;
       } else if (v.si < nanN) {
         v.si = nanN;
       }
   
       v.ui >>= shift;
       return sign | (v.ui & 0x7fff);
     }
   
     // Same as above routine, except for addition of volatile keyword
     TVM_XINLINE uint16_t float2half(
       const volatile float& value) const volatile {
       Bits v;
       v.f = value;
       uint32_t sign = v.si & signN;    // grab sign bit
       v.si ^= sign;                    // clear sign bit from v
       sign >>= shiftSign;              // logical shift sign to fp16 position
   
       if (v.si <= maxZ) {
         // Handle eventual zeros here to ensure
         // vshift will not exceed 32 below.
         v.ui = 0;
       } else if (v.si < minN) {
         // Handle denorms
         uint32_t exp32 = v.ui >> fp32FractionBits;
         int32_t exp16 = exp32 - expAdjust;
         // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
         // Smaller (so negative) exp16 values should result in greater right shifts.
         uint32_t vshift = 1 - exp16;
         uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
         v.ui = significand >> vshift;
         v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
       } else if (v.si <= maxN) {
         // Handle norms
         v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
         v.ui -= expAdjust << fp32FractionBits;
       } else if (v.si <= infN) {
         v.si = infN;
       } else if (v.si < nanN) {
         v.si = nanN;
       }
   
       v.ui >>= shift;
       return sign | (v.ui & 0x7fff);
     }
   
     TVM_XINLINE float half2float(const uint16_t& value) const {
       Bits v;
       v.ui = value;
       int32_t sign = v.si & signC;
       v.si ^= sign;
       sign <<= shiftSign;
       v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
       v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
       Bits s;
       s.si = mulC;
       s.f *= v.si;
       int32_t mask = -(norC > v.si);
       v.si <<= shift;
       v.si ^= (s.si ^ v.si) & mask;
       v.si |= sign;
       return v.f;
     }
   
     TVM_XINLINE float half2float(
       const volatile uint16_t& value) const volatile {
       Bits v;
       v.ui = value;
       int32_t sign = v.si & signC;
       v.si ^= sign;
       sign <<= shiftSign;
       v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
       v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
       Bits s;
       s.si = mulC;
       s.f *= v.si;
       int32_t mask = -(norC > v.si);
       v.si <<= shift;
       v.si ^= (s.si ^ v.si) & mask;
       v.si |= sign;
       return v.f;
     }
   
     template<typename T>
     TVM_XINLINE void constructor(const T& value) {
       half_ = float2half(float(value));
     }
   };
   
   TVM_HALF_OPERATOR(half, +)
   TVM_HALF_OPERATOR(half, -)
   TVM_HALF_OPERATOR(half, *)
   TVM_HALF_OPERATOR(half, /)
   TVM_HALF_OPERATOR(bool, >)
   TVM_HALF_OPERATOR(bool, <)
   TVM_HALF_OPERATOR(bool, >=)
   TVM_HALF_OPERATOR(bool, <=)
   
   TVM_XINLINE half __float2half_rn(const float a) {
     return half(a);
   }
   #endif
   
   
   // Pack two half values.
   static inline __device__ __host__ unsigned
   __pack_half2(const half x, const half y) {
     unsigned v0 = *((unsigned short *)&x);
     unsigned v1 = *((unsigned short *)&y);
     return (v1 << 16) | v0;
   }
   
   // Some fp16 math functions are not supported in cuda_fp16.h,
   // so we define them here to make sure the generated CUDA code
   // is valid.
   #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
   #define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
   static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) {   \
     float tmp_x = __half2float(x);                                          \
     float tmp_y = __half2float(y);                                          \
     float result = FP32_MATH_NAME(tmp_x, tmp_y);                            \
     return __float2half(result);                                            \
   }
   
   #define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
   static inline __device__ __host__ half HALF_MATH_NAME(half x) {          \
     float tmp_x = __half2float(x);                                         \
     float result = FP32_MATH_NAME(tmp_x);                                  \
     return __float2half(result);                                           \
   }
   
   CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
   CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
   CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
   CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
   CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)
   
   #undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
   #undef CUDA_UNSUPPORTED_HALF_MATH_UNARY
   
   #endif
   
   #ifdef _WIN32
     using uint = unsigned int;
     using uchar = unsigned char;
     using ushort = unsigned short;
     using int64_t = long long;
     using uint64_t = unsigned long long;
   #else
     #define uint unsigned int
     #define uchar unsigned char
     #define ushort unsigned short
     #define int64_t long long
     #define uint64_t unsigned long long
   #endif
   extern "C" __global__ void __launch_bounds__(1024) tvmgen_default_fused_abs_kernel0(half* __restrict__ T_abs, half* __restrict__ placeholder) {
     uint1 _1;
     uint1 _2 = ((uint1*)(placeholder + (((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 2)))))[0];
     ((half2*)(&(_1.x)))->x = hfabs(((half2*)(&(_2.x)))->x);
     ((half2*)(&(_1.x)))->y = hfabs(((half2*)(&(_2.x)))->y);
     ((uint1*)(T_abs + (((((int)blockIdx.x) * 2048) + (((int)threadIdx.x) * 2)))))[0] = _1;
   }
   
   
   Compilation error:
   /tmp/tmpvsm1thf9/my_kernel.cu(328): error: identifier "hfabs" is undefined
   
   /tmp/tmpvsm1thf9/my_kernel.cu(276): warning: function "__pack_half2" was declared but never referenced
   
   /tmp/tmpvsm1thf9/my_kernel.cu(301): warning: function "hpow" was declared but never referenced
   
   /tmp/tmpvsm1thf9/my_kernel.cu(302): warning: function "htanh" was declared but never referenced
   
   /tmp/tmpvsm1thf9/my_kernel.cu(303): warning: function "htan" was declared but never referenced
   
   /tmp/tmpvsm1thf9/my_kernel.cu(304): warning: function "hatan" was declared but never referenced
   
   /tmp/tmpvsm1thf9/my_kernel.cu(305): warning: function "herf" was declared but never referenced
   
   1 error detected in the compilation of "/tmp/tmpvsm1thf9/my_kernel.cu".
   
   
   During handling of the above exception, another exception occurred:
   
   Traceback (most recent call last):
     File "/home/hans/code/stylegan3/test.py", line 10, in <module>
       lib = relay.build(mod)
     File "/home/hans/code/tvm/python/tvm/target/target.py", line 130, in __exit__
       _ffi_api.TargetExitScope(self)
     File "tvm/_ffi/_cython/./packed_func.pxi", line 323, in tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 257, in tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./packed_func.pxi", line 246, in tvm._ffi._cy3.core.FuncCall3
     File "tvm/_ffi/_cython/./base.pxi", line 163, in tvm._ffi._cy3.core.CALL
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     2: TVMFuncCall
     1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::Target)>::AssignTypedLambda<void (*)(tvm::Target)>(void (*)(tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
     0: tvm::Target::ExitWithScope()
     File "/home/hans/code/tvm/src/target/target.cc", line 603
   TVMError: 
   ---------------------------------------------------------------
   An error occurred during the execution of TVM.
   For more information, please see: https://tvm.apache.org/docs/errors.html
   ---------------------------------------------------------------
     Check failed: (entry->context_stack.top().same_as(*this)) is false: 
   ```
   
   </details>
   
   ### Environment
   
   TVM 0.9.dev0 (compiled from source)
   CUDA 11.4
   
   ### Steps to reproduce
   
   ```python
   import tvm
   from tvm import relay
   
   x = relay.var('x', shape=(32, 128), dtype='float16')
   func = relay.Function([x], relay.abs(x))
   mod = tvm.IRModule.from_expr(func)
   
   with tvm.transform.PassContext(opt_level=3):
       with tvm.target.Target('cuda'):
           lib = relay.build(mod)
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi edited a comment on issue #10397: [Bug] Error: identifier “hfabs” is undefined

Posted by GitBox <gi...@apache.org>.
masahi edited a comment on issue #10397:
URL: https://github.com/apache/tvm/issues/10397#issuecomment-1052831955


   You can use this change https://github.com/apache/tvm/commit/c3c57c60a97761719d43781ae61bcc85ce6d11c0
   
   I'm not sure if we are supporting fp16 cuda math function properly, but for `habs` it should be correct. You can add my change to to your PR https://github.com/apache/tvm/pull/10396


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on issue #10397: [Bug] Error: identifier “hfabs” is undefined

Posted by GitBox <gi...@apache.org>.
masahi commented on issue #10397:
URL: https://github.com/apache/tvm/issues/10397#issuecomment-1052831955


   You can use this change https://github.com/apache/tvm/commit/c3c57c60a97761719d43781ae61bcc85ce6d11c0
   
   I'm not sure if we are supporting fp16 cuda math function properly, but for `habs` it should be correct. You can add my change to https://github.com/apache/tvm/pull/10396


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi closed issue #10397: [Bug] Error: identifier “hfabs” is undefined

Posted by GitBox <gi...@apache.org>.
masahi closed issue #10397:
URL: https://github.com/apache/tvm/issues/10397


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org