You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/04/03 17:14:46 UTC

[tvm] branch main updated: [PTX] `ldmatrix` builtin to accelerate copying data from shared memory to warp memory (#10855)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 966d018  [PTX] `ldmatrix` builtin to accelerate copying data from shared memory to warp memory (#10855)
966d018 is described below

commit 966d018da8c553e1870433f4cdedfbc03bfaa39b
Author: Zihao Ye <ex...@outlook.com>
AuthorDate: Sun Apr 3 10:13:39 2022 -0700

    [PTX] `ldmatrix` builtin to accelerate copying data from shared memory to warp memory (#10855)
    
    We already have PTX mma and mma.sp builtin support in #9909  and #10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma.
    
    This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
---
 include/tvm/tir/builtin.h                      |   9 ++
 src/target/source/codegen_cuda.cc              |  26 ++++-
 src/target/source/{ptx_mma.cc => ptx.cc}       | 126 ++++++++++++++++++++-----
 src/target/source/{ptx_mma.h => ptx.h}         |  38 +++++---
 src/tir/op/builtin.cc                          |   3 +
 tests/python/unittest/test_tir_ptx_ldmatrix.py | 101 ++++++++++++++++++++
 6 files changed, 263 insertions(+), 40 deletions(-)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index c42d44f..b166b16 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -623,6 +623,15 @@ TVM_DLL const Op& ptx_mma();
  */
 TVM_DLL const Op& ptx_mma_sp();
 
+/*!
+ * \brief tvm intrinsic for ptx load matrix from shared memory.
+ *
+ * void ptx_ldmatrix(Bool trans, IntImm num, StringImm type,
+ *                   Var local_ptr, Expr local_offset,
+ *                   Var smem_ptr, Expr smem_offset);
+ */
+TVM_DLL const Op& ptx_ldmatrix();
+
 // TODO(tvm-team) replace the usage of the vector operations by Shuffle.
 /*!
  * \brief Get the high level half of the vector
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index f74d5cf..d4ec536 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -33,7 +33,7 @@
 #include <vector>
 
 #include "literal/cuda_half_t.h"
-#include "ptx_mma.h"
+#include "ptx.h"
 
 namespace tvm {
 namespace codegen {
@@ -772,11 +772,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
     // arg 3: A precision: fp16, fp32, ...
     // arg 4: B precision: fp16, fp32, ...
     // arg 5: C precision: fp16, fp32, ...
-    // arg 6: A multiplicand
+    // arg 6: A multiplicand pointer
     // arg 7: A multiplicand index
-    // arg 8: B multiplicand
+    // arg 8: B multiplicand pointer
     // arg 9: B multiplicand index
-    // arg 10: C accumulator
+    // arg 10: C accumulator pointer
     // arg 11: C accumulator index
     // arg 12: metadata
     // arg 13: metadata index
@@ -803,6 +803,24 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
         shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
         c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
     this->stream << asm_code;
+  } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
+    // arg 0: whether the matrix is loaded in column major format or not.
+    // arg 1: number of matrices to load.
+    // arg 2: The data type in the matrix, .b16 is the only accepted data type.
+    // arg 3: pointer to local buffer.
+    // arg 4: The offset of the element to store in the local buffer.
+    // arg 5: pointer to the shared memory buffer to load.
+    // arg 6: The offset of the start element of the row to load in shared memory.
+    ICHECK_EQ(op->args.size(), 7U);
+    bool trans = Downcast<Bool>(op->args[0])->value;
+    int num = Downcast<Integer>(op->args[1])->value;
+    std::string type = Downcast<StringImm>(op->args[2])->value;
+    std::string local_ptr = this->PrintExpr(op->args[3]);
+    std::string local_elem_offset = this->PrintExpr(op->args[4]);
+    std::string smem_ptr = this->PrintExpr(op->args[5]);
+    std::string smem_elem_offset = this->PrintExpr(op->args[6]);
+    this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
+                                            smem_ptr, smem_elem_offset);
   } else {
     CodeGenC::VisitExpr_(op, os);
   }
diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx.cc
similarity index 81%
rename from src/target/source/ptx_mma.cc
rename to src/target/source/ptx.cc
index d04c018..02a98ff 100644
--- a/src/target/source/ptx_mma.cc
+++ b/src/target/source/ptx.cc
@@ -18,10 +18,10 @@
  */
 
 /*!
- * \file ptx_mma.cc
+ * \file ptx.cc
  */
 
-#include "ptx_mma.h"
+#include "ptx.h"
 
 #include <algorithm>
 #include <string>
@@ -60,13 +60,18 @@ enum class DataType : int {
   kFloat32 = 13,
   kTensorFloat32 = 14,
   kFloat64 = 15,
-  kBit1 = 16
+  kBit1 = 16,
+  kBit8 = 17,
+  kBit16 = 18,
+  kBit32 = 19,
+  kBit64 = 20,
 };
 
-static const char* dtype_str[] = {".s4",    ".u4",  ".s8",   ".u8",  ".s16", ".u16",
-                                  ".s32",   ".u32", ".s64",  ".u64", ".f16", ".bf16",
-                                  ".f16x2", ".f32", ".tf32", ".f64", ".b1"};
-static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1};
+static const char* dtype_str[] = {".s4",   ".u4",  ".s8",  ".u8",  ".s16",  ".u16",   ".s32",
+                                  ".u32",  ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32",
+                                  ".tf32", ".f64", ".b1",  ".b8",  ".b16",  ".b32",   ".b64"};
+static const uint32_t num_bits[] = {4,  4,  8,  8,  16, 16, 32, 32, 64, 64, 16,
+                                    16, 32, 32, 32, 64, 1,  8,  16, 32, 64};
 
 /*!
  * \brief Create PTX data type from string.
@@ -106,6 +111,14 @@ inline DataType DTypeFromString(const std::string str) {
     return DataType::kFloat64;
   } else if (str == "int1" || str == ".b1") {
     return DataType::kBit1;
+  } else if (str == ".b8") {
+    return DataType::kBit8;
+  } else if (str == ".b16") {
+    return DataType::kBit16;
+  } else if (str == ".b32") {
+    return DataType::kBit32;
+  } else if (str == ".b64") {
+    return DataType::kBit64;
   } else {
     LOG(FATAL) << "Unrecognized PTX data type " << str;
     return DataType(0);
@@ -360,6 +373,7 @@ inline FragAttrs GetFragAttrs(DataType dtype) {
     case DataType::kUInt4:
     case DataType::kInt8:
     case DataType::kUInt8:
+    case DataType::kBit16:
     case DataType::kFloat16:  // .f16x2 register
     case DataType::kBFloat16:
     case DataType::kTensorFloat32:
@@ -508,9 +522,9 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
 std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
                              const std::string& B_layout, const std::string& A_dtype,
                              const std::string& B_dtype, const std::string& C_dtype,
-                             const std::string& a_ref, const std::string& a_offset,
-                             const std::string& b_ref, const std::string& b_offset,
-                             const std::string& c_ref, const std::string& c_offset,
+                             const std::string& a_ptr, const std::string& a_elem_offset,
+                             const std::string& b_ptr, const std::string& b_elem_offset,
+                             const std::string& c_ptr, const std::string& c_elem_offset,
                              const std::string& metadata, const std::string& metadata_offset,
                              const std::string& sparsity_selector, const std::string& bit_op,
                              bool sparse, bool saturate) {
@@ -525,7 +539,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
   std::string asm_code = R"(
   {
     __asm__ __volatile__(
-      "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
+      "mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}"
       "{templates};\n"
       : {outputs}
       : {inputs});
@@ -537,30 +551,92 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
 
   // replace patterns
   Replacer replacer;
-  replacer.register_rule("{sparse}", sparse ? ".sp" : "");
-  replacer.register_rule("{shape}", shape);
-  replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
-  replacer.register_rule("{alayout}", A_layout);
-  replacer.register_rule("{blayout}", B_layout);
-  replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
-  replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
-  replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
-  replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
-  replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
+  replacer.register_rule("{.sparse}", sparse ? ".sp" : "");
+  replacer.register_rule("{.shape}", "." + shape);
+  replacer.register_rule("{.saturate}", saturate ? ".satfinite" : "");
+  replacer.register_rule("{.alayout}", "." + A_layout);
+  replacer.register_rule("{.blayout}", "." + B_layout);
+  replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a));
+  replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b));
+  replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c));
+  replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c));
+  replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
   replacer.register_rule("{templates}", templates_str);
   replacer.register_rule("{outputs}", outputs_str);
   replacer.register_rule("{inputs}", inputs_str);
   asm_code = replacer.rewrite(asm_code);
   replacer.empty_rules();
-  replacer.register_rule("A", a_ref + " + " + a_offset);
-  replacer.register_rule("B", b_ref + " + " + b_offset);
-  replacer.register_rule("C", c_ref + " + " + c_offset);
-  replacer.register_rule("D", c_ref + " + " + c_offset);
+  replacer.register_rule("A", a_ptr + " + " + a_elem_offset);
+  replacer.register_rule("B", b_ptr + " + " + b_elem_offset);
+  replacer.register_rule("C", c_ptr + " + " + c_elem_offset);
+  replacer.register_rule("D", c_ptr + " + " + c_elem_offset);
   replacer.register_rule("E", metadata + " + " + metadata_offset);
   replacer.register_rule("F", sparsity_selector);
   asm_code = replacer.rewrite(asm_code);
   return asm_code;
 }
 
+inline std::tuple<std::string, std::string> GetLoadMatrixOperands(
+    int num, const std::string& local_ptr, const std::string& local_elem_offset) {
+  std::stringstream templates, outputs;
+  int arg_counter = 0;
+  // generate templates
+  templates << "{%" << arg_counter++;
+  for (int i = 1; i < num; ++i) {
+    templates << ", %" << arg_counter++;
+  }
+  templates << "}, [%" << arg_counter++ << "]";
+  // generate outputs
+  std::string ptr_type = "(unsigned *)";
+  for (int i = 0; i < num; ++i) {
+    if (i != 0) {
+      outputs << ", ";
+    }
+    outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))["
+            << i << "])";
+  }
+  return std::make_tuple(templates.str(), outputs.str());
+}
+
+std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
+                                    const std::string& local_ptr,
+                                    const std::string& local_elem_offset,
+                                    const std::string& smem_ptr,
+                                    const std::string& smem_elem_offset) {
+  CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices.";
+  ptx::DataType data_type = ptx::DTypeFromString(type);
+  CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16.";
+  std::string asm_code = R"(
+  {
+    unsigned int addr;
+    __asm__ __volatile__(
+      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+      : "=r"(addr)
+      : "l"((void *)({smem_addr}))
+    );
+    __asm__ __volatile__(
+      "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
+      "{templates};\n"
+      : {outputs}
+      : "r"(addr)
+    );
+  }
+)";
+  std::string templates_str, outputs_str;
+  std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);
+
+  Replacer replacer;
+  replacer.register_rule("{.shape}", ".m8n8");
+  replacer.register_rule("{.num}", ".x" + std::to_string(num));
+  replacer.register_rule("{.trans}", trans ? ".trans" : "");
+  replacer.register_rule("{.ss}", ".shared");
+  replacer.register_rule("{.type}", ptx::DTypeToString(data_type));
+  replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset);
+  replacer.register_rule("{templates}", templates_str);
+  replacer.register_rule("{outputs}", outputs_str);
+  asm_code = replacer.rewrite(asm_code);
+  return asm_code;
+}
+
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx.h
similarity index 63%
rename from src/target/source/ptx_mma.h
rename to src/target/source/ptx.h
index 728478c..c4255d7 100644
--- a/src/target/source/ptx_mma.h
+++ b/src/target/source/ptx.h
@@ -18,11 +18,11 @@
  */
 
 /*!
- * \file ptx_mma.h
- * \brief MMA code generation with inlined PTX code.
+ * \file ptx.h
+ * \brief Code generation with inlined PTX code.
  */
-#ifndef TVM_TARGET_SOURCE_PTX_MMA_H_
-#define TVM_TARGET_SOURCE_PTX_MMA_H_
+#ifndef TVM_TARGET_SOURCE_PTX_H_
+#define TVM_TARGET_SOURCE_PTX_H_
 
 #include <tvm/runtime/logging.h>
 
@@ -40,11 +40,11 @@ namespace codegen {
  * \param A_dtype The data type of multiplicand A.
  * \param B_dtype The data type of multiplicand B.
  * \param C_dtype The data type of multiplicand C.
- * \param a_ref Pointer to buffer A.
+ * \param a_ptr Pointer to buffer A.
  * \param a_offset The offset of element in A.
- * \param b_ref Pointer to buffer B.
+ * \param b_ptr Pointer to buffer B.
  * \param b_offset The offset of element in B.
- * \param c_ref Pointer to buffer C.
+ * \param c_ptr Pointer to buffer C.
  * \param c_offset The offset of element in C.
  * \param metadata Pointer to metadata buffer (only used for sparse mma).
  * \param metadata_offset The offset of element in metadata.
@@ -56,14 +56,30 @@ namespace codegen {
 std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
                              const std::string& B_layout, const std::string& A_dtype,
                              const std::string& B_dtype, const std::string& C_dtype,
-                             const std::string& a_ref, const std::string& a_offset,
-                             const std::string& b_ref, const std::string& b_offset,
-                             const std::string& c_ref, const std::string& c_offset,
+                             const std::string& a_ptr, const std::string& a_offset,
+                             const std::string& b_ptr, const std::string& b_offset,
+                             const std::string& c_ptr, const std::string& c_offset,
                              const std::string& metadata, const std::string& metadata_offset,
                              const std::string& sparsity_selector, const std::string& bit_op,
                              bool sparse, bool saturate);
 
+/*!
+ * \brief Print ldmatrix assembly string given parameters.
+ * \param trans: whether the matrix is loaded in column major format or not.
+ * \param num: number of matrices to load.
+ * \param type: The data type in the matrix, .b16 is the only accepted data type.
+ * \param local_ptr: pointer to local buffer.
+ * \param local_elem_offset: The offset of the element to store in the local buffer.
+ * \param smem_ptr: pointer to the shared memory buffer to load.
+ * \param smem_elem_offset: The offset of the start element of the row to load in shared memory.
+ */
+std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
+                                    const std::string& local_ptr,
+                                    const std::string& local_elem_offset,
+                                    const std::string& smem_ptr,
+                                    const std::string& smem_elem_offset);
+
 }  // namespace codegen
 }  // namespace tvm
 
-#endif  // TVM_TARGET_SOURCE_PTX_MMA_H_
+#endif  // TVM_TARGET_SOURCE_PTX_H_
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 465428e..4e8d83d 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -244,6 +244,9 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
 TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
diff --git a/tests/python/unittest/test_tir_ptx_ldmatrix.py b/tests/python/unittest/test_tir_ptx_ldmatrix.py
new file mode 100644
index 0000000..f718082
--- /dev/null
+++ b/tests/python/unittest/test_tir_ptx_ldmatrix.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+@T.prim_func
+def ptx_ldmatrix(
+    A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"], num: T.int32, trans: T.uint8
+) -> None:
+    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+    bx = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(bx, 1)
+    T.launch_thread(tx, 32)
+    with T.block():
+        A_shared = T.alloc_buffer([16, 16], "float16", scope="shared")
+        A_local = T.alloc_buffer([8], "float16", scope="local")
+
+        for i in range(8):
+            A_shared[i * 2 + tx // 16, tx % 16] = A[i * 2 + tx // 16, tx % 16]
+
+        T.evaluate(
+            T.ptx_ldmatrix(
+                trans,
+                num,
+                ".b16",
+                A_local.data,
+                0,
+                A_shared.data,
+                16 * (tx % 16) + 8 * (tx // 16),
+                dtype="float16",
+            )
+        )
+
+        for k in range(2):
+            for j in range(2):
+                for i in range(2):
+                    B[8 * j + tx // 4, 8 * k + (tx % 4) * 2 + i] = A_local[4 * k + 2 * j + i]
+
+
+@tvm.testing.requires_cuda
+def test_ptx_ldmatrix():
+    f = ptx_ldmatrix
+    _, _, param_num, param_trans = f.params
+    arch = tvm.contrib.nvcc.get_target_compute_version()
+    major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
+    if major * 10 + minor < 75:
+        # Require at least SM75
+        return
+    for num in [1, 2, 4]:
+        for trans in [False, True]:
+            mod = tvm.build(f.specialize({param_num: num, param_trans: trans}), target="cuda")
+            A_np = np.random.rand(16, 16).astype("float16")
+            A_mask_np = np.zeros_like(A_np)
+            if num == 1:
+                if trans:
+                    A_mask_np[:8, :8] = A_np[:8, :8].T
+                else:
+                    A_mask_np[:8, :8] = A_np[:8, :8]
+            elif num == 2:
+                if trans:
+                    A_mask_np[:8, :8] = A_np[:8, :8].T
+                    A_mask_np[8:16, :8] = A_np[8:16, :8].T
+                else:
+                    A_mask_np[:16, :8] = A_np[:16, :8]
+            else:  # num == 4
+                if trans:
+                    A_mask_np[:8, :8] = A_np[:8, :8].T
+                    A_mask_np[8:16, :8] = A_np[8:16, :8].T
+                    A_mask_np[:8, 8:16] = A_np[:8, 8:16].T
+                    A_mask_np[8:16, 8:16] = A_np[8:16, 8:16].T
+                else:
+                    A_mask_np[:16, :16] = A_np[:16, :16]
+            B_np = np.zeros((16, 16)).astype("float16")
+            dev = tvm.cuda(0)
+            A_nd = tvm.nd.array(A_np, device=dev)
+            B_nd = tvm.nd.array(B_np, device=dev)
+            mod(A_nd, B_nd)
+            tvm.testing.assert_allclose(B_nd.numpy(), A_mask_np)
+
+
+if __name__ == "__main__":
+    test_ptx_ldmatrix()