You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by la...@apache.org on 2021/01/29 19:00:33 UTC

[tvm] branch main updated: [CUDA][PASS]Legalize tensorcore (#7147)

This is an automated email from the ASF dual-hosted git repository.

laurawly pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3734d5f  [CUDA][PASS]Legalize tensorcore (#7147)
3734d5f is described below

commit 3734d5f7f8475a2a7897f239b9942c913256fc96
Author: Meteorix <li...@bytedance.com>
AuthorDate: Sat Jan 30 03:00:12 2021 +0800

    [CUDA][PASS]Legalize tensorcore (#7147)
    
    * add pad_to_tensorcore & legalize for dense/bmm/conv2d
    
    * fix pad & slice
    
    * fix comments
    
    * fix comments
    
    * resolve conflict
    
    * resolve conflict
    
    * support only fp16
    
    * add tests/python/relay/test_pass_legalize_tensorcore.py
    
    * add tests for legalize tensorcore
    
    * fix pylint
    
    * fix pylint
    
    * code format
    
    * use_gpu test only; fix conv2d_alter_op
    
    * fix tests params
    
    * revert transform fix
---
 python/tvm/relay/op/nn/_nn.py                      |  42 ++++
 python/tvm/topi/cuda/__init__.py                   |   1 +
 python/tvm/topi/cuda/conv2d_alter_op.py            |  48 +++++
 python/tvm/topi/cuda/tensorcore_alter_op.py        | 204 ++++++++++++++++++
 python/tvm/topi/nn/batch_matmul.py                 |  24 +++
 python/tvm/topi/nn/dense.py                        |  24 +++
 .../python/relay/test_pass_legalize_tensorcore.py  | 239 +++++++++++++++++++++
 7 files changed, 582 insertions(+)

diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index c5af5d8..37ee6b6 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -52,6 +52,27 @@ reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
+@reg.register_legalize("nn.dense")
+def legalize_dense(attrs, inputs, types):
+    """Legalize dense op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    return topi.nn.dense_legalize(attrs, inputs, types)
+
+
 # dense
 reg.register_strategy("nn.dense", strategy.dense_strategy)
 reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -67,6 +88,27 @@ reg.register_injective_schedule("nn.fifo_buffer")
 reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)
 
 
+@reg.register_legalize("nn.batch_matmul")
+def legalize_batch_matmul(attrs, inputs, types):
+    """Legalize batch_matmul op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    return topi.nn.batch_matmul_legalize(attrs, inputs, types)
+
+
 # batch_matmul
 reg.register_strategy("nn.batch_matmul", strategy.batch_matmul_strategy)
 reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index e0ff5a1..bf3582c 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -55,5 +55,6 @@ from .dense_tensorcore import *
 from .conv2d_hwnc_tensorcore import *
 from .correlation import *
 from .sparse import *
+from . import tensorcore_alter_op
 from .argwhere import *
 from .scan import *
diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py
index 8cf0519..65bf9d1 100644
--- a/python/tvm/topi/cuda/conv2d_alter_op.py
+++ b/python/tvm/topi/cuda/conv2d_alter_op.py
@@ -24,8 +24,10 @@ from tvm import te, relay, autotvm
 from .. import nn
 from ..utils import get_const_tuple
 from .conv2d_winograd import _infer_tile_size
+from .tensorcore_alter_op import pad_to_tensorcore
 from ..nn import conv2d_legalize
 
+
 logger = logging.getLogger("topi")
 
 
@@ -345,4 +347,50 @@ def _conv2d_legalize(attrs, inputs, arg_types):
             else:
                 out = relay.nn.conv2d(data, kernel, **new_attrs)
             return out
+    elif data_dtype in ["float16"]:  # todo: support int8/int4
+        if data_layout == "NHWC" and kernel_layout == "HWIO":
+            batch = data_tensor.shape[0].value
+            in_channel = data_tensor.shape[3].value
+            out_channel = kernel_tensor.shape[3].value
+
+            if (
+                (batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
+                or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
+                or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
+            ):
+                # no need to pad
+                return None
+
+            (db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel)
+
+            if extra_flops > 2:
+                logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
+                return None
+
+            logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)
+
+            # Pad batch size
+            if db != 0:
+                data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))
+
+            # Pad input channel
+            if di != 0:
+                data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
+                kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))
+
+            # Pad output channel
+            if do != 0:
+                kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))
+
+            if do != 0:
+                new_out_channel = out_channel + do
+                new_attrs["channels"] = new_out_channel
+
+            out = relay.nn.conv2d(data, kernel, **new_attrs)
+
+            if db != 0 or do != 0:
+                original_out_shape = [x.value for x in output_tensor.shape]
+                out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
+
+            return out
     return None
diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py
new file mode 100644
index 0000000..aec7acb
--- /dev/null
+++ b/python/tvm/topi/cuda/tensorcore_alter_op.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Tensorcore alter op and legalize functions for cuda backend"""
+
+import logging
+import math
+from tvm import relay
+
+from .. import nn
+
+logger = logging.getLogger("topi")
+
+
+@nn.batch_matmul_legalize.register("cuda")
+def _batch_matmul_legalize(attrs, inputs, arg_types):
+    """Legalizes batch_matmul op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    arg_types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    # Collect the input tensors.
+    x_tensor, y_tensor = arg_types[0], arg_types[1]
+    dtype = x_tensor.dtype
+
+    # Collect the output tensor.
+    output_tensor = arg_types[2]
+
+    # Collect the input exprs.
+    x, y = inputs
+
+    # Pad input and output channels to use tensorcore schedule.
+    if dtype in ["float16"]:  # todo: support int8/int4
+        B, M, K = x_tensor.shape
+        B, N, K = y_tensor.shape
+        M = M.value
+        K = K.value
+        N = N.value
+
+        # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
+        if (
+            (M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
+            or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
+            or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
+        ):
+            # no need to pad
+            return None
+
+        (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N)
+
+        if extra_flops > 2:
+            logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops)
+            return None
+
+        logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops)
+        if dm or dk:
+            x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
+        else:
+            x_ = x
+        if dn or dk:
+            y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk)))
+        else:
+            y_ = y
+        out_ = relay.nn.batch_matmul(x_, y_)
+        if dm or dn:
+            original_out_shape = [x.value for x in output_tensor.shape]
+            out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape)
+        else:
+            out = out_
+        return out
+    return None
+
+
+@nn.dense_legalize.register("cuda")
+def _dense_legalize(attrs, inputs, arg_types):
+    """Legalizes dense op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    # Collect the input tensors.
+    x_tensor, y_tensor = arg_types[0], arg_types[1]
+    dtype = x_tensor.dtype
+
+    # Collect the output tensor.
+    output_tensor = arg_types[2]
+
+    # Collect the input exprs.
+    x, y = inputs
+
+    # Pad input and output channels to use tensorcore schedule.
+    if dtype in ["float16"]:  # todo: support int8/int4
+        M, K = x_tensor.shape
+        N, K = y_tensor.shape
+        try:
+            M = M.value
+            K = K.value
+            N = N.value
+        except AttributeError:
+            # todo: deal with unfixed shape when compiling wdl model
+            return None
+
+        # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
+        if (
+            (M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
+            or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
+            or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
+        ):
+            # no need to pad
+            return None
+
+        (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N)
+
+        if extra_flops_ratio > 2:
+            logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
+            return None
+
+        logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio)
+
+        if dm or dk:
+            x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk)))
+        else:
+            x_ = x
+        if dn or dk:
+            y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk)))
+        else:
+            y_ = y
+        out_ = relay.nn.dense(x_, y_)
+        if dm or dn:
+            original_out_shape = [x.value for x in output_tensor.shape]
+            out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape)
+        else:
+            out = out_
+        return out
+    return None
+
+
+def pad_to_tensorcore(M, K, N):
+    """pad shape to enable tensorcore"""
+    candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
+
+    flops = M * K * N
+    extra_flops = math.inf
+    best_pad = (0, 0, 0)
+    for padding in candidates:
+        dm, dk, dn = _pad_to(M, K, N, padding)
+        e = (M + dm) * (N + dn) * (K + dk) - M * N * K
+        # print(dm, dk, dn, e, flops)
+        if e < extra_flops:
+            extra_flops = e
+            best_pad = (dm, dk, dn)
+    return best_pad, extra_flops / flops
+
+
+def _pad_to(M, K, N, PADDING):
+    dm, dk, dn = 0, 0, 0
+
+    if M % PADDING[0] != 0:
+        M_ = ((M + PADDING[0]) // PADDING[0]) * PADDING[0]
+        dm = M_ - M
+    if K % PADDING[1] != 0:
+        K_ = ((K + PADDING[1]) // PADDING[1]) * PADDING[1]
+        dk = K_ - K
+    if N % PADDING[2] != 0:
+        N_ = ((N + PADDING[2]) // PADDING[2]) * PADDING[2]
+        dn = N_ - N
+
+    return dm, dk, dn
diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py
index 9ca2df7..9c58481 100644
--- a/python/tvm/topi/nn/batch_matmul.py
+++ b/python/tvm/topi/nn/batch_matmul.py
@@ -16,6 +16,7 @@
 # under the License.
 """Batch matrix multiplication"""
 # pylint: disable=invalid-name
+import tvm
 from tvm import te, auto_scheduler
 from ..utils import get_const_tuple
 
@@ -77,3 +78,26 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
         output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)
 
     return output
+
+
+@tvm.target.generic_func
+def batch_matmul_legalize(attrs, inputs, types):
+    """Legalizes batch_matmul op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current batch_matmul
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    # not to change by default
+    # pylint: disable=unused-argument
+    return None
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index 474fea4..bb6ea90 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """TVM operator fully connected compute."""
+import tvm
 from tvm import te, auto_scheduler
 from .. import tag
 
@@ -80,3 +81,26 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo
         matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout)
 
     return matmul
+
+
+@tvm.target.generic_func
+def dense_legalize(attrs, inputs, types):
+    """Legalizes dense op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current dense
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    types : list of types
+        List of input and output types
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The legalized expr
+    """
+    # not to change by default
+    # pylint: disable=unused-argument
+    return None
diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py
new file mode 100644
index 0000000..5ecda4b
--- /dev/null
+++ b/tests/python/relay/test_pass_legalize_tensorcore.py
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test legalize pass"""
+import numpy as np
+import tvm
+from tvm import te
+from tvm import topi
+from tvm import relay
+from tvm.contrib import graph_runtime
+from tvm.relay import transform, analysis
+from tvm.relay.testing.temp_op_attr import TempOpAttr
+
+
+def run_opt_pass(expr, passes):
+    passes = passes if isinstance(passes, list) else [passes]
+    mod = tvm.IRModule.from_expr(expr)
+    seq = tvm.transform.Sequential(passes)
+    with tvm.transform.PassContext(opt_level=3):
+        mod = seq(mod)
+    entry = mod["main"]
+    return entry if isinstance(expr, relay.Function) else entry.body
+
+
+@tvm.testing.uses_gpu
+def test_legalize_conv2d():
+    """test legalize conv2d to enable tensorcore"""
+
+    def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, do_pad=True):
+        out_channel = kernel_shape[3]
+        out_shape = list(data_shape)
+        out_shape[3] = out_channel
+        db, di, do = pad_shape
+
+        def before():
+            x = relay.var("x", shape=data_shape, dtype="float16")
+            weight = relay.var("weight", shape=kernel_shape, dtype="float16")
+            y = relay.nn.conv2d(
+                x,
+                weight,
+                channels=out_channel,
+                kernel_size=(3, 3),
+                padding=(1, 1),
+                data_layout="NHWC",
+                kernel_layout="HWIO",
+            )
+            y = relay.Function([x, weight], y)
+            return y
+
+        def legalize_conv2d(attrs, inputs, types):
+            with tvm.target.Target("cuda"):
+                return topi.nn.conv2d_legalize(attrs, inputs, types)
+
+        def expected():
+            if not do_pad:
+                return before()
+            x = relay.var("x", shape=data_shape, dtype="float16")
+            if db or di:
+                x_pad = relay.nn.pad(x, pad_width=((0, db), (0, 0), (0, 0), (0, di)))
+            else:
+                x_pad = x
+            weight = relay.var("weight", shape=(kernel_shape), dtype="float16")
+            if di or do:
+                weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, 0), (0, di), (0, do)))
+            else:
+                weight_pad = weight
+            y_pad = relay.nn.conv2d(
+                x_pad,
+                weight=weight_pad,
+                channels=out_channel + do,
+                kernel_size=(3, 3),
+                padding=(1, 1),
+                data_layout="NHWC",
+                kernel_layout="HWIO",
+            )
+            if db or do:
+                y = relay.strided_slice(y_pad, begin=[0, 0, 0, 0], end=out_shape)
+            else:
+                y = y_pad
+            y = relay.Function([x, weight], y)
+            return y
+
+        with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
+            a = before()
+            a = run_opt_pass(a, transform.Legalize())
+            b = run_opt_pass(expected(), transform.InferType())
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
+
+    # conv2d pad batch
+    _test_legalize_conv2d((7, 16, 16, 64), (3, 3, 64, 64), (1, 0, 0))
+    _test_legalize_conv2d((3, 16, 16, 64), (3, 3, 64, 64), (5, 0, 0))
+    _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), False)
+    # conv2d pad in_channel
+    _test_legalize_conv2d((8, 16, 16, 63), (3, 3, 63, 64), (0, 1, 0))
+    _test_legalize_conv2d((8, 16, 16, 33), (3, 3, 33, 64), (0, 15, 0))
+    _test_legalize_conv2d((8, 16, 16, 13), (3, 3, 13, 64), (0, 3, 0))
+    _test_legalize_conv2d((8, 16, 16, 1), (3, 3, 1, 64), (0, 0, 0), False)
+    # conv2d pad out_channel
+    _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 63), (0, 0, 1))
+    _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 33), (0, 0, 31))
+    _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 1), (0, 0, 0), False)
+
+
+@tvm.testing.uses_gpu
+def test_legalize_dense():
+    def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True):
+        """test legalize dense to enable tensorcore"""
+        M, K = data_shape
+        N, _ = kernel_shape
+        out_shape = (M, N)
+        dm, dk, dn = pad_shape
+
+        def before():
+            x = relay.var("x", shape=data_shape, dtype="float16")
+            weight = relay.var("weight", shape=kernel_shape, dtype="float16")
+            y = relay.nn.dense(x, weight)
+            y = relay.Function([x, weight], y)
+            return y
+
+        def legalize_dense(attrs, inputs, types):
+            with tvm.target.Target("cuda"):
+                return topi.nn.dense_legalize(attrs, inputs, types)
+
+        def expected():
+            if not do_pad:
+                return before()
+            x = relay.var("x", shape=data_shape, dtype="float16")
+            if dm or dk:
+                x_pad = relay.nn.pad(x, pad_width=((0, dm), (0, dk)))
+            else:
+                x_pad = x
+            weight = relay.var("weight", shape=(kernel_shape), dtype="float16")
+            if dn or dk:
+                weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk)))
+            else:
+                weight_pad = weight
+            y_pad = relay.nn.dense(
+                x_pad,
+                weight_pad,
+            )
+            if dm or dn:
+                y = relay.strided_slice(y_pad, begin=[0, 0], end=out_shape)
+            else:
+                y = y_pad
+            y = relay.Function([x, weight], y)
+            return y
+
+        with TempOpAttr("nn.dense", "FTVMLegalize", legalize_dense):
+            a = before()
+            a = run_opt_pass(a, transform.Legalize())
+            b = run_opt_pass(expected(), transform.InferType())
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
+
+    # dense
+    _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), False)
+    _test_legalize_dense((7, 16), (32, 16), (1, 0, 0))
+    _test_legalize_dense((8, 15), (32, 15), (0, 1, 0))
+    _test_legalize_dense((8, 16), (31, 16), (0, 0, 1))
+    _test_legalize_dense((7, 15), (31, 15), (1, 1, 1))
+    _test_legalize_dense((3, 16), (32, 16), (5, 0, 0))
+    _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), False)
+
+
+@tvm.testing.uses_gpu
+def test_legalize_batch_matmul():
+    def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True):
+        """test legalize dense to enable tensorcore"""
+        B, M, _ = data_shape
+        _, N, _ = kernel_shape
+        out_shape = (B, M, N)
+        dm, dk, dn = pad_shape
+
+        def before():
+            x = relay.var("x", shape=data_shape, dtype="float16")
+            weight = relay.var("weight", shape=kernel_shape, dtype="float16")
+            y = relay.nn.batch_matmul(x, weight)
+            y = relay.Function([x, weight], y)
+            return y
+
+        def legalize_batch_matmul(attrs, inputs, types):
+            with tvm.target.Target("cuda"):
+                return topi.nn.batch_matmul_legalize(attrs, inputs, types)
+
+        def expected():
+            if not do_pad:
+                return before()
+            x = relay.var("x", shape=data_shape, dtype="float16")
+            if dm or dk:
+                x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
+            else:
+                x_pad = x
+            weight = relay.var("weight", shape=(kernel_shape), dtype="float16")
+            if dn or dk:
+                weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk)))
+            else:
+                weight_pad = weight
+            y_pad = relay.nn.batch_matmul(
+                x_pad,
+                weight_pad,
+            )
+            if dm or dn:
+                y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape)
+            else:
+                y = y_pad
+            y = relay.Function([x, weight], y)
+            return y
+
+        with TempOpAttr("nn.batch_matmul", "FTVMLegalize", legalize_batch_matmul):
+            a = before()
+            a = run_opt_pass(a, transform.Legalize())
+            b = run_opt_pass(expected(), transform.InferType())
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
+
+    _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), False)
+    _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0))
+    _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0))
+    _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1))
+    _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1))
+    _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0))
+    _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), False)
+
+
+if __name__ == "__main__":
+    test_legalize_conv2d()
+    test_legalize_dense()
+    test_legalize_batch_matmul()