You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/27 18:15:09 UTC

[tvm] branch main updated: [TIR] Expose MMA-related PTX builtins (#12623)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 534412896e [TIR] Expose MMA-related PTX builtins (#12623)
534412896e is described below

commit 534412896e6d39ee4f830d63370d02e8e5f09050
Author: Yaxing Cai <ca...@gmail.com>
AuthorDate: Sat Aug 27 11:14:58 2022 -0700

    [TIR] Expose MMA-related PTX builtins (#12623)
    
    Expose MMA-related PTX builtins
    
    This PR exposes the following TIR operation in python:
    
    `ptx_mma`: tested
    `ptx_mma_sp`: tested
    `mma_store`: add new unittest
    `mma_fill`: add new unittest
    
    Co-authored-by: yongwww <yo...@gmail.com>
    
    Co-authored-by: yongwww <yo...@gmail.com>
---
 python/tvm/tir/__init__.py                 |   1 +
 python/tvm/tir/op.py                       | 287 +++++++++++++++++++++++++++++
 tests/python/unittest/test_tir_op_types.py |  75 ++++++++
 3 files changed, 363 insertions(+)

diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 4a6f32d03a..8e637d2d65 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -59,6 +59,7 @@ from .op import (
     tvm_bmma_sync,
     tvm_fill_fragment,
 )
+from .op import ptx_mma, ptx_mma_sp, mma_store, mma_fill
 from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group
 from .op import vectorlow, vectorhigh, vectorcombine
 from .op import infinity, reinterpret
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index e510f68a68..1fd3050c0a 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -831,6 +831,293 @@ def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
     )
 
 
+def ptx_mma(
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    saturate,
+    operator=None,
+):
+    """TVM intrinsic for ptx tensor core mma instructions
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    shape : str
+        The shape of mma fragment.
+
+    A_layout : Literal["row", "col"]
+        The layout of multiplicand fragment A.
+
+    B_layout : Literal["row", "col"]
+        The layout of multiplicand fragment B.
+
+    A_dtype : str
+        The data type of multiplicand fragment A.
+
+    B_dtype : str
+        The data type of multiplicand fragment B.
+
+    C_dtype : str
+        The data type of accumulator fragment C.
+
+    multiplicand_a : Var
+        The multiplicand fragment A variable.
+
+    a_index : Expr
+        The index of multiplicand fragment A.
+
+    multiplicand_b : Var
+        The multiplicand fragment B variable.
+
+    b_index : Expr
+        The index of multiplicand fragment A.
+
+    accumulator : Var
+        The accumulator fragment C variable.
+
+    c_index : Expr
+        The index of accumulator fragment C.
+
+    saturate : bool
+        The optional saturation at the output.
+
+
+    operator : Optional[Literal["xor", "and"]]
+        The 1-bit operator.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    if operator is None:
+        return call_intrin(
+            dtype,
+            "tir.ptx_mma",
+            shape,
+            A_layout,
+            B_layout,
+            A_dtype,
+            B_dtype,
+            C_dtype,
+            multiplicand_a,
+            a_index,
+            multiplicand_b,
+            b_index,
+            accumulator,
+            c_index,
+            saturate,
+        )
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        saturate,
+        operator,
+    )
+
+
+def ptx_mma_sp(
+    dtype,
+    shape,
+    A_layout,
+    B_layout,
+    A_dtype,
+    B_dtype,
+    C_dtype,
+    multiplicand_a,
+    a_index,
+    multiplicand_b,
+    b_index,
+    accumulator,
+    c_index,
+    metadata,
+    meta_index,
+    sparse_selector,
+    saturate,
+):
+    """TVM intrinsic for sparse tensor core ptx instructions
+    https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    shape : str
+        The shape of mma fragment.
+
+    A_layout : Literal["row", "col"]
+        The layout of multiplicand fragment A.
+
+    B_layout : Literal["row", "col"]
+        The layout of multiplicand fragment B.
+
+    A_dtype : str
+        The data type of multiplicand fragment A.
+
+    B_dtype : str
+        The data type of multiplicand fragment B.
+
+    C_dtype : str
+        The data type of multiplicand fragment C.
+
+    multiplicand_a : Var
+        The multiplicand fragment A variable.
+
+    a_index : Expr
+        The index of multiplicand fragment A.
+
+    multiplicand_b : Var
+        The multiplicand fragment B variable.
+
+    b_index : Expr
+        The index of multiplicand fragment B.
+
+    accumulator : Var
+        The accumulator fragment C variable.
+
+    c_index : Expr
+        The index of accumulator fragment C.
+
+    metadata : Expr
+        The metadata of operand.
+
+    meta_index : Expr
+        The metadata index of operand.
+
+    sparse_selector : Expr
+        The sparse selector indicating the thread that stores the metadata.
+
+    saturate : bool
+        The optional saturation at the output.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        dtype,
+        "tir.ptx_mma_sp",
+        shape,
+        A_layout,
+        B_layout,
+        A_dtype,
+        B_dtype,
+        C_dtype,
+        multiplicand_a,
+        a_index,
+        multiplicand_b,
+        b_index,
+        accumulator,
+        c_index,
+        metadata,
+        meta_index,
+        sparse_selector,
+        saturate,
+    )
+
+
+def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride):
+    """TVM intrinsic for storing the result of PTX MMA into a destination pointer
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    m : IntImm
+        The shape of mma fragment.
+
+    n : IntImm
+        The shape of mma fragment.
+
+    dst_ptr : Var
+        The destination pointer variable.
+
+    src_ptr : Var
+        The source pointer variable.
+
+    src_offset : Expr
+        The source offset.
+
+    dst_stride : Var
+        The destination stride.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        dtype,
+        "tir.mma_store",
+        m,
+        n,
+        dst_ptr,
+        src_ptr,
+        src_offset,
+        dst_stride,
+    )
+
+
+def mma_fill(dtype, local_size, local_ptr, offset):
+    """TVM intrinsic for zero-initalizing an MMA accumulation registor
+
+    Parameters
+    ----------
+    dtype : str
+        The data type of the result.
+
+    local_size : IntImm
+        The number of elements.
+
+    local_ptr : Var
+        The destination pointer variable.
+
+    offset : Expr
+        The destination offset.
+
+    Returns
+    -------
+    call : PrimExpr
+        The call expression.
+    """
+    return call_intrin(
+        dtype,
+        "tir.mma_fill",
+        local_size,
+        local_ptr,
+        offset,
+    )
+
+
 def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
     """TVM intrinsic for ptx load matrix from shared memory
     https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py
index f8e8de074c..23a264bef7 100644
--- a/tests/python/unittest/test_tir_op_types.py
+++ b/tests/python/unittest/test_tir_op_types.py
@@ -143,6 +143,81 @@ def test_tir_op_tvm_fill_fragment():
     assert expr.op.name == "tir.tvm_fill_fragment"
 
 
+def test_tir_op_ptx_mma():
+    buffer_a = tir.decl_buffer([32], "int4", scope="local")
+    buffer_b = tir.decl_buffer([16], "uint4", scope="local")
+    buffer_c = tir.decl_buffer([4], "int32", scope="local")
+    expr = tir.ptx_mma(
+        "int32",
+        "m8n8k32",
+        "row",
+        "col",
+        "int4",
+        "uint4",
+        "int32",
+        buffer_a.data,
+        0,
+        buffer_b.data,
+        0,
+        buffer_c.data,
+        0,
+        False,
+    )
+    assert expr.op.name == "tir.ptx_mma"
+
+
+def test_tir_op_ptx_mma_sp():
+    buffer_a = tir.decl_buffer([32], "int4", scope="local")
+    buffer_b = tir.decl_buffer([16], "uint4", scope="local")
+    buffer_c = tir.decl_buffer([4], "int32", scope="local")
+    buffer_d = tir.decl_buffer([1], "uint32", scope="local")
+    expr = tir.ptx_mma_sp(
+        "int32",
+        "m8n8k32",
+        "row",
+        "col",
+        "int4",
+        "uint4",
+        "int32",
+        buffer_a.data,
+        0,
+        buffer_b.data,
+        0,
+        buffer_c.data,
+        0,
+        buffer_d.data,
+        0,
+        0,
+        False,
+    )
+    assert expr.op.name == "tir.ptx_mma_sp"
+
+
+def test_tir_op_mma_store():
+    x = tir.Var("x", dtype="int32")
+    y = tir.Var("y", dtype="int32")
+    buffer_w = tir.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1)
+    buffer = tir.decl_buffer(
+        [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[x, y]
+    )
+    expr = tir.mma_store(
+        "int32",
+        16,
+        16,
+        buffer.access_ptr("w"),
+        buffer_w.data,
+        buffer_w.elem_offset,
+        x,
+    )
+    assert expr.op.name == "tir.mma_store"
+
+
+def test_tir_op_mma_fill():
+    buffer_w = tir.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1)
+    expr = tir.mma_fill("int32", 8, buffer_w.data, buffer_w.elem_offset)
+    assert expr.op.name == "tir.mma_fill"
+
+
 def test_op_ptx_ldmatrix():
     buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
     buffer_local = tir.decl_buffer([8], "float16", scope="local")