You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/20 08:50:40 UTC
[tvm] branch main updated: [PTX] Intrinsics for async copy from global to shared (SM80) (#11368)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7e99d30d63 [PTX] Intrinsics for async copy from global to shared (SM80) (#11368)
7e99d30d63 is described below
commit 7e99d30d63a0c20eedc247c723e2318686b815cf
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Fri May 20 17:50:32 2022 +0900
[PTX] Intrinsics for async copy from global to shared (SM80) (#11368)
* registor ptx builtin for async copy
* add basic codegen
* add test
* update codegen
* wip
* codegen bug fixed, test working
* add commit group
* add doc
---
include/tvm/tir/builtin.h | 19 +++++++
src/target/source/codegen_cuda.cc | 12 +++++
src/target/source/ptx.cc | 26 ++++++++++
src/target/source/ptx.h | 13 +++++
src/tir/op/builtin.cc | 9 ++++
tests/python/unittest/test_tir_ptx_cp_async.py | 70 ++++++++++++++++++++++++++
6 files changed, 149 insertions(+)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index b166b16b77..f33432645c 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -632,6 +632,25 @@ TVM_DLL const Op& ptx_mma_sp();
*/
TVM_DLL const Op& ptx_ldmatrix();
+/*!
+ * \brief tvm intrinsics for ptx async copy from global to shared memory
+ *
+ * void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t
+ * bytes);
+ *
+ */
+TVM_DLL const Op& ptx_cp_async();
+
+/*!
+ * \brief tvm intrinsics for ptx async copy commit and wait.
+ *
+ * void ptx_commit_group();
+ * void ptx_wait_group(int num);
+ *
+ */
+TVM_DLL const Op& ptx_commit_group();
+TVM_DLL const Op& ptx_wait_group();
+
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index d4ec536fb0..7459d4c250 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -821,6 +821,18 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
+ } else if (op->op.same_as(builtin::ptx_cp_async())) {
+ std::string dst = this->PrintExpr(op->args[0]);
+ std::string dst_offset = this->PrintExpr(op->args[1]);
+ std::string src = this->PrintExpr(op->args[2]);
+ std::string src_offset = this->PrintExpr(op->args[3]);
+ std::string size = this->PrintExpr(op->args[4]);
+ this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
+ } else if (op->op.same_as(builtin::ptx_commit_group())) {
+ this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
+ } else if (op->op.same_as(builtin::ptx_wait_group())) {
+ std::string N = this->PrintExpr(op->args[0]);
+ this->stream << "__asm__ __volatile__(\"cp.async.wait_group " + N + ";\");\n\n";
} else {
CodeGenC::VisitExpr_(op, os);
}
diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc
index 02a98ffbba..71c68baed6 100644
--- a/src/target/source/ptx.cc
+++ b/src/target/source/ptx.cc
@@ -638,5 +638,31 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
return asm_code;
}
+std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
+ const std::string& shared_elem_offset,
+ const std::string& global_ptr,
+ const std::string& global_elem_offset, const std::string& bytes) {
+ std::string asm_code = R"(
+ {
+ unsigned int addr;
+ __asm__ __volatile__(
+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+ : "=r"(addr)
+ : "l"((void *)({smem_addr}))
+ );
+ __asm__ __volatile__(
+ "cp.async.cg.shared.global [%0], [%1], %2;"
+ :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
+ );
+ }
+)";
+ Replacer replacer;
+ replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
+ replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
+ replacer.register_rule("{bytes}", bytes);
+ asm_code = replacer.rewrite(asm_code);
+ return asm_code;
+}
+
} // namespace codegen
} // namespace tvm
diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h
index c4255d737a..c811a1b9c1 100644
--- a/src/target/source/ptx.h
+++ b/src/target/source/ptx.h
@@ -79,6 +79,19 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
const std::string& smem_ptr,
const std::string& smem_elem_offset);
+/*!
+ * \brief Print ptx cp.async assembly string given parameters.
+ * \param shared_ptr: The pointer to the destination shared memory.
+ * \param shared_elem_offset: The offset into the shared memory.
+ * \param global_ptr: The pointer to the global memory.
+ * \param global_elem_offset: The offset into the global memory.
+ * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
+ */
+std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
+ const std::string& shared_elem_offset,
+ const std::string& global_ptr,
+ const std::string& global_elem_offset, const std::string& bytes);
+
} // namespace codegen
} // namespace tvm
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 4e8d83dd32..0415d1bbec 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -247,6 +247,15 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
+TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py
new file mode 100644
index 0000000000..17b6088550
--- /dev/null
+++ b/tests/python/unittest/test_tir_ptx_cp_async.py
@@ -0,0 +1,70 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+@T.prim_func
+def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None:
+ T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+ bx = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(bx, 1)
+ T.launch_thread(tx, 32)
+ with T.block():
+ A_shared = T.alloc_buffer([32, 128], "float16", scope="shared")
+ T.reads(A[0:32, 0:128])
+ T.writes(B[0:32, 0:128])
+
+ for i in range(16):
+ T.evaluate(
+ T.ptx_cp_async(
+ A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16"
+ )
+ )
+
+ # TODO(masahi): Remove dtype requirement from TVMScript parser
+ T.evaluate(T.ptx_commit_group(dtype="float16"))
+ T.evaluate(T.ptx_wait_group(0, dtype="float16"))
+
+ for i in range(128):
+ B[tx, i] = A_shared[tx, i]
+
+
+@tvm.testing.requires_cuda
+def test_ptx_cp_async():
+ f = ptx_cp_async
+ arch = tvm.contrib.nvcc.get_target_compute_version()
+ major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
+ if major < 8:
+ # Require at least SM80
+ return
+
+ mod = tvm.build(f, target="cuda")
+ A_np = np.random.rand(32, 128).astype("float16")
+ B_np = np.zeros((32, 128)).astype("float16")
+ dev = tvm.cuda(0)
+ A_nd = tvm.nd.array(A_np, device=dev)
+ B_nd = tvm.nd.array(B_np, device=dev)
+ mod(A_nd, B_nd)
+ tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
+if __name__ == "__main__":
+ test_ptx_cp_async()