You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/22 05:39:21 UTC

[GitHub] [tvm] Meteorix opened a new pull request #7146: [CUDA]batch_matmul tensorcore schedule

Meteorix opened a new pull request #7146:
URL: https://github.com/apache/tvm/pull/7146


   Add batch_matmul tensorcore schedule for bert inference. It shows better performance than cublas batch_matmul kernel.
   
   @jcf94 @merrymercy could you help review this pr?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r547085365



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -657,6 +657,20 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
             name="batch_matmul_cublas.cuda",
             plevel=15,
         )
+    if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+        x, y = inputs
+        B, M, K = get_const_tuple(x.shape)
+        B, N, K = get_const_tuple(y.shape)
+        # "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+        if ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \
+                (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \
+                (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)):

Review comment:
       Will it be better to also add data type check here or use some other user defined options?
   TensorCore needs to be computed in float16, but I'm not sure if this will bring any loss in precision if we just try to transform all float32 batch_matmul ops to compute in lower precision.
   Besides, TensorCore can also support datatype like int8 in some higher cuda versions.

##########
File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py
##########
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
+"""cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import intrin_wmma_load_matrix_A, \
+        intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
+
+@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
+def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
+    """batch matmul tensorcore operator on cuda"""
+    # todo: deal with out_shape for broadcast, liuxin.ai
+    return batch_matmul_tensorcore_cuda(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
+def schedule_batch_matmul_tensorcore(cfg, outs):
+    """Schedule for batch_matmul operator using Tensorcore
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of batch_matmul
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule(cfg, s, C):
+        A, B = s[C].op.input_tensors
+        batch, m_dim, k_dim = get_const_tuple(A.shape)
+        batch, n_dim, k_dim = get_const_tuple(B.shape)
+        out_dtype = C.dtype
+        # inline astype fp16
+        s[A].compute_inline()
+        s[B].compute_inline()
+
+        # Explicit memory access
+        AS = s.cache_read(A, 'shared', [C])
+        BS = s.cache_read(B, 'shared', [C])
+        AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+        BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+        CF = s.cache_write(C, 'wmma.accumulator')
+        CS = s.cache_read(CF, 'shared', [C])
+
+        # fallback support
+        target = tvm.target.Target.current()
+        if cfg.is_fallback:
+            ref_log = autotvm.tophub.load_reference_log(
+                target.kind.name, target.model, 'batch_matmul_tensorcore.cuda')
+            cfg.fallback_with_reference_log(ref_log)
+
+        # Deal with op fusion, such as bias/relu and slice after padding
+        if C.op not in s.outputs and "injective" in s.outputs[0].tag:
+            s[C].compute_inline()
+            C = s.outputs[0].output(0)
+
+        # create tuning space
+        cfg.define_knob("block_row_warps", [1, 2, 4])
+        cfg.define_knob("block_col_warps", [1, 2, 4])
+        cfg.define_knob("warp_row_tiles", [1, 2, 4])
+        cfg.define_knob("warp_col_tiles", [1, 2, 4])
+        cfg.define_knob("chunk", [1, 2, 4, 8])
+        cfg.define_knob("offset", [0, 8])
+        cfg.define_knob("offsetCS", [0, 8])
+        cfg.define_knob("vec", [1, 2, 4, 8])
+
+        # Ensure that the default parameters are applicable when autotvm is not in use
+        if (m_dim % 32 == 0 and n_dim % 8 == 0):
+            cfg.define_knob("wmma_m", [32, 16, 8])
+        elif (m_dim % 16 == 0 and n_dim % 16 == 0):
+            cfg.define_knob("wmma_m", [16, 8, 32])
+        elif (m_dim % 8 == 0 and n_dim % 32 == 0):
+            cfg.define_knob("wmma_m", [8, 16, 32])
+
+        warp_size = 32
+        wmma_k = 16
+        block_row_warps = cfg["block_row_warps"].val
+        block_col_warps = cfg["block_col_warps"].val
+        warp_row_tiles = cfg["warp_row_tiles"].val
+        warp_col_tiles = cfg["warp_col_tiles"].val
+        chunk = cfg["chunk"].val
+        offset = cfg["offset"].val
+        offsetCS = cfg["offsetCS"].val
+        wmma_m = cfg["wmma_m"].val
+        vec = cfg["vec"].val
+
+        if wmma_m == 16:
+            wmma_n = 16
+        elif wmma_m == 8:
+            wmma_n = 32
+        elif wmma_m == 32:
+            wmma_n = 8
+
+        # Define the stride of intrin functions
+        AS_align = chunk * wmma_k + offset
+        BS_align = chunk * wmma_k + offset
+        CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+        AS_stride = [AS_align, 1]
+        BS_stride = [BS_align, 1]
+        AF_stride = [wmma_k, 1]
+        BF_stride = [wmma_k, 1]
+        CF_stride = [warp_col_tiles * wmma_n, 1]
+        CS_stride = [CS_align, 1]
+
+        block_x = te.thread_axis('blockIdx.x')
+        block_y = te.thread_axis('blockIdx.y')
+        block_z = te.thread_axis('blockIdx.z')
+        thread_x = te.thread_axis('threadIdx.x')
+        thread_y = te.thread_axis('threadIdx.y')
+        thread_z = te.thread_axis('threadIdx.z')
+
+        # Schedule for dense computation
+        block_factor_m = wmma_m * warp_row_tiles * block_row_warps
+        block_factor_n = wmma_n * warp_col_tiles * block_col_warps
+        b, m, n = C.op.axis
+        block_i, bc = s[C].split(m, factor=block_factor_m)
+        block_j, oc = s[C].split(n, factor=block_factor_n)
+        s[C].reorder(b, block_i, block_j, bc, oc)
+        t = s[C].fuse(bc, oc)
+        t, vi = s[C].split(t, factor=vec)
+        t, tx = s[C].split(t, factor=warp_size)
+        t, ty = s[C].split(t, factor=block_row_warps)
+        t, tz = s[C].split(t, factor=block_col_warps)
+        s[C].bind(block_i, block_x)
+        s[C].bind(block_j, block_y)
+        s[C].bind(b, block_z)
+        s[C].bind(tz, thread_z)
+        s[C].bind(ty, thread_y)
+        s[C].bind(tx, thread_x)
+        s[C].vectorize(vi)
+
+        # Schedule for wmma store
+        s[CS].compute_at(s[C], block_j)
+        bs, bb, oo = CS.op.axis
+        s[CS].storage_align(bb, CS_align - 1, CS_align)
+        bb, bbi = s[CS].split(bb, factor=wmma_m)
+        oo, ooi = s[CS].split(oo, factor=wmma_n)
+        bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+        oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+        s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi)
+
+        # Schedule for wmma computation
+        s[CF].compute_at(s[CS], oo)
+        bs, warp_i, warp_j = CF.op.axis
+        warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+        warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+        k, = CF.op.reduce_axis
+        k, _k = s[CF].split(k, factor=wmma_k)
+        ko, ki = s[CF].split(k, factor=chunk)
+        s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+        # Schedule for  wmma_matrix_a load
+        s[AF].compute_at(s[CF], ki)
+        bs, b, i = AF.op.axis
+        b, b_ii = s[AF].split(b, factor=wmma_m)
+        i, i_jj = s[AF].split(i, factor=wmma_k)
+        s[AF].reorder(bs, b, i, b_ii, i_jj)
+
+        # Schedule for  wmma_matrix_b load
+        s[BF].compute_at(s[CF], ki)
+        bs, o, i = BF.op.axis
+        o, o_ii = s[BF].split(o, factor=wmma_n)
+        i, i_ii = s[BF].split(i, factor=wmma_k)
+        s[BF].reorder(bs, o, i, o_ii, i_ii)
+
+        # Schedule for A's(B's) shared memory load
+        def shared_shedule(stage, strides):
+            s[stage].compute_at(s[CF], ko)
+            bs, xo, yo = stage.op.axis
+            s[stage].storage_align(xo, strides - 1, strides)
+            t = s[stage].fuse(xo, yo)
+            t, vi = s[stage].split(t, factor=vec)
+            t, tx = s[stage].split(t, factor=warp_size)
+            t, ty = s[stage].split(t, factor=block_row_warps)
+            _, tz = s[stage].split(t, factor=block_col_warps)
+            s[stage].bind(ty, thread_y)
+            s[stage].bind(tz, thread_z)
+            s[stage].bind(tx, thread_x)
+            s[stage].vectorize(vi)
+
+        shared_shedule(AS, AS_align)
+        shared_shedule(BS, BS_align)
+
+        shape = (wmma_m, wmma_n, wmma_k)
+        in_dtype = 'float16'

Review comment:
       Same concerns about the data type as above.
   It's fine for this PR, but will be better to add more check or just put some comments saying that the TensorCore needs to use a special data type, then if some one meets any trouble, they can know how to check.

##########
File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py
##########
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
+"""cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import intrin_wmma_load_matrix_A, \
+        intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
+
+@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
+def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
+    """batch matmul tensorcore operator on cuda"""
+    # todo: deal with out_shape for broadcast, liuxin.ai
+    return batch_matmul_tensorcore_cuda(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
+def schedule_batch_matmul_tensorcore(cfg, outs):
+    """Schedule for batch_matmul operator using Tensorcore
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of batch_matmul
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule(cfg, s, C):
+        A, B = s[C].op.input_tensors
+        batch, m_dim, k_dim = get_const_tuple(A.shape)
+        batch, n_dim, k_dim = get_const_tuple(B.shape)
+        out_dtype = C.dtype
+        # inline astype fp16
+        s[A].compute_inline()
+        s[B].compute_inline()
+
+        # Explicit memory access
+        AS = s.cache_read(A, 'shared', [C])
+        BS = s.cache_read(B, 'shared', [C])
+        AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+        BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+        CF = s.cache_write(C, 'wmma.accumulator')
+        CS = s.cache_read(CF, 'shared', [C])
+
+        # fallback support
+        target = tvm.target.Target.current()
+        if cfg.is_fallback:
+            ref_log = autotvm.tophub.load_reference_log(
+                target.kind.name, target.model, 'batch_matmul_tensorcore.cuda')
+            cfg.fallback_with_reference_log(ref_log)
+
+        # Deal with op fusion, such as bias/relu and slice after padding
+        if C.op not in s.outputs and "injective" in s.outputs[0].tag:
+            s[C].compute_inline()
+            C = s.outputs[0].output(0)
+
+        # create tuning space
+        cfg.define_knob("block_row_warps", [1, 2, 4])
+        cfg.define_knob("block_col_warps", [1, 2, 4])
+        cfg.define_knob("warp_row_tiles", [1, 2, 4])
+        cfg.define_knob("warp_col_tiles", [1, 2, 4])
+        cfg.define_knob("chunk", [1, 2, 4, 8])
+        cfg.define_knob("offset", [0, 8])
+        cfg.define_knob("offsetCS", [0, 8])
+        cfg.define_knob("vec", [1, 2, 4, 8])
+
+        # Ensure that the default parameters are applicable when autotvm is not in use
+        if (m_dim % 32 == 0 and n_dim % 8 == 0):
+            cfg.define_knob("wmma_m", [32, 16, 8])
+        elif (m_dim % 16 == 0 and n_dim % 16 == 0):
+            cfg.define_knob("wmma_m", [16, 8, 32])
+        elif (m_dim % 8 == 0 and n_dim % 32 == 0):
+            cfg.define_knob("wmma_m", [8, 16, 32])
+
+        warp_size = 32
+        wmma_k = 16
+        block_row_warps = cfg["block_row_warps"].val
+        block_col_warps = cfg["block_col_warps"].val
+        warp_row_tiles = cfg["warp_row_tiles"].val
+        warp_col_tiles = cfg["warp_col_tiles"].val
+        chunk = cfg["chunk"].val
+        offset = cfg["offset"].val
+        offsetCS = cfg["offsetCS"].val
+        wmma_m = cfg["wmma_m"].val
+        vec = cfg["vec"].val
+
+        if wmma_m == 16:
+            wmma_n = 16
+        elif wmma_m == 8:
+            wmma_n = 32
+        elif wmma_m == 32:
+            wmma_n = 8
+
+        # Define the stride of intrin functions
+        AS_align = chunk * wmma_k + offset
+        BS_align = chunk * wmma_k + offset
+        CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+        AS_stride = [AS_align, 1]
+        BS_stride = [BS_align, 1]
+        AF_stride = [wmma_k, 1]
+        BF_stride = [wmma_k, 1]
+        CF_stride = [warp_col_tiles * wmma_n, 1]
+        CS_stride = [CS_align, 1]
+
+        block_x = te.thread_axis('blockIdx.x')
+        block_y = te.thread_axis('blockIdx.y')
+        block_z = te.thread_axis('blockIdx.z')
+        thread_x = te.thread_axis('threadIdx.x')
+        thread_y = te.thread_axis('threadIdx.y')
+        thread_z = te.thread_axis('threadIdx.z')
+
+        # Schedule for dense computation
+        block_factor_m = wmma_m * warp_row_tiles * block_row_warps
+        block_factor_n = wmma_n * warp_col_tiles * block_col_warps
+        b, m, n = C.op.axis
+        block_i, bc = s[C].split(m, factor=block_factor_m)
+        block_j, oc = s[C].split(n, factor=block_factor_n)
+        s[C].reorder(b, block_i, block_j, bc, oc)
+        t = s[C].fuse(bc, oc)
+        t, vi = s[C].split(t, factor=vec)
+        t, tx = s[C].split(t, factor=warp_size)
+        t, ty = s[C].split(t, factor=block_row_warps)
+        t, tz = s[C].split(t, factor=block_col_warps)
+        s[C].bind(block_i, block_x)
+        s[C].bind(block_j, block_y)
+        s[C].bind(b, block_z)
+        s[C].bind(tz, thread_z)
+        s[C].bind(ty, thread_y)
+        s[C].bind(tx, thread_x)
+        s[C].vectorize(vi)
+
+        # Schedule for wmma store
+        s[CS].compute_at(s[C], block_j)
+        bs, bb, oo = CS.op.axis
+        s[CS].storage_align(bb, CS_align - 1, CS_align)
+        bb, bbi = s[CS].split(bb, factor=wmma_m)
+        oo, ooi = s[CS].split(oo, factor=wmma_n)
+        bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+        oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+        s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi)
+
+        # Schedule for wmma computation
+        s[CF].compute_at(s[CS], oo)
+        bs, warp_i, warp_j = CF.op.axis
+        warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+        warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+        k, = CF.op.reduce_axis
+        k, _k = s[CF].split(k, factor=wmma_k)
+        ko, ki = s[CF].split(k, factor=chunk)
+        s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+        # Schedule for  wmma_matrix_a load
+        s[AF].compute_at(s[CF], ki)
+        bs, b, i = AF.op.axis
+        b, b_ii = s[AF].split(b, factor=wmma_m)
+        i, i_jj = s[AF].split(i, factor=wmma_k)
+        s[AF].reorder(bs, b, i, b_ii, i_jj)
+
+        # Schedule for  wmma_matrix_b load
+        s[BF].compute_at(s[CF], ki)
+        bs, o, i = BF.op.axis
+        o, o_ii = s[BF].split(o, factor=wmma_n)
+        i, i_ii = s[BF].split(i, factor=wmma_k)
+        s[BF].reorder(bs, o, i, o_ii, i_ii)
+
+        # Schedule for A's(B's) shared memory load
+        def shared_shedule(stage, strides):
+            s[stage].compute_at(s[CF], ko)
+            bs, xo, yo = stage.op.axis
+            s[stage].storage_align(xo, strides - 1, strides)
+            t = s[stage].fuse(xo, yo)
+            t, vi = s[stage].split(t, factor=vec)
+            t, tx = s[stage].split(t, factor=warp_size)
+            t, ty = s[stage].split(t, factor=block_row_warps)
+            _, tz = s[stage].split(t, factor=block_col_warps)
+            s[stage].bind(ty, thread_y)
+            s[stage].bind(tz, thread_z)
+            s[stage].bind(tx, thread_x)
+            s[stage].vectorize(vi)
+
+        shared_shedule(AS, AS_align)
+        shared_shedule(BS, BS_align)
+
+        shape = (wmma_m, wmma_n, wmma_k)
+        in_dtype = 'float16'
+        AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
+        BL_gemm = te.placeholder((wmma_n, wmma_k), name='BL_gemm', dtype=in_dtype)
+        k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
+        CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
+        te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) * \
+               BL_gemm[jj, k_gemm].astype(out_dtype), \
+               axis=k_gemm), name='CL_compute')
+
+        # lower the computation loops down to TensorCore hardware intrinsics
+        # by mapping the dense tensorcore to tensor intrinsics
+        s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A( \
+            AF_stride, AS_stride, shape, "row_major", \
+            (wmma_m, wmma_k), (wmma_m, wmma_k), 'float16'))
+        s[BF].tensorize(o_ii, intrin_wmma_load_matrix_W( \
+            BF_stride, BS_stride, shape, "col_major", \
+            (wmma_n, wmma_k), (wmma_n, wmma_k), 'float16'))
+        s[CF].tensorize(_ii, intrin_wmma_gemm( \
+            AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape))
+        s[CS].tensorize(bbi, intrin_wmma_store_matrix( \
+            CS_stride, CF_stride, shape, out_dtype, (wmma_m, wmma_n), (wmma_m, wmma_n)))
+
+    def _callback(op):
+        if "batch_matmul_tensorcore" in op.tag:
+            _schedule(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+def batch_matmul_tensorcore_cuda(x, y):
+    """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
+    data in batch.
+
+    Parameters
+    ----------
+    x : tvm.te.Tensor
+        3-D with shape [batch, M, K]
+
+    y : tvm.te.Tensor
+        3-D with shape [batch, N, K]
+
+    Returns
+    -------
+    output : tvm.te.Tensor
+        3-D with shape [batch, M, N]
+    """
+    assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
+    x_shape = get_const_tuple(x.shape)
+    y_shape = get_const_tuple(y.shape)
+    assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
+    assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
+    batch, M, K = x.shape
+    N = y.shape[1]
+    out_dtype = x.dtype
+
+    assert ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \
+            (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \
+            (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)), \
+        "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+
+    x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype('float16'))
+    y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype('float16'))

Review comment:
       ditto




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 merged pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
jcf94 merged pull request #7146:
URL: https://github.com/apache/tvm/pull/7146


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] merrymercy commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
merrymercy commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-750325168


   cc tensor core maintainers @vinx13 @Laurawly  @Hzfengsy 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Meteorix commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Meteorix commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-750787545


   > @Meteorix out of curiosity can you share some of your benchmarking results? I'd love to know how much faster this performs than cublas.
   
   @jwfromm following are some of the benchmark(tuning 1000 times). This schedule beat cublas on some shapes. That is also why I made `batch_matmul_cublas` autotunable in this pr.
   
   ```
   Shape: [1, 64, 1024] [1, 4096, 1024]
   batch_matmul_tensorcore.cuda   2.9238894640234948e-05
   batch_matmul_cublas.cuda       2.7487557097865394e-05 
   batch_matmul.cuda              0.00014189747117647058
   
   Shape: [1, 64, 1024] [1, 1024, 1024]
   batch_matmul_tensorcore.cuda   1.5578384301061096e-05 
   batch_matmul_cublas.cuda       2.041829239101948e-05
   batch_matmul.cuda              6.108717968157696e-05
   
   Shape: [1, 128, 1024] [1, 4096, 1024]
   batch_matmul_tensorcore.cuda   0.00011345079327976625
   batch_matmul_cublas.cuda       0.00011074180193236715 
   batch_matmul.cuda              0.00024510443407707913
   
   Shape: [1, 128, 4096] [1, 1024, 4096]
   batch_matmul_tensorcore.cuda   0.00017083510384959715
   batch_matmul_cublas.cuda       0.00010608833085714285 
   batch_matmul.cuda              0.00035638234315169367
   
   Shape: [16, 128, 64] [16, 128, 64]
   batch_matmul_cublas.cuda       6.046038943091678e-06
   batch_matmul_tensorcore.cuda   4.134768131265665e-06 
   batch_matmul.cuda              1.2430305571941866e-05
   
   Shape: [16, 128, 128] [16, 64, 128]
   batch_matmul_tensorcore.cuda   4.74178964860194e-06 
   batch_matmul_cublas.cuda       9.463372359711623e-06
   batch_matmul.cuda              1.4179731404708587e-05
   
   Shape: [1, 128, 1024] [1, 1024, 1024]
   batch_matmul_tensorcore.cuda   3.857668104222821e-05
   batch_matmul_cublas.cuda       2.3704257450575394e-05 
   batch_matmul.cuda              0.0002515613367983368
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] merrymercy commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
merrymercy commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-750324982


   cc tensor core maintainers @vinx13 @Laurawly 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-758346419


   @jcf94 @Meteorix @jwfromm because our TOPI test stage does not gaurantee uses the tensorcore GPU(we had two pascal GPUs), it would be useful to optionally skip it, to avoid flaky CI error on the main. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Laurawly commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Laurawly commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-750795971


   > > We should fix the type issue mentioned by @jcf94.
   > > The existing `dense_tensorcore` is buggy in my view. We should fix it instead of following it.
   > > This small bug can lead to potential accuracy loss that is very hard to debug.
   > 
   > @merrymercy I see your point. Maybe we can discuss it with other tensor core maintainers and file another pr to resolve this issue?
   
   I agree with @merrymercy and think we should fix the type issue that we overlooked before. We can either fix it in this PR or in a separate parallel PR. I'd like to help with that.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] merrymercy commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
merrymercy commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r547983792



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -657,6 +657,20 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
             name="batch_matmul_cublas.cuda",
             plevel=15,
         )
+    if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+        x, y = inputs
+        B, M, K = get_const_tuple(x.shape)
+        B, N, K = get_const_tuple(y.shape)
+        # "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+        if ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \
+                (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \
+                (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)):

Review comment:
       I think it is a bug in the dense_tensorcore. We should not follow that.

##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -657,6 +657,20 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
             name="batch_matmul_cublas.cuda",
             plevel=15,
         )
+    if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+        x, y = inputs
+        B, M, K = get_const_tuple(x.shape)
+        B, N, K = get_const_tuple(y.shape)
+        # "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+        if ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \
+                (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \
+                (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)):

Review comment:
       I think it is a bug in `dense_tensorcore`. We should not follow that.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Meteorix commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Meteorix commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-750788368


   > We should fix the type issue mentioned by @jcf94.
   > The existing `dense_tensorcore` is buggy in my view. We should fix it instead of following it.
   > This small bug can lead to potential accuracy loss that is very hard to debug.
   
   @merrymercy I see your point. Maybe we can discuss it with other tensor core maintainers and file another pr to resolve this issue?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Laurawly commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Laurawly commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r550363118



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -657,6 +657,23 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
             name="batch_matmul_cublas.cuda",
             plevel=15,
         )
+    if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version):

Review comment:
       It's better to use `nvcc.have_tensorcore(target=target)` here since `tvm.gpu(0)` might not exist.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-759827498


   created https://github.com/apache/tvm/issues/7277 to track the issue


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
jcf94 commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-749364296


   @Meteorix, great thanks for your PR! The code looks good to me.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
jcf94 commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-750862121


   Thanks! @Laurawly @merrymercy 
   I think it's fine to fix them in a new PR.
   
   @Meteorix If we're not going to finish these here, you can add some TODO comments in the code and create a new issue for tracking. Please fix the CI problem and we can merge this.
   #7147 is also fine, just need to add some unit tests for these modifications.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Meteorix commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Meteorix commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r547093078



##########
File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py
##########
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
+"""cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import intrin_wmma_load_matrix_A, \
+        intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
+
+@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
+def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
+    """batch matmul tensorcore operator on cuda"""
+    # todo: deal with out_shape for broadcast, liuxin.ai
+    return batch_matmul_tensorcore_cuda(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
+def schedule_batch_matmul_tensorcore(cfg, outs):
+    """Schedule for batch_matmul operator using Tensorcore
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of batch_matmul
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule(cfg, s, C):
+        A, B = s[C].op.input_tensors
+        batch, m_dim, k_dim = get_const_tuple(A.shape)
+        batch, n_dim, k_dim = get_const_tuple(B.shape)
+        out_dtype = C.dtype
+        # inline astype fp16
+        s[A].compute_inline()
+        s[B].compute_inline()
+
+        # Explicit memory access
+        AS = s.cache_read(A, 'shared', [C])
+        BS = s.cache_read(B, 'shared', [C])
+        AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+        BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+        CF = s.cache_write(C, 'wmma.accumulator')
+        CS = s.cache_read(CF, 'shared', [C])
+
+        # fallback support
+        target = tvm.target.Target.current()
+        if cfg.is_fallback:
+            ref_log = autotvm.tophub.load_reference_log(
+                target.kind.name, target.model, 'batch_matmul_tensorcore.cuda')
+            cfg.fallback_with_reference_log(ref_log)
+
+        # Deal with op fusion, such as bias/relu and slice after padding
+        if C.op not in s.outputs and "injective" in s.outputs[0].tag:
+            s[C].compute_inline()
+            C = s.outputs[0].output(0)
+
+        # create tuning space
+        cfg.define_knob("block_row_warps", [1, 2, 4])
+        cfg.define_knob("block_col_warps", [1, 2, 4])
+        cfg.define_knob("warp_row_tiles", [1, 2, 4])
+        cfg.define_knob("warp_col_tiles", [1, 2, 4])
+        cfg.define_knob("chunk", [1, 2, 4, 8])
+        cfg.define_knob("offset", [0, 8])
+        cfg.define_knob("offsetCS", [0, 8])
+        cfg.define_knob("vec", [1, 2, 4, 8])
+
+        # Ensure that the default parameters are applicable when autotvm is not in use
+        if (m_dim % 32 == 0 and n_dim % 8 == 0):
+            cfg.define_knob("wmma_m", [32, 16, 8])
+        elif (m_dim % 16 == 0 and n_dim % 16 == 0):
+            cfg.define_knob("wmma_m", [16, 8, 32])
+        elif (m_dim % 8 == 0 and n_dim % 32 == 0):
+            cfg.define_knob("wmma_m", [8, 16, 32])
+
+        warp_size = 32
+        wmma_k = 16
+        block_row_warps = cfg["block_row_warps"].val
+        block_col_warps = cfg["block_col_warps"].val
+        warp_row_tiles = cfg["warp_row_tiles"].val
+        warp_col_tiles = cfg["warp_col_tiles"].val
+        chunk = cfg["chunk"].val
+        offset = cfg["offset"].val
+        offsetCS = cfg["offsetCS"].val
+        wmma_m = cfg["wmma_m"].val
+        vec = cfg["vec"].val
+
+        if wmma_m == 16:
+            wmma_n = 16
+        elif wmma_m == 8:
+            wmma_n = 32
+        elif wmma_m == 32:
+            wmma_n = 8
+
+        # Define the stride of intrin functions
+        AS_align = chunk * wmma_k + offset
+        BS_align = chunk * wmma_k + offset
+        CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+        AS_stride = [AS_align, 1]
+        BS_stride = [BS_align, 1]
+        AF_stride = [wmma_k, 1]
+        BF_stride = [wmma_k, 1]
+        CF_stride = [warp_col_tiles * wmma_n, 1]
+        CS_stride = [CS_align, 1]
+
+        block_x = te.thread_axis('blockIdx.x')
+        block_y = te.thread_axis('blockIdx.y')
+        block_z = te.thread_axis('blockIdx.z')
+        thread_x = te.thread_axis('threadIdx.x')
+        thread_y = te.thread_axis('threadIdx.y')
+        thread_z = te.thread_axis('threadIdx.z')
+
+        # Schedule for dense computation
+        block_factor_m = wmma_m * warp_row_tiles * block_row_warps
+        block_factor_n = wmma_n * warp_col_tiles * block_col_warps
+        b, m, n = C.op.axis
+        block_i, bc = s[C].split(m, factor=block_factor_m)
+        block_j, oc = s[C].split(n, factor=block_factor_n)
+        s[C].reorder(b, block_i, block_j, bc, oc)
+        t = s[C].fuse(bc, oc)
+        t, vi = s[C].split(t, factor=vec)
+        t, tx = s[C].split(t, factor=warp_size)
+        t, ty = s[C].split(t, factor=block_row_warps)
+        t, tz = s[C].split(t, factor=block_col_warps)
+        s[C].bind(block_i, block_x)
+        s[C].bind(block_j, block_y)
+        s[C].bind(b, block_z)
+        s[C].bind(tz, thread_z)
+        s[C].bind(ty, thread_y)
+        s[C].bind(tx, thread_x)
+        s[C].vectorize(vi)
+
+        # Schedule for wmma store
+        s[CS].compute_at(s[C], block_j)
+        bs, bb, oo = CS.op.axis
+        s[CS].storage_align(bb, CS_align - 1, CS_align)
+        bb, bbi = s[CS].split(bb, factor=wmma_m)
+        oo, ooi = s[CS].split(oo, factor=wmma_n)
+        bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+        oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+        s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi)
+
+        # Schedule for wmma computation
+        s[CF].compute_at(s[CS], oo)
+        bs, warp_i, warp_j = CF.op.axis
+        warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+        warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+        k, = CF.op.reduce_axis
+        k, _k = s[CF].split(k, factor=wmma_k)
+        ko, ki = s[CF].split(k, factor=chunk)
+        s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+        # Schedule for  wmma_matrix_a load
+        s[AF].compute_at(s[CF], ki)
+        bs, b, i = AF.op.axis
+        b, b_ii = s[AF].split(b, factor=wmma_m)
+        i, i_jj = s[AF].split(i, factor=wmma_k)
+        s[AF].reorder(bs, b, i, b_ii, i_jj)
+
+        # Schedule for  wmma_matrix_b load
+        s[BF].compute_at(s[CF], ki)
+        bs, o, i = BF.op.axis
+        o, o_ii = s[BF].split(o, factor=wmma_n)
+        i, i_ii = s[BF].split(i, factor=wmma_k)
+        s[BF].reorder(bs, o, i, o_ii, i_ii)
+
+        # Schedule for A's(B's) shared memory load
+        def shared_shedule(stage, strides):
+            s[stage].compute_at(s[CF], ko)
+            bs, xo, yo = stage.op.axis
+            s[stage].storage_align(xo, strides - 1, strides)
+            t = s[stage].fuse(xo, yo)
+            t, vi = s[stage].split(t, factor=vec)
+            t, tx = s[stage].split(t, factor=warp_size)
+            t, ty = s[stage].split(t, factor=block_row_warps)
+            _, tz = s[stage].split(t, factor=block_col_warps)
+            s[stage].bind(ty, thread_y)
+            s[stage].bind(tz, thread_z)
+            s[stage].bind(tx, thread_x)
+            s[stage].vectorize(vi)
+
+        shared_shedule(AS, AS_align)
+        shared_shedule(BS, BS_align)
+
+        shape = (wmma_m, wmma_n, wmma_k)
+        in_dtype = 'float16'

Review comment:
       also kept same with https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/dense_tensorcore.py#L72




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jwfromm commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
jwfromm commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-749941299


   @Meteorix out of curiosity can you share some of your benchmarking results? I'd love to know how much faster this performs than cublas.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] merrymercy removed a comment on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
merrymercy removed a comment on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-750324982


   cc tensor core maintainers @vinx13 @Laurawly 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Meteorix commented on pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Meteorix commented on pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#issuecomment-756577894


   @jcf94 @merrymercy @Laurawly finally the ci passed. Also I have fixed the dtype check for batch_matmul. Please review this mr again.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Meteorix commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Meteorix commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r547091707



##########
File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py
##########
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
+"""cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import intrin_wmma_load_matrix_A, \
+        intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
+
+@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
+def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
+    """batch matmul tensorcore operator on cuda"""
+    # todo: deal with out_shape for broadcast, liuxin.ai
+    return batch_matmul_tensorcore_cuda(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
+def schedule_batch_matmul_tensorcore(cfg, outs):
+    """Schedule for batch_matmul operator using Tensorcore
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of batch_matmul
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule(cfg, s, C):
+        A, B = s[C].op.input_tensors
+        batch, m_dim, k_dim = get_const_tuple(A.shape)
+        batch, n_dim, k_dim = get_const_tuple(B.shape)
+        out_dtype = C.dtype
+        # inline astype fp16
+        s[A].compute_inline()
+        s[B].compute_inline()
+
+        # Explicit memory access
+        AS = s.cache_read(A, 'shared', [C])
+        BS = s.cache_read(B, 'shared', [C])
+        AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+        BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+        CF = s.cache_write(C, 'wmma.accumulator')
+        CS = s.cache_read(CF, 'shared', [C])
+
+        # fallback support
+        target = tvm.target.Target.current()
+        if cfg.is_fallback:
+            ref_log = autotvm.tophub.load_reference_log(
+                target.kind.name, target.model, 'batch_matmul_tensorcore.cuda')
+            cfg.fallback_with_reference_log(ref_log)
+
+        # Deal with op fusion, such as bias/relu and slice after padding
+        if C.op not in s.outputs and "injective" in s.outputs[0].tag:
+            s[C].compute_inline()
+            C = s.outputs[0].output(0)
+
+        # create tuning space
+        cfg.define_knob("block_row_warps", [1, 2, 4])
+        cfg.define_knob("block_col_warps", [1, 2, 4])
+        cfg.define_knob("warp_row_tiles", [1, 2, 4])
+        cfg.define_knob("warp_col_tiles", [1, 2, 4])
+        cfg.define_knob("chunk", [1, 2, 4, 8])
+        cfg.define_knob("offset", [0, 8])
+        cfg.define_knob("offsetCS", [0, 8])
+        cfg.define_knob("vec", [1, 2, 4, 8])
+
+        # Ensure that the default parameters are applicable when autotvm is not in use
+        if (m_dim % 32 == 0 and n_dim % 8 == 0):
+            cfg.define_knob("wmma_m", [32, 16, 8])
+        elif (m_dim % 16 == 0 and n_dim % 16 == 0):
+            cfg.define_knob("wmma_m", [16, 8, 32])
+        elif (m_dim % 8 == 0 and n_dim % 32 == 0):
+            cfg.define_knob("wmma_m", [8, 16, 32])
+
+        warp_size = 32
+        wmma_k = 16
+        block_row_warps = cfg["block_row_warps"].val
+        block_col_warps = cfg["block_col_warps"].val
+        warp_row_tiles = cfg["warp_row_tiles"].val
+        warp_col_tiles = cfg["warp_col_tiles"].val
+        chunk = cfg["chunk"].val
+        offset = cfg["offset"].val
+        offsetCS = cfg["offsetCS"].val
+        wmma_m = cfg["wmma_m"].val
+        vec = cfg["vec"].val
+
+        if wmma_m == 16:
+            wmma_n = 16
+        elif wmma_m == 8:
+            wmma_n = 32
+        elif wmma_m == 32:
+            wmma_n = 8
+
+        # Define the stride of intrin functions
+        AS_align = chunk * wmma_k + offset
+        BS_align = chunk * wmma_k + offset
+        CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+        AS_stride = [AS_align, 1]
+        BS_stride = [BS_align, 1]
+        AF_stride = [wmma_k, 1]
+        BF_stride = [wmma_k, 1]
+        CF_stride = [warp_col_tiles * wmma_n, 1]
+        CS_stride = [CS_align, 1]
+
+        block_x = te.thread_axis('blockIdx.x')
+        block_y = te.thread_axis('blockIdx.y')
+        block_z = te.thread_axis('blockIdx.z')
+        thread_x = te.thread_axis('threadIdx.x')
+        thread_y = te.thread_axis('threadIdx.y')
+        thread_z = te.thread_axis('threadIdx.z')
+
+        # Schedule for dense computation
+        block_factor_m = wmma_m * warp_row_tiles * block_row_warps
+        block_factor_n = wmma_n * warp_col_tiles * block_col_warps
+        b, m, n = C.op.axis
+        block_i, bc = s[C].split(m, factor=block_factor_m)
+        block_j, oc = s[C].split(n, factor=block_factor_n)
+        s[C].reorder(b, block_i, block_j, bc, oc)
+        t = s[C].fuse(bc, oc)
+        t, vi = s[C].split(t, factor=vec)
+        t, tx = s[C].split(t, factor=warp_size)
+        t, ty = s[C].split(t, factor=block_row_warps)
+        t, tz = s[C].split(t, factor=block_col_warps)
+        s[C].bind(block_i, block_x)
+        s[C].bind(block_j, block_y)
+        s[C].bind(b, block_z)
+        s[C].bind(tz, thread_z)
+        s[C].bind(ty, thread_y)
+        s[C].bind(tx, thread_x)
+        s[C].vectorize(vi)
+
+        # Schedule for wmma store
+        s[CS].compute_at(s[C], block_j)
+        bs, bb, oo = CS.op.axis
+        s[CS].storage_align(bb, CS_align - 1, CS_align)
+        bb, bbi = s[CS].split(bb, factor=wmma_m)
+        oo, ooi = s[CS].split(oo, factor=wmma_n)
+        bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+        oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+        s[CS].reorder(bs, bb, oo, bbii, ooii, bbi, ooi)
+
+        # Schedule for wmma computation
+        s[CF].compute_at(s[CS], oo)
+        bs, warp_i, warp_j = CF.op.axis
+        warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+        warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+        k, = CF.op.reduce_axis
+        k, _k = s[CF].split(k, factor=wmma_k)
+        ko, ki = s[CF].split(k, factor=chunk)
+        s[CF].reorder(bs, ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+        # Schedule for  wmma_matrix_a load
+        s[AF].compute_at(s[CF], ki)
+        bs, b, i = AF.op.axis
+        b, b_ii = s[AF].split(b, factor=wmma_m)
+        i, i_jj = s[AF].split(i, factor=wmma_k)
+        s[AF].reorder(bs, b, i, b_ii, i_jj)
+
+        # Schedule for  wmma_matrix_b load
+        s[BF].compute_at(s[CF], ki)
+        bs, o, i = BF.op.axis
+        o, o_ii = s[BF].split(o, factor=wmma_n)
+        i, i_ii = s[BF].split(i, factor=wmma_k)
+        s[BF].reorder(bs, o, i, o_ii, i_ii)
+
+        # Schedule for A's(B's) shared memory load
+        def shared_shedule(stage, strides):
+            s[stage].compute_at(s[CF], ko)
+            bs, xo, yo = stage.op.axis
+            s[stage].storage_align(xo, strides - 1, strides)
+            t = s[stage].fuse(xo, yo)
+            t, vi = s[stage].split(t, factor=vec)
+            t, tx = s[stage].split(t, factor=warp_size)
+            t, ty = s[stage].split(t, factor=block_row_warps)
+            _, tz = s[stage].split(t, factor=block_col_warps)
+            s[stage].bind(ty, thread_y)
+            s[stage].bind(tz, thread_z)
+            s[stage].bind(tx, thread_x)
+            s[stage].vectorize(vi)
+
+        shared_shedule(AS, AS_align)
+        shared_shedule(BS, BS_align)
+
+        shape = (wmma_m, wmma_n, wmma_k)
+        in_dtype = 'float16'

Review comment:
       I just kept the same with code for dense_tensorcore https://github.com/apache/tvm/blob/main/python/tvm/relay/op/strategy/cuda.py#L679




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jwfromm commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
jwfromm commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r547662242



##########
File path: tests/python/topi/python/test_topi_batch_matmul_tensorcore.py
##########
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for batch_matmul operator"""
+import numpy as np
+import tvm
+from tvm import te
+from tvm import topi
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+from tvm.contrib.pickle_memoize import memoize
+
+import tvm.testing
+
+_batch_matmul_implement = {
+    "gpu": (topi.cuda.batch_matmul_tensorcore, topi.cuda.schedule_batch_matmul_tensorcore),
+}
+
+
+def verify_batch_matmul(x_batch, y_batch, M, N, K):
+    x = te.placeholder((x_batch, M, K), name="x")
+    y = te.placeholder((y_batch, N, K), name="y")
+    dtype = x.dtype

Review comment:
       It may be worth testing other datatypes as well, especially `float16`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jcf94 commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
jcf94 commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r547081523



##########
File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py
##########
@@ -0,0 +1,275 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
+"""cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import intrin_wmma_load_matrix_A, \
+        intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
+
+@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
+def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
+    """batch matmul tensorcore operator on cuda"""
+    # todo: deal with out_shape for broadcast, liuxin.ai
+    return batch_matmul_tensorcore_cuda(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
+def schedule_batch_matmul_tensorcore(cfg, outs):
+    """Schedule for batch_matmul operator using Tensorcore
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of batch_matmul
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule(cfg, s, C):
+        A, B = s[C].op.input_tensors
+        batch, m_dim, k_dim = get_const_tuple(A.shape)
+        batch, n_dim, k_dim = get_const_tuple(B.shape)
+        out_dtype = C.dtype
+        # inline astype fp16
+        s[A].compute_inline()
+        s[B].compute_inline()
+
+        # Explicit memory access
+        AS = s.cache_read(A, 'shared', [C])
+        BS = s.cache_read(B, 'shared', [C])
+        AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+        BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+        CF = s.cache_write(C, 'wmma.accumulator')
+        CS = s.cache_read(CF, 'shared', [C])
+
+        # fallback support
+        target = tvm.target.Target.current()
+        if cfg.is_fallback:
+            ref_log = autotvm.tophub.load_reference_log(
+                target.kind.name, target.model, 'batch_matmul_tensorcore.cuda')
+            cfg.fallback_with_reference_log(ref_log)
+
+        # ??? Deal with op fusion, such as bias and relu ??? is this needed?

Review comment:
       typo?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Meteorix commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Meteorix commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r547091793



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -657,6 +657,20 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
             name="batch_matmul_cublas.cuda",
             plevel=15,
         )
+    if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+        x, y = inputs
+        B, M, K = get_const_tuple(x.shape)
+        B, N, K = get_const_tuple(y.shape)
+        # "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
+        if ((M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or \
+                (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or \
+                (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)):

Review comment:
       I just kept the same with code for dense_tensorcore https://github.com/apache/tvm/blob/main/python/tvm/relay/op/strategy/cuda.py#L679




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Meteorix commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Meteorix commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r547086764



##########
File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py
##########
@@ -0,0 +1,275 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
+"""cuda batch_matmul operators"""
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..utils import traverse_inline, get_const_tuple
+from .tensor_intrin import intrin_wmma_load_matrix_A, \
+        intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm
+
+@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
+def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
+    """batch matmul tensorcore operator on cuda"""
+    # todo: deal with out_shape for broadcast, liuxin.ai
+    return batch_matmul_tensorcore_cuda(x, y)
+
+
+@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
+def schedule_batch_matmul_tensorcore(cfg, outs):
+    """Schedule for batch_matmul operator using Tensorcore
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of batch_matmul
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule(cfg, s, C):
+        A, B = s[C].op.input_tensors
+        batch, m_dim, k_dim = get_const_tuple(A.shape)
+        batch, n_dim, k_dim = get_const_tuple(B.shape)
+        out_dtype = C.dtype
+        # inline astype fp16
+        s[A].compute_inline()
+        s[B].compute_inline()
+
+        # Explicit memory access
+        AS = s.cache_read(A, 'shared', [C])
+        BS = s.cache_read(B, 'shared', [C])
+        AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+        BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+        CF = s.cache_write(C, 'wmma.accumulator')
+        CS = s.cache_read(CF, 'shared', [C])
+
+        # fallback support
+        target = tvm.target.Target.current()
+        if cfg.is_fallback:
+            ref_log = autotvm.tophub.load_reference_log(
+                target.kind.name, target.model, 'batch_matmul_tensorcore.cuda')
+            cfg.fallback_with_reference_log(ref_log)
+
+        # ??? Deal with op fusion, such as bias and relu ??? is this needed?

Review comment:
       fixed




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Laurawly commented on a change in pull request #7146: [CUDA]batch_matmul tensorcore schedule

Posted by GitBox <gi...@apache.org>.
Laurawly commented on a change in pull request #7146:
URL: https://github.com/apache/tvm/pull/7146#discussion_r550363118



##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -657,6 +657,23 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
             name="batch_matmul_cublas.cuda",
             plevel=15,
         )
+    if target.kind.name == "cuda" and nvcc.have_tensorcore(tvm.gpu(0).compute_version):

Review comment:
       Maybe it's better to use `nvcc.have_tensorcore(target=target)` here since `tvm.gpu(0)` might not exist?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org