You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/08 19:02:25 UTC

[GitHub] [tvm] csullivan opened a new pull request #8216: [ROCM][TOPI] AutoTVM support for MFMA tensorization available on AMD gfx908

csullivan opened a new pull request #8216:
URL: https://github.com/apache/tvm/pull/8216


   This PR adds TOPI schedules for dense and batch matmul that target the AMDGCN XDLOP `llvm.amdgcn.mfma.f32.16x16x16f16` which performs M=N=K=16 matrix compute.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on pull request #8216: [ROCM][TOPI] AutoTVM support for MFMA tensorization available on AMD gfx908

Posted by GitBox <gi...@apache.org>.
csullivan commented on pull request #8216:
URL: https://github.com/apache/tvm/pull/8216#issuecomment-857115444


   cc @mvermeulen @vinx13 Please review or add others as reviewers as you desire. Thanks!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch closed pull request #8216: [ROCM][TOPI] AutoTVM support for MFMA tensorization available on AMD gfx908

Posted by GitBox <gi...@apache.org>.
jroesch closed pull request #8216:
URL: https://github.com/apache/tvm/pull/8216


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #8216: [ROCM][TOPI] AutoTVM support for MFMA tensorization available on AMD gfx908

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #8216:
URL: https://github.com/apache/tvm/pull/8216#discussion_r647798845



##########
File path: python/tvm/topi/rocm/batch_matmul_mfma.py
##########
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
+"""Compute and Schedule definition for dense tensorcore with cuda backend"""
+from __future__ import absolute_import as _abs
+from tvm import te
+import tvm.autotvm as autotvm
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import (
+    intrin_mfma_load_matrix,
+    intrin_mfma_store_matrix,
+    intrin_mfma_gemm,
+)
+
+
+@autotvm.register_topi_compute("batch_matmul_mfma.rocm")
+def batch_matmul_mfma(cfg, x, y, out_shape=None):
+    """Computes matrix multiplication of `x` and `y` when
+    `x` and `y` are batched matrices via Matrix FMA on ROCM.
+
+    Parameters
+    ----------
+    cfg : ConfigSpace
+        Autotvm tuning config
+    x : tvm.te.Tensor
+        3-D with shape [batch, M, K]
+    y : tvm.te.Tensor
+        3-D with shape [batch, N, K]
+    Returns
+    -------
+    output : tvm.te.Tensor
+        3-D with shape [batch, M, N]
+    """
+    batch, M, _ = get_const_tuple(x.shape)
+    _, N, _ = get_const_tuple(y.shape)
+    if out_shape is not None:
+        assert out_shape[0] == batch, "Input and output batch sizes must match"
+        assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
+    matmul = batch_matmul_mfma_rocm(x, y)
+    return matmul
+
+
+@autotvm.register_topi_schedule("batch_matmul_mfma.rocm")
+def schedule_batch_matmul_mfma(cfg, outs):
+    """Schedule batch_matmul operator using MFMA"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "batch_matmul_mfma":
+            _schedule_batch_matmul_mfma(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def batch_matmul_mfma_rocm(A, B):
+    """Batch Matmul MFMA operator on ROCM"""
+    assert len(A.shape) == 3 and len(B.shape) == 3, "only support 3-dim batch_matmul"
+
+    batch, M, K = get_const_tuple(A.shape)
+    # B is transposed by default in relay
+    _, N, _ = get_const_tuple(B.shape)
+
+    assert M % 16 == 0 and N % 16 == 0 and K % 16 == 0, "M, N, and K each must be a multiple of 16"
+    A_16 = te.compute((batch, M, K), lambda b, i, j: A[b, i, j].astype("float16"))
+    B_16 = te.compute((batch, N, K), lambda b, i, j: B[b, i, j].astype("float16"))
+
+    k = te.reduce_axis((0, K), name="k")
+    C = te.compute(
+        (batch, M, N),
+        lambda b, i, j: te.sum(
+            A_16[b, i, k].astype("float") * B_16[b, j, k].astype("float"), axis=[k]
+        ),
+        name="T_batch_matmul",
+        tag="batch_matmul_mfma",
+    )
+
+    return C
+
+
+def _schedule_batch_matmul_mfma(cfg, s, C):
+    """Schedule batch_matmul operator using MFMA"""
+    A, B = s[C].op.input_tensors
+
+    s[A].compute_inline()
+    s[B].compute_inline()
+
+    # Explicit memory access
+    AS = s.cache_read(A, "shared", [C])
+    BS = s.cache_read(B, "shared", [C])
+    AF = s.cache_read(AS, "local", [C])
+    BF = s.cache_read(BS, "local", [C])
+    CF = s.cache_write(C, "local")
+    CS = s.cache_read(CF, "shared", [C])
+
+    # # Support op fusion
+    # if C.op not in s.outputs:
+    #     s[C].compute_inline()
+    #     C = s.outputs[0].output(0)
+
+    cfg.define_knob("block_row_warps", [1, 2, 4])
+    cfg.define_knob("block_col_warps", [1, 2, 4])
+    cfg.define_knob("warp_row_tiles", [1, 2, 4])
+    cfg.define_knob("warp_col_tiles", [1, 2, 4])
+    cfg.define_knob("chunk", [1, 2, 4, 8])
+    cfg.define_knob("offset", [0, 8])
+    cfg.define_knob("offsetCS", [0, 8])
+    cfg.define_knob("vec", [1, 2, 4, 8])
+
+    warp_size = 64
+    mfma_m = 16
+    mfma_n = 16
+    mfma_k = 16
+    block_row_warps = cfg["block_row_warps"].val
+    block_col_warps = cfg["block_col_warps"].val
+    warp_row_tiles = cfg["warp_row_tiles"].val
+    warp_col_tiles = cfg["warp_col_tiles"].val
+    chunk = cfg["chunk"].val
+    offset = cfg["offset"].val
+    offsetCS = cfg["offsetCS"].val
+    vec = cfg["vec"].val
+
+    # Define the stride for tensorization
+    AS_align = chunk * mfma_k + offset
+    BS_align = chunk * mfma_k + offset
+    CS_align = warp_col_tiles * block_col_warps * mfma_n + offsetCS
+    AS_stride = [AS_align, 1]
+    BS_stride = [BS_align, 1]
+    AF_stride = [mfma_k, 1]
+    BF_stride = [mfma_k, 1]
+    CF_stride = [warp_col_tiles * mfma_n, 1]
+    CS_stride = [CS_align, 1]
+
+    block_x = te.thread_axis("blockIdx.x")
+    block_y = te.thread_axis("blockIdx.y")
+    block_z = te.thread_axis("blockIdx.z")
+    thread_x = te.thread_axis("threadIdx.x")
+    thread_y = te.thread_axis("threadIdx.y")
+    thread_z = te.thread_axis("threadIdx.z")
+
+    # Schedule for batch_matmul computation
+    block_factor_row = mfma_m * warp_row_tiles * block_row_warps
+    block_factor_col = mfma_n * warp_col_tiles * block_col_warps
+    # b, o = C.op.axis
+    b, i, j = s[C].op.axis
+    block_i, i = s[C].split(i, factor=block_factor_row)
+    block_j, j = s[C].split(j, factor=block_factor_col)
+    s[C].reorder(block_i, block_j, i, j)
+    t = s[C].fuse(i, j)
+    t, vi = s[C].split(t, factor=vec)
+    t, tx = s[C].split(t, factor=warp_size)
+    t, ty = s[C].split(t, factor=block_row_warps)
+    t, tz = s[C].split(t, factor=block_col_warps)
+    s[C].bind(b, block_z)
+    s[C].bind(block_i, block_x)
+    s[C].bind(block_j, block_y)
+    s[C].bind(tz, thread_z)
+    s[C].bind(ty, thread_y)
+    s[C].bind(tx, thread_x)
+    s[C].vectorize(vi)
+
+    # Schedule for fragment store
+    s[CS].compute_at(s[C], block_j)
+    _, _m, _n = CS.op.axis
+    s[CS].storage_align(_m, CS_align - 1, CS_align)
+    _m, _mi = s[CS].split(_m, factor=mfma_m)
+    _n, _ni = s[CS].split(_n, factor=mfma_n)
+    _m, _mo = s[CS].split(_m, factor=warp_row_tiles)
+    _n, _no = s[CS].split(_n, factor=warp_col_tiles)
+    s[CS].reorder(_m, _n, _mo, _no, _mi, _ni)
+    # s[CS].compute_at(s[C], block_j)

Review comment:
       remove these lines

##########
File path: tests/python/unittest/test_amd_xdlops.py
##########
@@ -0,0 +1,289 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# 'License'); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import te
+import numpy as np
+from tvm import relay
+
+
+def intrin_mfma_load_matrix(shape, matrix, thread=None, strides=None):

Review comment:
       can we reuse tensor intrinsics defined in topi/rocm/tensor_intrin.py?

##########
File path: python/tvm/topi/rocm/batch_matmul_mfma.py
##########
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-locals, too-many-statements, unused-argument
+"""Compute and Schedule definition for dense tensorcore with cuda backend"""
+from __future__ import absolute_import as _abs
+from tvm import te
+import tvm.autotvm as autotvm
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import (
+    intrin_mfma_load_matrix,
+    intrin_mfma_store_matrix,
+    intrin_mfma_gemm,
+)
+
+
+@autotvm.register_topi_compute("batch_matmul_mfma.rocm")
+def batch_matmul_mfma(cfg, x, y, out_shape=None):
+    """Computes matrix multiplication of `x` and `y` when
+    `x` and `y` are batched matrices via Matrix FMA on ROCM.
+
+    Parameters
+    ----------
+    cfg : ConfigSpace
+        Autotvm tuning config
+    x : tvm.te.Tensor
+        3-D with shape [batch, M, K]
+    y : tvm.te.Tensor
+        3-D with shape [batch, N, K]
+    Returns
+    -------
+    output : tvm.te.Tensor
+        3-D with shape [batch, M, N]
+    """
+    batch, M, _ = get_const_tuple(x.shape)
+    _, N, _ = get_const_tuple(y.shape)
+    if out_shape is not None:
+        assert out_shape[0] == batch, "Input and output batch sizes must match"
+        assert out_shape[1] == M and out_shape[2] == N, "Invalid output shape"
+    matmul = batch_matmul_mfma_rocm(x, y)
+    return matmul
+
+
+@autotvm.register_topi_schedule("batch_matmul_mfma.rocm")
+def schedule_batch_matmul_mfma(cfg, outs):
+    """Schedule batch_matmul operator using MFMA"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "batch_matmul_mfma":
+            _schedule_batch_matmul_mfma(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def batch_matmul_mfma_rocm(A, B):
+    """Batch Matmul MFMA operator on ROCM"""
+    assert len(A.shape) == 3 and len(B.shape) == 3, "only support 3-dim batch_matmul"
+
+    batch, M, K = get_const_tuple(A.shape)
+    # B is transposed by default in relay
+    _, N, _ = get_const_tuple(B.shape)
+
+    assert M % 16 == 0 and N % 16 == 0 and K % 16 == 0, "M, N, and K each must be a multiple of 16"
+    A_16 = te.compute((batch, M, K), lambda b, i, j: A[b, i, j].astype("float16"))
+    B_16 = te.compute((batch, N, K), lambda b, i, j: B[b, i, j].astype("float16"))
+
+    k = te.reduce_axis((0, K), name="k")
+    C = te.compute(
+        (batch, M, N),
+        lambda b, i, j: te.sum(
+            A_16[b, i, k].astype("float") * B_16[b, j, k].astype("float"), axis=[k]
+        ),
+        name="T_batch_matmul",
+        tag="batch_matmul_mfma",
+    )
+
+    return C
+
+
+def _schedule_batch_matmul_mfma(cfg, s, C):
+    """Schedule batch_matmul operator using MFMA"""
+    A, B = s[C].op.input_tensors
+
+    s[A].compute_inline()
+    s[B].compute_inline()
+
+    # Explicit memory access
+    AS = s.cache_read(A, "shared", [C])
+    BS = s.cache_read(B, "shared", [C])
+    AF = s.cache_read(AS, "local", [C])
+    BF = s.cache_read(BS, "local", [C])
+    CF = s.cache_write(C, "local")
+    CS = s.cache_read(CF, "shared", [C])
+
+    # # Support op fusion
+    # if C.op not in s.outputs:

Review comment:
       we need to support fusion here




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #8216: [ROCM][TOPI] AutoTVM support for MFMA tensorization available on AMD gfx908

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #8216:
URL: https://github.com/apache/tvm/pull/8216#discussion_r647910036



##########
File path: python/tvm/topi/rocm/tensor_intrin.py
##########
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unnecessary-lambda, too-many-arguments
+"""MFMA Tensor intrinsics for GFX908."""
+import tvm
+from tvm import te
+
+
+def intrin_mfma_load_matrix(shape, matrix, thread=None, strides_src=None, strides_dst=None):
+    """Intrin function for loading thread registers for mfma tensorization"""
+    M, N, K = shape
+    assert M == 16
+    assert N == 16
+    assert K == 16
+    if matrix in ("A", "BT", "W"):
+        row, col = M, K
+    elif matrix == "B":
+        row, col = K, N
+    output_shape = (row, col)
+
+    A = te.placeholder(output_shape, name=matrix, dtype="float16")
+    BA = tvm.tir.decl_buffer(A.shape, A.dtype, scope="shared", offset_factor=1, strides=strides_src)
+
+    C = te.compute(output_shape, lambda i, j: A[i, j], name="C")
+
+    BC = tvm.tir.decl_buffer(C.shape, C.dtype, scope="local", offset_factor=1, strides=strides_dst)
+
+    def intrin_func(ins, outs):
+        ib = tvm.tir.ir_builder.create()
+
+        BA = ins[0]
+        BC = outs[0]
+
+        tx = thread
+        if tx is None:
+            tx = te.thread_axis("threadIdx.x")
+            ib.scope_attr(tx, "thread_extent", 64)
+
+        blk_td = tx % 16
+        offset = tx // 16
+        # TODO(csullivan): Using offset works, but using tx directly does not, fix this

Review comment:
       How about `tx + 0`? I've hit a similar issue and that was my workaround:
   https://github.com/apache/tvm/blob/f4ec5fd4ae346dbdd8e915c048aeed94b44f6776/python/tvm/topi/cuda/nms.py#L358




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on pull request #8216: [ROCM][TOPI] AutoTVM support for MFMA tensorization available on AMD gfx908

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #8216:
URL: https://github.com/apache/tvm/pull/8216#issuecomment-1016831177


   This PR appears to be out of date, please feel free to reopen it if this is not the case.
   
   As part of the new year we are attempting to triage the project's open pull requests to ensure that code which
   is ready for review and/or merging receives adequate attention.
   
   Thanks again for your contribution, and feel free to reach out to discuss these changes.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] masahi commented on a change in pull request #8216: [ROCM][TOPI] AutoTVM support for MFMA tensorization available on AMD gfx908

Posted by GitBox <gi...@apache.org>.
masahi commented on a change in pull request #8216:
URL: https://github.com/apache/tvm/pull/8216#discussion_r647909253



##########
File path: python/tvm/contrib/rocm.py
##########
@@ -168,3 +168,34 @@ def callback_rocm_bitcode_path(rocdl_dir=None):
             raise RuntimeError("could not find bitcode " + n)
 
     return tvm.runtime.convert(bitcode_files)
+
+
+def parse_compute_version(compute_version):
+    """Parse compute capability string to divide major and minor version
+
+    Parameters
+    ----------
+    compute_version : str
+        GFX version of a GPU (e.g. "6.0")
+
+    Returns
+    -------
+    major : int
+        major version number
+    minor : int
+        minor version number
+    """
+    split_ver = compute_version.split(".")
+    try:
+        major = int(split_ver[0])
+        minor = int(split_ver[1])
+        return major, minor
+    except (IndexError, ValueError) as err:
+        raise RuntimeError("Compute version parsing error: " + str(err))
+
+
+def support_mfma(device):
+    major, minor = parse_compute_version(device.compute_version)
+    if major >= 9 and minor >= 8:

Review comment:
       Since probably mfma is only available on gfx908, how about directly comparing against "908"? Newer arch like navi2 is gfx1030 but it does not have mfma.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org