You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/04/03 17:14:46 UTC
[tvm] branch main updated: [PTX] `ldmatrix` builtin to accelerate copying data from shared memory to warp memory (#10855)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 966d018 [PTX] `ldmatrix` builtin to accelerate copying data from shared memory to warp memory (#10855)
966d018 is described below
commit 966d018da8c553e1870433f4cdedfbc03bfaa39b
Author: Zihao Ye <ex...@outlook.com>
AuthorDate: Sun Apr 3 10:13:39 2022 -0700
[PTX] `ldmatrix` builtin to accelerate copying data from shared memory to warp memory (#10855)
We already have PTX mma and mma.sp builtin support in #9909 and #10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma.
This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
---
include/tvm/tir/builtin.h | 9 ++
src/target/source/codegen_cuda.cc | 26 ++++-
src/target/source/{ptx_mma.cc => ptx.cc} | 126 ++++++++++++++++++++-----
src/target/source/{ptx_mma.h => ptx.h} | 38 +++++---
src/tir/op/builtin.cc | 3 +
tests/python/unittest/test_tir_ptx_ldmatrix.py | 101 ++++++++++++++++++++
6 files changed, 263 insertions(+), 40 deletions(-)
diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h
index c42d44f..b166b16 100644
--- a/include/tvm/tir/builtin.h
+++ b/include/tvm/tir/builtin.h
@@ -623,6 +623,15 @@ TVM_DLL const Op& ptx_mma();
*/
TVM_DLL const Op& ptx_mma_sp();
+/*!
+ * \brief tvm intrinsic for ptx load matrix from shared memory.
+ *
+ * void ptx_ldmatrix(Bool trans, IntImm num, StringImm type,
+ * Var local_ptr, Expr local_offset,
+ * Var smem_ptr, Expr smem_offset);
+ */
+TVM_DLL const Op& ptx_ldmatrix();
+
// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index f74d5cf..d4ec536 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -33,7 +33,7 @@
#include <vector>
#include "literal/cuda_half_t.h"
-#include "ptx_mma.h"
+#include "ptx.h"
namespace tvm {
namespace codegen {
@@ -772,11 +772,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 3: A precision: fp16, fp32, ...
// arg 4: B precision: fp16, fp32, ...
// arg 5: C precision: fp16, fp32, ...
- // arg 6: A multiplicand
+ // arg 6: A multiplicand pointer
// arg 7: A multiplicand index
- // arg 8: B multiplicand
+ // arg 8: B multiplicand pointer
// arg 9: B multiplicand index
- // arg 10: C accumulator
+ // arg 10: C accumulator pointer
// arg 11: C accumulator index
// arg 12: metadata
// arg 13: metadata index
@@ -803,6 +803,24 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
this->stream << asm_code;
+ } else if (op->op.same_as(builtin::ptx_ldmatrix())) {
+ // arg 0: whether the matrix is loaded in column major format or not.
+ // arg 1: number of matrices to load.
+ // arg 2: The data type in the matrix, .b16 is the only accepted data type.
+ // arg 3: pointer to local buffer.
+ // arg 4: The offset of the element to store in the local buffer.
+ // arg 5: pointer to the shared memory buffer to load.
+ // arg 6: The offset of the start element of the row to load in shared memory.
+ ICHECK_EQ(op->args.size(), 7U);
+ bool trans = Downcast<Bool>(op->args[0])->value;
+ int num = Downcast<Integer>(op->args[1])->value;
+ std::string type = Downcast<StringImm>(op->args[2])->value;
+ std::string local_ptr = this->PrintExpr(op->args[3]);
+ std::string local_elem_offset = this->PrintExpr(op->args[4]);
+ std::string smem_ptr = this->PrintExpr(op->args[5]);
+ std::string smem_elem_offset = this->PrintExpr(op->args[6]);
+ this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
+ smem_ptr, smem_elem_offset);
} else {
CodeGenC::VisitExpr_(op, os);
}
diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx.cc
similarity index 81%
rename from src/target/source/ptx_mma.cc
rename to src/target/source/ptx.cc
index d04c018..02a98ff 100644
--- a/src/target/source/ptx_mma.cc
+++ b/src/target/source/ptx.cc
@@ -18,10 +18,10 @@
*/
/*!
- * \file ptx_mma.cc
+ * \file ptx.cc
*/
-#include "ptx_mma.h"
+#include "ptx.h"
#include <algorithm>
#include <string>
@@ -60,13 +60,18 @@ enum class DataType : int {
kFloat32 = 13,
kTensorFloat32 = 14,
kFloat64 = 15,
- kBit1 = 16
+ kBit1 = 16,
+ kBit8 = 17,
+ kBit16 = 18,
+ kBit32 = 19,
+ kBit64 = 20,
};
-static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16",
- ".s32", ".u32", ".s64", ".u64", ".f16", ".bf16",
- ".f16x2", ".f32", ".tf32", ".f64", ".b1"};
-static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1};
+static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32",
+ ".u32", ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32",
+ ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"};
+static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16,
+ 16, 32, 32, 32, 64, 1, 8, 16, 32, 64};
/*!
* \brief Create PTX data type from string.
@@ -106,6 +111,14 @@ inline DataType DTypeFromString(const std::string str) {
return DataType::kFloat64;
} else if (str == "int1" || str == ".b1") {
return DataType::kBit1;
+ } else if (str == ".b8") {
+ return DataType::kBit8;
+ } else if (str == ".b16") {
+ return DataType::kBit16;
+ } else if (str == ".b32") {
+ return DataType::kBit32;
+ } else if (str == ".b64") {
+ return DataType::kBit64;
} else {
LOG(FATAL) << "Unrecognized PTX data type " << str;
return DataType(0);
@@ -360,6 +373,7 @@ inline FragAttrs GetFragAttrs(DataType dtype) {
case DataType::kUInt4:
case DataType::kInt8:
case DataType::kUInt8:
+ case DataType::kBit16:
case DataType::kFloat16: // .f16x2 register
case DataType::kBFloat16:
case DataType::kTensorFloat32:
@@ -508,9 +522,9 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
const std::string& B_layout, const std::string& A_dtype,
const std::string& B_dtype, const std::string& C_dtype,
- const std::string& a_ref, const std::string& a_offset,
- const std::string& b_ref, const std::string& b_offset,
- const std::string& c_ref, const std::string& c_offset,
+ const std::string& a_ptr, const std::string& a_elem_offset,
+ const std::string& b_ptr, const std::string& b_elem_offset,
+ const std::string& c_ptr, const std::string& c_elem_offset,
const std::string& metadata, const std::string& metadata_offset,
const std::string& sparsity_selector, const std::string& bit_op,
bool sparse, bool saturate) {
@@ -525,7 +539,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
std::string asm_code = R"(
{
__asm__ __volatile__(
- "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
+ "mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}"
"{templates};\n"
: {outputs}
: {inputs});
@@ -537,30 +551,92 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
// replace patterns
Replacer replacer;
- replacer.register_rule("{sparse}", sparse ? ".sp" : "");
- replacer.register_rule("{shape}", shape);
- replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
- replacer.register_rule("{alayout}", A_layout);
- replacer.register_rule("{blayout}", B_layout);
- replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
- replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
- replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
- replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
- replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
+ replacer.register_rule("{.sparse}", sparse ? ".sp" : "");
+ replacer.register_rule("{.shape}", "." + shape);
+ replacer.register_rule("{.saturate}", saturate ? ".satfinite" : "");
+ replacer.register_rule("{.alayout}", "." + A_layout);
+ replacer.register_rule("{.blayout}", "." + B_layout);
+ replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a));
+ replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b));
+ replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c));
+ replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c));
+ replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
replacer.register_rule("{templates}", templates_str);
replacer.register_rule("{outputs}", outputs_str);
replacer.register_rule("{inputs}", inputs_str);
asm_code = replacer.rewrite(asm_code);
replacer.empty_rules();
- replacer.register_rule("A", a_ref + " + " + a_offset);
- replacer.register_rule("B", b_ref + " + " + b_offset);
- replacer.register_rule("C", c_ref + " + " + c_offset);
- replacer.register_rule("D", c_ref + " + " + c_offset);
+ replacer.register_rule("A", a_ptr + " + " + a_elem_offset);
+ replacer.register_rule("B", b_ptr + " + " + b_elem_offset);
+ replacer.register_rule("C", c_ptr + " + " + c_elem_offset);
+ replacer.register_rule("D", c_ptr + " + " + c_elem_offset);
replacer.register_rule("E", metadata + " + " + metadata_offset);
replacer.register_rule("F", sparsity_selector);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}
+inline std::tuple<std::string, std::string> GetLoadMatrixOperands(
+ int num, const std::string& local_ptr, const std::string& local_elem_offset) {
+ std::stringstream templates, outputs;
+ int arg_counter = 0;
+ // generate templates
+ templates << "{%" << arg_counter++;
+ for (int i = 1; i < num; ++i) {
+ templates << ", %" << arg_counter++;
+ }
+ templates << "}, [%" << arg_counter++ << "]";
+ // generate outputs
+ std::string ptr_type = "(unsigned *)";
+ for (int i = 0; i < num; ++i) {
+ if (i != 0) {
+ outputs << ", ";
+ }
+ outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))["
+ << i << "])";
+ }
+ return std::make_tuple(templates.str(), outputs.str());
+}
+
+std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
+ const std::string& local_ptr,
+ const std::string& local_elem_offset,
+ const std::string& smem_ptr,
+ const std::string& smem_elem_offset) {
+ CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices.";
+ ptx::DataType data_type = ptx::DTypeFromString(type);
+ CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16.";
+ std::string asm_code = R"(
+ {
+ unsigned int addr;
+ __asm__ __volatile__(
+ "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+ : "=r"(addr)
+ : "l"((void *)({smem_addr}))
+ );
+ __asm__ __volatile__(
+ "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
+ "{templates};\n"
+ : {outputs}
+ : "r"(addr)
+ );
+ }
+)";
+ std::string templates_str, outputs_str;
+ std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);
+
+ Replacer replacer;
+ replacer.register_rule("{.shape}", ".m8n8");
+ replacer.register_rule("{.num}", ".x" + std::to_string(num));
+ replacer.register_rule("{.trans}", trans ? ".trans" : "");
+ replacer.register_rule("{.ss}", ".shared");
+ replacer.register_rule("{.type}", ptx::DTypeToString(data_type));
+ replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset);
+ replacer.register_rule("{templates}", templates_str);
+ replacer.register_rule("{outputs}", outputs_str);
+ asm_code = replacer.rewrite(asm_code);
+ return asm_code;
+}
+
} // namespace codegen
} // namespace tvm
diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx.h
similarity index 63%
rename from src/target/source/ptx_mma.h
rename to src/target/source/ptx.h
index 728478c..c4255d7 100644
--- a/src/target/source/ptx_mma.h
+++ b/src/target/source/ptx.h
@@ -18,11 +18,11 @@
*/
/*!
- * \file ptx_mma.h
- * \brief MMA code generation with inlined PTX code.
+ * \file ptx.h
+ * \brief Code generation with inlined PTX code.
*/
-#ifndef TVM_TARGET_SOURCE_PTX_MMA_H_
-#define TVM_TARGET_SOURCE_PTX_MMA_H_
+#ifndef TVM_TARGET_SOURCE_PTX_H_
+#define TVM_TARGET_SOURCE_PTX_H_
#include <tvm/runtime/logging.h>
@@ -40,11 +40,11 @@ namespace codegen {
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
- * \param a_ref Pointer to buffer A.
+ * \param a_ptr Pointer to buffer A.
* \param a_offset The offset of element in A.
- * \param b_ref Pointer to buffer B.
+ * \param b_ptr Pointer to buffer B.
* \param b_offset The offset of element in B.
- * \param c_ref Pointer to buffer C.
+ * \param c_ptr Pointer to buffer C.
* \param c_offset The offset of element in C.
* \param metadata Pointer to metadata buffer (only used for sparse mma).
* \param metadata_offset The offset of element in metadata.
@@ -56,14 +56,30 @@ namespace codegen {
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
const std::string& B_layout, const std::string& A_dtype,
const std::string& B_dtype, const std::string& C_dtype,
- const std::string& a_ref, const std::string& a_offset,
- const std::string& b_ref, const std::string& b_offset,
- const std::string& c_ref, const std::string& c_offset,
+ const std::string& a_ptr, const std::string& a_offset,
+ const std::string& b_ptr, const std::string& b_offset,
+ const std::string& c_ptr, const std::string& c_offset,
const std::string& metadata, const std::string& metadata_offset,
const std::string& sparsity_selector, const std::string& bit_op,
bool sparse, bool saturate);
+/*!
+ * \brief Print ldmatrix assembly string given parameters.
+ * \param trans: whether the matrix is loaded in column major format or not.
+ * \param num: number of matrices to load.
+ * \param type: The data type in the matrix, .b16 is the only accepted data type.
+ * \param local_ptr: pointer to local buffer.
+ * \param local_elem_offset: The offset of the element to store in the local buffer.
+ * \param smem_ptr: pointer to the shared memory buffer to load.
+ * \param smem_elem_offset: The offset of the start element of the row to load in shared memory.
+ */
+std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
+ const std::string& local_ptr,
+ const std::string& local_elem_offset,
+ const std::string& smem_ptr,
+ const std::string& smem_elem_offset);
+
} // namespace codegen
} // namespace tvm
-#endif // TVM_TARGET_SOURCE_PTX_MMA_H_
+#endif // TVM_TARGET_SOURCE_PTX_H_
diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc
index 465428e..4e8d83d 100644
--- a/src/tir/op/builtin.cc
+++ b/src/tir/op/builtin.cc
@@ -244,6 +244,9 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
+ .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
+
TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
diff --git a/tests/python/unittest/test_tir_ptx_ldmatrix.py b/tests/python/unittest/test_tir_ptx_ldmatrix.py
new file mode 100644
index 0000000..f718082
--- /dev/null
+++ b/tests/python/unittest/test_tir_ptx_ldmatrix.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.script import tir as T
+import numpy as np
+import tvm.testing
+
+
+@T.prim_func
+def ptx_ldmatrix(
+ A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"], num: T.int32, trans: T.uint8
+) -> None:
+ T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
+ bx = T.env_thread("blockIdx.x")
+ tx = T.env_thread("threadIdx.x")
+ T.launch_thread(bx, 1)
+ T.launch_thread(tx, 32)
+ with T.block():
+ A_shared = T.alloc_buffer([16, 16], "float16", scope="shared")
+ A_local = T.alloc_buffer([8], "float16", scope="local")
+
+ for i in range(8):
+ A_shared[i * 2 + tx // 16, tx % 16] = A[i * 2 + tx // 16, tx % 16]
+
+ T.evaluate(
+ T.ptx_ldmatrix(
+ trans,
+ num,
+ ".b16",
+ A_local.data,
+ 0,
+ A_shared.data,
+ 16 * (tx % 16) + 8 * (tx // 16),
+ dtype="float16",
+ )
+ )
+
+ for k in range(2):
+ for j in range(2):
+ for i in range(2):
+ B[8 * j + tx // 4, 8 * k + (tx % 4) * 2 + i] = A_local[4 * k + 2 * j + i]
+
+
+@tvm.testing.requires_cuda
+def test_ptx_ldmatrix():
+ f = ptx_ldmatrix
+ _, _, param_num, param_trans = f.params
+ arch = tvm.contrib.nvcc.get_target_compute_version()
+ major, minor = tvm.contrib.nvcc.parse_compute_version(arch)
+ if major * 10 + minor < 75:
+ # Require at least SM75
+ return
+ for num in [1, 2, 4]:
+ for trans in [False, True]:
+ mod = tvm.build(f.specialize({param_num: num, param_trans: trans}), target="cuda")
+ A_np = np.random.rand(16, 16).astype("float16")
+ A_mask_np = np.zeros_like(A_np)
+ if num == 1:
+ if trans:
+ A_mask_np[:8, :8] = A_np[:8, :8].T
+ else:
+ A_mask_np[:8, :8] = A_np[:8, :8]
+ elif num == 2:
+ if trans:
+ A_mask_np[:8, :8] = A_np[:8, :8].T
+ A_mask_np[8:16, :8] = A_np[8:16, :8].T
+ else:
+ A_mask_np[:16, :8] = A_np[:16, :8]
+ else: # num == 4
+ if trans:
+ A_mask_np[:8, :8] = A_np[:8, :8].T
+ A_mask_np[8:16, :8] = A_np[8:16, :8].T
+ A_mask_np[:8, 8:16] = A_np[:8, 8:16].T
+ A_mask_np[8:16, 8:16] = A_np[8:16, 8:16].T
+ else:
+ A_mask_np[:16, :16] = A_np[:16, :16]
+ B_np = np.zeros((16, 16)).astype("float16")
+ dev = tvm.cuda(0)
+ A_nd = tvm.nd.array(A_np, device=dev)
+ B_nd = tvm.nd.array(B_np, device=dev)
+ mod(A_nd, B_nd)
+ tvm.testing.assert_allclose(B_nd.numpy(), A_mask_np)
+
+
+if __name__ == "__main__":
+ test_ptx_ldmatrix()