You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/20 08:50:40 UTC

[tvm] branch main updated: [PTX] Intrinsics for async copy from global to shared (SM80) (#11368)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e99d30d63 [PTX] Intrinsics for async copy from global to shared (SM80) (#11368)
7e99d30d63 is described below

commit 7e99d30d63a0c20eedc247c723e2318686b815cf
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Fri May 20 17:50:32 2022 +0900

    [PTX] Intrinsics for async copy from global to shared (SM80) (#11368)
    
    * registor ptx builtin for async copy
    
    * add basic codegen
    
    * add test
    
    * update codegen
    
    * wip
    
    * codegen bug fixed, test working
    
    * add commit group
    
    * add doc
---
 include/tvm/tir/builtin.h                      | 19 +++++++
 src/target/source/codegen_cuda.cc              | 12 +++++
 src/target/source/ptx.cc                       | 26 ++++++++++
 src/target/source/ptx.h                        | 13 +++++
 src/tir/op/builtin.cc                          |  9 ++++
 tests/python/unittest/test_tir_ptx_cp_async.py | 70 ++++++++++++++++++++++++++
 6 files changed, 149 insertions(+)

diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index b166b16b77..f33432645c 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -632,6 +632,25 @@ TVM_DLL const Op& ptx_mma_sp();
  */
 TVM_DLL const Op& ptx_ldmatrix();
 
+/*!
+ * \brief tvm intrinsics for ptx async copy from global to shared memory
+ *
+ * void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t
+ * bytes);
+ *
+ */
+TVM_DLL const Op& ptx_cp_async();
+
+/*!
+ * \brief tvm intrinsics for ptx async copy commit and wait.
+ *
+ * void ptx_commit_group();
+ * void ptx_wait_group(int num);
+ *
+ */
+TVM_DLL const Op& ptx_commit_group();
+TVM_DLL const Op& ptx_wait_group();
+
 // TODO(tvm-team) replace the usage of the vector operations by Shuffle.
 /*!
  * \brief Get the high level half of the vector
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index d4ec536fb0..7459d4c250 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -821,6 +821,18 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
     std::string smem_elem_offset = this->PrintExpr(op->args[6]);
     this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
                                             smem_ptr, smem_elem_offset);
+  } else if (op->op.same_as(builtin::ptx_cp_async())) {
+    std::string dst = this->PrintExpr(op->args[0]);
+    std::string dst_offset = this->PrintExpr(op->args[1]);
+    std::string src = this->PrintExpr(op->args[2]);
+    std::string src_offset = this->PrintExpr(op->args[3]);
+    std::string size = this->PrintExpr(op->args[4]);
+    this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
+  } else if (op->op.same_as(builtin::ptx_commit_group())) {
+    this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
+  } else if (op->op.same_as(builtin::ptx_wait_group())) {
+    std::string N = this->PrintExpr(op->args[0]);
+    this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n";
   } else {
     CodeGenC::VisitExpr_(op, os);
   }
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index 02a98ffbba..71c68baed6 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -638,5 +638,31 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
   return asm_code;
 }
 
+std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
+                                 const std::string& shared_elem_offset,
+                                 const std::string& global_ptr,
+                                 const std::string& global_elem_offset, const std::string& bytes) {
+  std::string asm_code = R"(
+  {
+    unsigned int addr;
+    __asm__ __volatile__(
+      "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+      : "=r"(addr)
+      : "l"((void *)({smem_addr}))
+    );
+    __asm__ __volatile__(
+      "cp.async.cg.shared.global [%0], [%1], %2;"
+       :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
+    );
+  }
+)";
+  Replacer replacer;
+  replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
+  replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
+  replacer.register_rule("{bytes}", bytes);
+  asm_code = replacer.rewrite(asm_code);
+  return asm_code;
+}
+
 }  // namespace codegen
 }  // namespace tvm
diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h
index c4255d737a..c811a1b9c1 100644
--- a/src/target/source/ptx.h
+++ b/src/target/source/ptx.h
@@ -79,6 +79,19 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
                                     const std::string& smem_ptr,
                                     const std::string& smem_elem_offset);
 
+/*!
+ * \brief Print ptx cp.async assembly string given parameters.
+ * \param shared_ptr: The pointer to the destination shared memory.
+ * \param shared_elem_offset: The offset into the shared memory.
+ * \param global_ptr: The pointer to the global memory.
+ * \param global_elem_offset: The offset into the global memory.
+ * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
+ */
+std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
+                                 const std::string& shared_elem_offset,
+                                 const std::string& global_ptr,
+                                 const std::string& global_elem_offset, const std::string& bytes);
+
 }  // namespace codegen
 }  // namespace tvm
 
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 4e8d83dd32..0415d1bbec 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -247,6 +247,15 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
 TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
 
+TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
+    .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
 TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
     .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
 
diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py
new file mode 100644
index 0000000000..17b6088550
--- /dev/null
+++ b/tests/python/unittest/test_tir_ptx_cp_async.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+@T.prim_func
+def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None:
+    T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+    bx = T.env_thread("blockIdx.x")
+    tx = T.env_thread("threadIdx.x")
+    T.launch_thread(bx, 1)
+    T.launch_thread(tx, 32)
+    with T.block():
+        A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
+        T.reads(A[0:32, 0:128])
+        T.writes(B[0:32, 0:128])
+
+        for i in range(16):
+            T.evaluate(
+                T.ptx_cp_async(
+                    A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16"
+                )
+            )
+
+        # TODO(masahi): Remove dtype requirement from TVMScript parser
+        T.evaluate(T.ptx_commit_group(dtype="float16"))
+        T.evaluate(T.ptx_wait_group(0, dtype="float16"))
+
+        for i in range(128):
+            B[tx, i] = A_shared[tx, i]
+
+
+@tvm.testing.requires_cuda
+def test_ptx_cp_async():
+    f = ptx_cp_async
+    arch = tvm.contrib.nvcc.get_target_compute_version()
+    major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+    if major < 8:
+        # Require at least SM80
+        return
+
+    mod = tvm.build(f, target="cuda")
+    A_np = np.random.rand(32, 128).astype("float16")
+    B_np = np.zeros((32, 128)).astype("float16")
+    dev = tvm.cuda(0)
+    A_nd = tvm.nd.array(A_np, device=dev)
+    B_nd = tvm.nd.array(B_np, device=dev)
+    mod(A_nd, B_nd)
+    tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
+if __name__ == "__main__":
+    test_ptx_cp_async()