You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/27 18:15:09 UTC
[tvm] branch main updated: [TIR] Expose MMA-related PTX builtins (#12623)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 534412896e [TIR] Expose MMA-related PTX builtins (#12623)
534412896e is described below
commit 534412896e6d39ee4f830d63370d02e8e5f09050
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Sat Aug 27 11:14:58 2022 -0700
[TIR] Expose MMA-related PTX builtins (#12623)
Expose MMA-related PTX builtins
This PR exposes the following TIR operation in python:
`ptx_mma`: tested
`ptx_mma_sp`: tested
`mma_store`: add new unittest
`mma_fill`: add new unittest
Co-authored-by: yongwww <yo...@gmail.com>
Co-authored-by: yongwww <yo...@gmail.com>
---
python/tvm/tir/__init__.py | 1 +
python/tvm/tir/op.py | 287 +++++++++++++++++++++++++++++
tests/python/unittest/test_tir_op_types.py | 75 ++++++++
3 files changed, 363 insertions(+)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 4a6f32d03a..8e637d2d65 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -59,6 +59,7 @@ from .op import (
tvm_bmma_sync,
tvm_fill_fragment,
)
+from .op import ptx_mma, ptx_mma_sp, mma_store, mma_fill
from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index e510f68a68..1fd3050c0a 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -831,6 +831,293 @@ def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
)
+def ptx_mma(
+ dtype,
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ operator=None,
+):
+ """TVM intrinsic for ptx tensor core mma instructions
+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
+
+ Parameters
+ ----------
+ dtype : str
+ The data type of the result.
+
+ shape : str
+ The shape of mma fragment.
+
+ A_layout : Literal["row", "col"]
+ The layout of multiplicand fragment A.
+
+ B_layout : Literal["row", "col"]
+ The layout of multiplicand fragment B.
+
+ A_dtype : str
+ The data type of multiplicand fragment A.
+
+ B_dtype : str
+ The data type of multiplicand fragment B.
+
+ C_dtype : str
+ The data type of accumulator fragment C.
+
+ multiplicand_a : Var
+ The multiplicand fragment A variable.
+
+ a_index : Expr
+ The index of multiplicand fragment A.
+
+ multiplicand_b : Var
+ The multiplicand fragment B variable.
+
+ b_index : Expr
+ The index of multiplicand fragment A.
+
+ accumulator : Var
+ The accumulator fragment C variable.
+
+ c_index : Expr
+ The index of accumulator fragment C.
+
+ saturate : bool
+ The optional saturation at the output.
+
+
+ operator : Optional[Literal["xor", "and"]]
+ The 1-bit operator.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ if operator is None:
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ )
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ saturate,
+ operator,
+ )
+
+
+def ptx_mma_sp(
+ dtype,
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ metadata,
+ meta_index,
+ sparse_selector,
+ saturate,
+):
+ """TVM intrinsic for sparse tensor core ptx instructions
+ https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
+
+ Parameters
+ ----------
+ dtype : str
+ The data type of the result.
+
+ shape : str
+ The shape of mma fragment.
+
+ A_layout : Literal["row", "col"]
+ The layout of multiplicand fragment A.
+
+ B_layout : Literal["row", "col"]
+ The layout of multiplicand fragment B.
+
+ A_dtype : str
+ The data type of multiplicand fragment A.
+
+ B_dtype : str
+ The data type of multiplicand fragment B.
+
+ C_dtype : str
+ The data type of multiplicand fragment C.
+
+ multiplicand_a : Var
+ The multiplicand fragment A variable.
+
+ a_index : Expr
+ The index of multiplicand fragment A.
+
+ multiplicand_b : Var
+ The multiplicand fragment B variable.
+
+ b_index : Expr
+ The index of multiplicand fragment B.
+
+ accumulator : Var
+ The accumulator fragment C variable.
+
+ c_index : Expr
+ The index of accumulator fragment C.
+
+ metadata : Expr
+ The metadata of operand.
+
+ meta_index : Expr
+ The metadata index of operand.
+
+ sparse_selector : Expr
+ The sparse selector indicating the thread that stores the metadata.
+
+ saturate : bool
+ The optional saturation at the output.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ dtype,
+ "tir.ptx_mma_sp",
+ shape,
+ A_layout,
+ B_layout,
+ A_dtype,
+ B_dtype,
+ C_dtype,
+ multiplicand_a,
+ a_index,
+ multiplicand_b,
+ b_index,
+ accumulator,
+ c_index,
+ metadata,
+ meta_index,
+ sparse_selector,
+ saturate,
+ )
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+ """TVM intrinsic for storing the result of PTX MMA into a destination pointer
+
+ Parameters
+ ----------
+ dtype : str
+ The data type of the result.
+
+ m : IntImm
+ The shape of mma fragment.
+
+ n : IntImm
+ The shape of mma fragment.
+
+ dst_ptr : Var
+ The destination pointer variable.
+
+ src_ptr : Var
+ The source pointer variable.
+
+ src_offset : Expr
+ The source offset.
+
+ dst_stride : Var
+ The destination stride.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ dtype,
+ "tir.mma_store",
+ m,
+ n,
+ dst_ptr,
+ src_ptr,
+ src_offset,
+ dst_stride,
+ )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+ """TVM intrinsic for zero-initalizing an MMA accumulation registor
+
+ Parameters
+ ----------
+ dtype : str
+ The data type of the result.
+
+ local_size : IntImm
+ The number of elements.
+
+ local_ptr : Var
+ The destination pointer variable.
+
+ offset : Expr
+ The destination offset.
+
+ Returns
+ -------
+ call : PrimExpr
+ The call expression.
+ """
+ return call_intrin(
+ dtype,
+ "tir.mma_fill",
+ local_size,
+ local_ptr,
+ offset,
+ )
+
+
def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
"""TVM intrinsic for ptx load matrix from shared memory
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py
index f8e8de074c..23a264bef7 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -143,6 +143,81 @@ def test_tir_op_tvm_fill_fragment():
assert expr.op.name == "tir.tvm_fill_fragment"
+def test_tir_op_ptx_mma():
+ buffer_a = tir.decl_buffer([32], "int4", scope="local")
+ buffer_b = tir.decl_buffer([16], "uint4", scope="local")
+ buffer_c = tir.decl_buffer([4], "int32", scope="local")
+ expr = tir.ptx_mma(
+ "int32",
+ "m8n8k32",
+ "row",
+ "col",
+ "int4",
+ "uint4",
+ "int32",
+ buffer_a.data,
+ 0,
+ buffer_b.data,
+ 0,
+ buffer_c.data,
+ 0,
+ False,
+ )
+ assert expr.op.name == "tir.ptx_mma"
+
+
+def test_tir_op_ptx_mma_sp():
+ buffer_a = tir.decl_buffer([32], "int4", scope="local")
+ buffer_b = tir.decl_buffer([16], "uint4", scope="local")
+ buffer_c = tir.decl_buffer([4], "int32", scope="local")
+ buffer_d = tir.decl_buffer([1], "uint32", scope="local")
+ expr = tir.ptx_mma_sp(
+ "int32",
+ "m8n8k32",
+ "row",
+ "col",
+ "int4",
+ "uint4",
+ "int32",
+ buffer_a.data,
+ 0,
+ buffer_b.data,
+ 0,
+ buffer_c.data,
+ 0,
+ buffer_d.data,
+ 0,
+ 0,
+ False,
+ )
+ assert expr.op.name == "tir.ptx_mma_sp"
+
+
+def test_tir_op_mma_store():
+ x = tir.Var("x", dtype="int32")
+ y = tir.Var("y", dtype="int32")
+ buffer_w = tir.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1)
+ buffer = tir.decl_buffer(
+ [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[x, y]
+ )
+ expr = tir.mma_store(
+ "int32",
+ 16,
+ 16,
+ buffer.access_ptr("w"),
+ buffer_w.data,
+ buffer_w.elem_offset,
+ x,
+ )
+ assert expr.op.name == "tir.mma_store"
+
+
+def test_tir_op_mma_fill():
+ buffer_w = tir.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1)
+ expr = tir.mma_fill("int32", 8, buffer_w.data, buffer_w.elem_offset)
+ assert expr.op.name == "tir.mma_fill"
+
+
def test_op_ptx_ldmatrix():
buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
buffer_local = tir.decl_buffer([8], "float16", scope="local")