You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jc...@apache.org on 2021/03/06 10:25:04 UTC

[tvm] branch main updated: [CUDA] BF16 support (#7014)

This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8aa2a7c  [CUDA] BF16 support (#7014)
8aa2a7c is described below

commit 8aa2a7cdbc81a0633b1f78ab28f31921e9fa9e98
Author: Huang, Guangtai <hg...@foxmail.com>
AuthorDate: Sat Mar 6 18:24:49 2021 +0800

    [CUDA] BF16 support (#7014)
---
 include/tvm/runtime/data_type.h                   |  9 ++-
 python/tvm/contrib/nvcc.py                        | 16 ++++-
 python/tvm/runtime/ndarray.py                     |  4 +-
 src/target/source/codegen_cuda.cc                 | 72 ++++++++++++++++++++++-
 src/target/source/codegen_cuda.h                  |  4 +-
 src/target/source/intrin_rule_cuda.cc             |  2 +
 src/target/source/literal/cuda_half_t.h           | 24 ++++++++
 tests/python/unittest/test_target_codegen_cuda.py | 50 +++++++++++++++-
 8 files changed, 174 insertions(+), 7 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index d705be6..7d914ce 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -160,13 +160,20 @@ class DataType {
    */
   static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); }
   /*!
-   * \brief Construct an uint type.
+   * \brief Construct an float type.
    * \param bits The number of bits in the type.
    * \param lanes The number of lanes
    * \return The constructed data type.
    */
   static DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); }
   /*!
+   * \brief Construct an bfloat type.
+   * \param bits The number of bits in the type.
+   * \param lanes The number of lanes
+   * \return The constructed data type.
+   */
+  static DataType BFloat(int bits, int lanes = 1) { return DataType(kDLBfloat, bits, lanes); }
+  /*!
    * \brief Construct a bool type.
    * \param lanes The number of lanes
    * \return The constructed data type.
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index 2a97b0b..f33603b 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -302,8 +302,22 @@ def have_tensorcore(compute_version=None, target=None):
             major, minor = compute_version.split("_")[1]
             compute_version = major + "." + minor
     major, _ = parse_compute_version(compute_version)
+    if major >= 7:
+        return True
+
+    return False
+
+
+def have_bf16(compute_version):
+    """Either bf16 support is provided in the compute capability or not
 
-    if major == 7:
+    Parameters
+    ----------
+    compute_version : str
+        compute capability of a GPU (e.g. "8.0")
+    """
+    major, _ = parse_compute_version(compute_version)
+    if major >= 8:
         return True
 
     return False
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 75da3d4..5c60515 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -148,7 +148,9 @@ class NDArray(NDArrayBase):
                     source_array.shape, shape
                 )
             )
-        source_array = np.ascontiguousarray(source_array, dtype=dtype)
+        source_array = np.ascontiguousarray(
+            source_array, dtype="uint16" if dtype == "bfloat16" else dtype
+        )
         assert source_array.flags["C_CONTIGUOUS"]
         data = source_array.ctypes.data_as(ctypes.c_void_p)
         nbytes = ctypes.c_size_t(source_array.size * source_array.dtype.itemsize)
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index 2e9baba..e54acd2 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -61,6 +61,18 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << _cuda_half_util;
   }
 
+  if (enable_bf16_) {
+    decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)\n";
+    decl_stream << "#include <cuda_bf16.h>\n";
+    decl_stream << "__device__ nv_bfloat16 max"
+                << "(nv_bfloat16 a, nv_bfloat16 b)\n"
+                << "{\n  return __hgt(a, b) ? a : b;\n}\n";
+    decl_stream << "__device__ nv_bfloat16 min(nv_bfloat16 a, nv_bfloat16 b)\n"
+                << "{\n  return __hlt(a, b) ? a : b;\n}\n";
+    decl_stream << "#endif\n\n";
+    decl_stream << _cuda_bfloat16_util;
+  }
+
   if (enable_warp_shuffle_) {
     decl_stream << _cuda_warp_intrinsic_util;
   }
@@ -170,6 +182,17 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
       os << lanes;
       return;
     }
+  } else if (t.is_bfloat16()) {
+    enable_bf16_ = true;
+    if (t.is_scalar()) {
+      os << "nv_bfloat16";
+    } else if (lanes <= 8) {
+      ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
+      os << "uint" << lanes / 2;
+    } else {
+      fail = true;
+    }
+    if (!fail) return;
   } else if (t == DataType::Bool()) {
     os << "bool";
     return;
@@ -382,6 +405,8 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
     }
   } else if (t.is_float16()) {
     os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
+  } else if (t.is_bfloat16()) {
+    os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
   } else if (t.lanes() > 4 && t.lanes() <= 8) {
     std::string type_name;
     if (t.bits() == 16) {
@@ -427,6 +452,9 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
   } else if (t.is_float16()) {
     stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
            << value << ";\n";
+  } else if (t.is_bfloat16()) {
+    stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]
+           << " = " << value << ";\n";
   } else if (t.lanes() > 4 && t.lanes() <= 8) {
     std::string type_name;
     if (t.bits() == 16) {
@@ -687,7 +715,8 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
     if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
       ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
              op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
-             op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1))
+             op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1) ||
+             op->dtype == DataType::BFloat(16))
           << "Matrix_a and matrix_b only support half or char or unsigned char "
           << "or uint4 or int4 or int1 type for now";
     } else {
@@ -767,6 +796,19 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) {  // NO
     return;
   }
 
+  if (op->dtype.is_bfloat16()) {
+    std::string v = PrintExpr(op->value);
+    os << "make_";
+    PrintType(op->dtype, os);
+    os << '(';
+    for (int i = 0; i < op->lanes / 2; ++i) {
+      if (i != 0) os << ", ";
+      os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
+    }
+    os << ')';
+    return;
+  }
+
   std::string v = PrintExpr(op->value);
   os << "make_";
   PrintType(op->dtype, os);
@@ -836,6 +878,13 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
 }
 
 inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) {  // NOLINT(*)
+  // Type code is kBFloat
+  if (op->dtype.is_bfloat16()) {
+    os << "__float2bfloat16_rn";
+    os << '(' << std::scientific << op->value << 'f' << ')';
+    return;
+  }
+  // Type code is kFloat
   switch (op->dtype.bits()) {
     case 64:
     case 32: {
@@ -938,7 +987,7 @@ void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode*
   // Cast away volatile qualifier for fp16 types. That is, only loads and
   // stores are volatile. The loaded objects are not marked as volatile.
   //
-  if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) {
+  if ((op->dtype.is_float16() || op->dtype.is_bfloat16()) && IsVolatile(op->buffer_var.get())) {
     os << "(";
     PrintType(op->dtype, os);
     os << ")(" << value << ")";
@@ -979,6 +1028,25 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& val
     return;
   }
 
+  if (t.is_bfloat16()) {
+    if (i == 0) {
+      os << "make_";
+      PrintType(t, os);
+      os << '(';
+    }
+    if (i % 2 == 0) {
+      os << "__pack_bfloat162(" << value;
+    } else {
+      os << "," << value << ")";
+      if (i != t.lanes() - 1) {
+        os << ",";
+      } else {
+        os << ")";
+      }
+    }
+    return;
+  }
+
   if (i == 0) {
     os << "make_";
     PrintType(t, os);
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 3cde8e3..2098b8a 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -42,7 +42,7 @@ class CodeGenCUDA final : public CodeGenC {
   void Init(bool output_ssa);
   std::string Finish();
   bool need_include_path() {
-    return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
+    return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
   }
   // override behavior
   void PrintFuncPrefix() final;
@@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC {
   std::string vid_global_barrier_expect_;
   // whether enable fp16
   bool enable_fp16_{false};
+  // whether enable bf16
+  bool enable_bf16_{false};
   // whether enable int8
   bool enable_int8_{false};
   // whether enable warp shuffle intrinsics
diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc
index 5c562f7..965b86c 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -43,6 +43,8 @@ struct CUDAMath {
         default:
           return "";
       }
+    } else if (t.is_bfloat16()) {
+      return 'h' + name;
     }
     return "";
   }
diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h
index f8e92d5..3888f3a 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -311,6 +311,30 @@ static inline __device__ __host__ half htanh(half x) {
 #endif
 )";
 
+static constexpr const char* _cuda_bfloat16_util = R"(
+// Pack two bfloat16 values.
+static inline __device__ __host__ unsigned
+__pack_nv_bfloat162(const nv_bfloat16 x, const nv_bfloat16 y) {
+  unsigned v0 = *((unsigned short *)&x);
+  unsigned v1 = *((unsigned short *)&y);
+  return (v1 << 16) | v0;
+}
+
+// fix undefined fp16 match function
+static inline __device__ __host__ nv_bfloat16 hpow(nv_bfloat16 x, nv_bfloat16 y) {
+  float tmp_x = __bfloat162float(x);
+  float tmp_y = __bfloat162float(y);
+  float result = powf(tmp_x, tmp_y);
+  return __float2bfloat16(result);
+}
+
+static inline __device__ __host__ nv_bfloat16 htanh(nv_bfloat16 x) {
+  float tmp_x = __bfloat162float(x);
+  float result = tanhf(tmp_x);
+  return __float2bfloat16(result);
+}
+)";
+
 static constexpr const char* _cuda_warp_intrinsic_util = R"(
 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
 #define __shfl_sync(mask, var, lane, width) \
diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py
index a228a64..06d7cb4 100644
--- a/tests/python/unittest/test_target_codegen_cuda.py
+++ b/tests/python/unittest/test_target_codegen_cuda.py
@@ -19,7 +19,7 @@ from tvm import te
 import numpy as np
 from tvm import topi
 import unittest
-from tvm.contrib.nvcc import have_fp16, have_int8
+from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
 from tvm.contrib import nvcc
 import tvm.testing
 
@@ -69,6 +69,53 @@ def test_cuda_vectorize_add():
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
+def test_cuda_bf16_vectorize_add():
+    if not have_bf16(tvm.gpu(0).compute_version):
+        print("skip because gpu does not support bf16")
+        return
+    num_thread = 8
+
+    def np_float2np_bf16(arr):
+        """Convert a numpy array of float to a numpy array
+        of bf16 in uint16"""
+        orig = arr.view("<u4")
+        bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
+        return np.right_shift(orig + bias, 16).astype("uint16")
+
+    def np_bf162np_float(arr):
+        """Convert a numpy array of bf16 (uint16) to a numpy array
+        of float"""
+        u32 = np.left_shift(arr.astype("uint32"), 16)
+        return u32.view("<f4")
+
+    def check_cuda(n, lanes):
+        A = te.placeholder((n,), name="A", dtype="bfloat16x%d" % lanes)
+        B = te.compute((n,), lambda i: A[i] + tvm.tir.const(1, A.dtype), name="B")
+        s = te.create_schedule(B.op)
+        xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
+        s[B].bind(xo, bx)
+        s[B].bind(xi, tx)
+        with tvm.transform.PassContext(
+            disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"]
+        ):
+            fun = tvm.build(s, [A, B], "cuda")
+        ctx = tvm.gpu(0)
+        np_a = np.random.uniform(size=(n, lanes)).astype("float32")
+        np_a = np_bf162np_float(np_float2np_bf16(np_a))
+        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_float2np_bf16(np_a))
+        c = tvm.nd.empty((n,), B.dtype, ctx)
+        fun(a, c)
+        c = tvm.nd.empty((n, lanes), "uint16", ctx).copyfrom(c)
+        tvm.testing.assert_allclose(c.asnumpy(), np_float2np_bf16(np_a + 1))
+
+    check_cuda(64, 2)
+    check_cuda(64, 4)
+    check_cuda(64, 6)
+    check_cuda(64, 8)
+
+
+@tvm.testing.requires_gpu
+@tvm.testing.requires_cuda
 def test_cuda_multiply_add():
     num_thread = 8
 
@@ -922,6 +969,7 @@ def test_unrolled_vectorization():
 
 if __name__ == "__main__":
     test_cuda_vectorize_add()
+    test_cuda_bf16_vectorize_add()
     test_cuda_multiply_add()
     test_cuda_vectorize_load()
     test_cuda_make_int8()