You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by la...@apache.org on 2021/01/29 19:00:33 UTC
[tvm] branch main updated: [CUDA][PASS]Legalize tensorcore (#7147)
This is an automated email from the ASF dual-hosted git repository.
laurawly pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3734d5f [CUDA][PASS]Legalize tensorcore (#7147)
3734d5f is described below
commit 3734d5f7f8475a2a7897f239b9942c913256fc96
Author: Meteorix <li...@bytedance.com>
AuthorDate: Sat Jan 30 03:00:12 2021 +0800
[CUDA][PASS]Legalize tensorcore (#7147)
* add pad_to_tensorcore & legalize for dense/bmm/conv2d
* fix pad & slice
* fix comments
* fix comments
* resolve conflict
* resolve conflict
* support only fp16
* add tests/python/relay/test_pass_legalize_tensorcore.py
* add tests for legalize tensorcore
* fix pylint
* fix pylint
* code format
* use_gpu test only; fix conv2d_alter_op
* fix tests params
* revert transform fix
---
python/tvm/relay/op/nn/_nn.py | 42 ++++
python/tvm/topi/cuda/__init__.py | 1 +
python/tvm/topi/cuda/conv2d_alter_op.py | 48 +++++
python/tvm/topi/cuda/tensorcore_alter_op.py | 204 ++++++++++++++++++
python/tvm/topi/nn/batch_matmul.py | 24 +++
python/tvm/topi/nn/dense.py | 24 +++
.../python/relay/test_pass_legalize_tensorcore.py | 239 +++++++++++++++++++++
7 files changed, 582 insertions(+)
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index c5af5d8..37ee6b6 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -52,6 +52,27 @@ reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
+@reg.register_legalize("nn.dense")
+def legalize_dense(attrs, inputs, types):
+ """Legalize dense op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ return topi.nn.dense_legalize(attrs, inputs, types)
+
+
# dense
reg.register_strategy("nn.dense", strategy.dense_strategy)
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -67,6 +88,27 @@ reg.register_injective_schedule("nn.fifo_buffer")
reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)
+@reg.register_legalize("nn.batch_matmul")
+def legalize_batch_matmul(attrs, inputs, types):
+ """Legalize batch_matmul op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ return topi.nn.batch_matmul_legalize(attrs, inputs, types)
+
+
# batch_matmul
reg.register_strategy("nn.batch_matmul", strategy.batch_matmul_strategy)
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
diff --git a/python/tvm/topi/cuda/__init__.py b/python/tvm/topi/cuda/__init__.py
index e0ff5a1..bf3582c 100644
--- a/python/tvm/topi/cuda/__init__.py
+++ b/python/tvm/topi/cuda/__init__.py
@@ -55,5 +55,6 @@ from .dense_tensorcore import *
from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
+from . import tensorcore_alter_op
from .argwhere import *
from .scan import *
diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py
index 8cf0519..65bf9d1 100644
--- a/python/tvm/topi/cuda/conv2d_alter_op.py
+++ b/python/tvm/topi/cuda/conv2d_alter_op.py
@@ -24,8 +24,10 @@ from tvm import te, relay, autotvm
from .. import nn
from ..utils import get_const_tuple
from .conv2d_winograd import _infer_tile_size
+from .tensorcore_alter_op import pad_to_tensorcore
from ..nn import conv2d_legalize
+
logger = logging.getLogger("topi")
@@ -345,4 +347,50 @@ def _conv2d_legalize(attrs, inputs, arg_types):
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
+ elif data_dtype in ["float16"]: # todo: support int8/int4
+ if data_layout == "NHWC" and kernel_layout == "HWIO":
+ batch = data_tensor.shape[0].value
+ in_channel = data_tensor.shape[3].value
+ out_channel = kernel_tensor.shape[3].value
+
+ if (
+ (batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
+ or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
+ or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
+ ):
+ # no need to pad
+ return None
+
+ (db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel)
+
+ if extra_flops > 2:
+ logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
+ return None
+
+ logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)
+
+ # Pad batch size
+ if db != 0:
+ data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))
+
+ # Pad input channel
+ if di != 0:
+ data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
+ kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))
+
+ # Pad output channel
+ if do != 0:
+ kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))
+
+ if do != 0:
+ new_out_channel = out_channel + do
+ new_attrs["channels"] = new_out_channel
+
+ out = relay.nn.conv2d(data, kernel, **new_attrs)
+
+ if db != 0 or do != 0:
+ original_out_shape = [x.value for x in output_tensor.shape]
+ out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
+
+ return out
return None
diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py
new file mode 100644
index 0000000..aec7acb
--- /dev/null
+++ b/python/tvm/topi/cuda/tensorcore_alter_op.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+"""Tensorcore alter op and legalize functions for cuda backend"""
+
+import logging
+import math
+from tvm import relay
+
+from .. import nn
+
+logger = logging.getLogger("topi")
+
+
+@nn.batch_matmul_legalize.register("cuda")
+def _batch_matmul_legalize(attrs, inputs, arg_types):
+ """Legalizes batch_matmul op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ arg_types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ # Collect the input tensors.
+ x_tensor, y_tensor = arg_types[0], arg_types[1]
+ dtype = x_tensor.dtype
+
+ # Collect the output tensor.
+ output_tensor = arg_types[2]
+
+ # Collect the input exprs.
+ x, y = inputs
+
+ # Pad input and output channels to use tensorcore schedule.
+ if dtype in ["float16"]: # todo: support int8/int4
+ B, M, K = x_tensor.shape
+ B, N, K = y_tensor.shape
+ M = M.value
+ K = K.value
+ N = N.value
+
+ # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
+ if (
+ (M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
+ or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
+ or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
+ ):
+ # no need to pad
+ return None
+
+ (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N)
+
+ if extra_flops > 2:
+ logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops)
+ return None
+
+ logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops)
+ if dm or dk:
+ x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
+ else:
+ x_ = x
+ if dn or dk:
+ y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk)))
+ else:
+ y_ = y
+ out_ = relay.nn.batch_matmul(x_, y_)
+ if dm or dn:
+ original_out_shape = [x.value for x in output_tensor.shape]
+ out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape)
+ else:
+ out = out_
+ return out
+ return None
+
+
+@nn.dense_legalize.register("cuda")
+def _dense_legalize(attrs, inputs, arg_types):
+ """Legalizes dense op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ # Collect the input tensors.
+ x_tensor, y_tensor = arg_types[0], arg_types[1]
+ dtype = x_tensor.dtype
+
+ # Collect the output tensor.
+ output_tensor = arg_types[2]
+
+ # Collect the input exprs.
+ x, y = inputs
+
+ # Pad input and output channels to use tensorcore schedule.
+ if dtype in ["float16"]: # todo: support int8/int4
+ M, K = x_tensor.shape
+ N, K = y_tensor.shape
+ try:
+ M = M.value
+ K = K.value
+ N = N.value
+ except AttributeError:
+ # todo: deal with unfixed shape when compiling wdl model
+ return None
+
+ # The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
+ if (
+ (M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
+ or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
+ or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
+ ):
+ # no need to pad
+ return None
+
+ (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N)
+
+ if extra_flops_ratio > 2:
+ logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
+ return None
+
+ logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio)
+
+ if dm or dk:
+ x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk)))
+ else:
+ x_ = x
+ if dn or dk:
+ y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk)))
+ else:
+ y_ = y
+ out_ = relay.nn.dense(x_, y_)
+ if dm or dn:
+ original_out_shape = [x.value for x in output_tensor.shape]
+ out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape)
+ else:
+ out = out_
+ return out
+ return None
+
+
+def pad_to_tensorcore(M, K, N):
+ """pad shape to enable tensorcore"""
+ candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
+
+ flops = M * K * N
+ extra_flops = math.inf
+ best_pad = (0, 0, 0)
+ for padding in candidates:
+ dm, dk, dn = _pad_to(M, K, N, padding)
+ e = (M + dm) * (N + dn) * (K + dk) - M * N * K
+ # print(dm, dk, dn, e, flops)
+ if e < extra_flops:
+ extra_flops = e
+ best_pad = (dm, dk, dn)
+ return best_pad, extra_flops / flops
+
+
+def _pad_to(M, K, N, PADDING):
+ dm, dk, dn = 0, 0, 0
+
+ if M % PADDING[0] != 0:
+ M_ = ((M + PADDING[0]) // PADDING[0]) * PADDING[0]
+ dm = M_ - M
+ if K % PADDING[1] != 0:
+ K_ = ((K + PADDING[1]) // PADDING[1]) * PADDING[1]
+ dk = K_ - K
+ if N % PADDING[2] != 0:
+ N_ = ((N + PADDING[2]) // PADDING[2]) * PADDING[2]
+ dn = N_ - N
+
+ return dm, dk, dn
diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py
index 9ca2df7..9c58481 100644
--- a/python/tvm/topi/nn/batch_matmul.py
+++ b/python/tvm/topi/nn/batch_matmul.py
@@ -16,6 +16,7 @@
# under the License.
"""Batch matrix multiplication"""
# pylint: disable=invalid-name
+import tvm
from tvm import te, auto_scheduler
from ..utils import get_const_tuple
@@ -77,3 +78,26 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)
return output
+
+
+@tvm.target.generic_func
+def batch_matmul_legalize(attrs, inputs, types):
+ """Legalizes batch_matmul op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current batch_matmul
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ # not to change by default
+ # pylint: disable=unused-argument
+ return None
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index 474fea4..bb6ea90 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""TVM operator fully connected compute."""
+import tvm
from tvm import te, auto_scheduler
from .. import tag
@@ -80,3 +81,26 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo
matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout)
return matmul
+
+
+@tvm.target.generic_func
+def dense_legalize(attrs, inputs, types):
+ """Legalizes dense op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current dense
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+ # not to change by default
+ # pylint: disable=unused-argument
+ return None
diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py
new file mode 100644
index 0000000..5ecda4b
--- /dev/null
+++ b/tests/python/relay/test_pass_legalize_tensorcore.py
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test legalize pass"""
+import numpy as np
+import tvm
+from tvm import te
+from tvm import topi
+from tvm import relay
+from tvm.contrib import graph_runtime
+from tvm.relay import transform, analysis
+from tvm.relay.testing.temp_op_attr import TempOpAttr
+
+
+def run_opt_pass(expr, passes):
+ passes = passes if isinstance(passes, list) else [passes]
+ mod = tvm.IRModule.from_expr(expr)
+ seq = tvm.transform.Sequential(passes)
+ with tvm.transform.PassContext(opt_level=3):
+ mod = seq(mod)
+ entry = mod["main"]
+ return entry if isinstance(expr, relay.Function) else entry.body
+
+
+@tvm.testing.uses_gpu
+def test_legalize_conv2d():
+ """test legalize conv2d to enable tensorcore"""
+
+ def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, do_pad=True):
+ out_channel = kernel_shape[3]
+ out_shape = list(data_shape)
+ out_shape[3] = out_channel
+ db, di, do = pad_shape
+
+ def before():
+ x = relay.var("x", shape=data_shape, dtype="float16")
+ weight = relay.var("weight", shape=kernel_shape, dtype="float16")
+ y = relay.nn.conv2d(
+ x,
+ weight,
+ channels=out_channel,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+ y = relay.Function([x, weight], y)
+ return y
+
+ def legalize_conv2d(attrs, inputs, types):
+ with tvm.target.Target("cuda"):
+ return topi.nn.conv2d_legalize(attrs, inputs, types)
+
+ def expected():
+ if not do_pad:
+ return before()
+ x = relay.var("x", shape=data_shape, dtype="float16")
+ if db or di:
+ x_pad = relay.nn.pad(x, pad_width=((0, db), (0, 0), (0, 0), (0, di)))
+ else:
+ x_pad = x
+ weight = relay.var("weight", shape=(kernel_shape), dtype="float16")
+ if di or do:
+ weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, 0), (0, di), (0, do)))
+ else:
+ weight_pad = weight
+ y_pad = relay.nn.conv2d(
+ x_pad,
+ weight=weight_pad,
+ channels=out_channel + do,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+ if db or do:
+ y = relay.strided_slice(y_pad, begin=[0, 0, 0, 0], end=out_shape)
+ else:
+ y = y_pad
+ y = relay.Function([x, weight], y)
+ return y
+
+ with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
+ a = before()
+ a = run_opt_pass(a, transform.Legalize())
+ b = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
+
+ # conv2d pad batch
+ _test_legalize_conv2d((7, 16, 16, 64), (3, 3, 64, 64), (1, 0, 0))
+ _test_legalize_conv2d((3, 16, 16, 64), (3, 3, 64, 64), (5, 0, 0))
+ _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), False)
+ # conv2d pad in_channel
+ _test_legalize_conv2d((8, 16, 16, 63), (3, 3, 63, 64), (0, 1, 0))
+ _test_legalize_conv2d((8, 16, 16, 33), (3, 3, 33, 64), (0, 15, 0))
+ _test_legalize_conv2d((8, 16, 16, 13), (3, 3, 13, 64), (0, 3, 0))
+ _test_legalize_conv2d((8, 16, 16, 1), (3, 3, 1, 64), (0, 0, 0), False)
+ # conv2d pad out_channel
+ _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 63), (0, 0, 1))
+ _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 33), (0, 0, 31))
+ _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 1), (0, 0, 0), False)
+
+
+@tvm.testing.uses_gpu
+def test_legalize_dense():
+ def _test_legalize_dense(data_shape, kernel_shape, pad_shape, do_pad=True):
+ """test legalize dense to enable tensorcore"""
+ M, K = data_shape
+ N, _ = kernel_shape
+ out_shape = (M, N)
+ dm, dk, dn = pad_shape
+
+ def before():
+ x = relay.var("x", shape=data_shape, dtype="float16")
+ weight = relay.var("weight", shape=kernel_shape, dtype="float16")
+ y = relay.nn.dense(x, weight)
+ y = relay.Function([x, weight], y)
+ return y
+
+ def legalize_dense(attrs, inputs, types):
+ with tvm.target.Target("cuda"):
+ return topi.nn.dense_legalize(attrs, inputs, types)
+
+ def expected():
+ if not do_pad:
+ return before()
+ x = relay.var("x", shape=data_shape, dtype="float16")
+ if dm or dk:
+ x_pad = relay.nn.pad(x, pad_width=((0, dm), (0, dk)))
+ else:
+ x_pad = x
+ weight = relay.var("weight", shape=(kernel_shape), dtype="float16")
+ if dn or dk:
+ weight_pad = relay.nn.pad(weight, pad_width=((0, dn), (0, dk)))
+ else:
+ weight_pad = weight
+ y_pad = relay.nn.dense(
+ x_pad,
+ weight_pad,
+ )
+ if dm or dn:
+ y = relay.strided_slice(y_pad, begin=[0, 0], end=out_shape)
+ else:
+ y = y_pad
+ y = relay.Function([x, weight], y)
+ return y
+
+ with TempOpAttr("nn.dense", "FTVMLegalize", legalize_dense):
+ a = before()
+ a = run_opt_pass(a, transform.Legalize())
+ b = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
+
+ # dense
+ _test_legalize_dense((8, 16), (32, 16), (0, 0, 0), False)
+ _test_legalize_dense((7, 16), (32, 16), (1, 0, 0))
+ _test_legalize_dense((8, 15), (32, 15), (0, 1, 0))
+ _test_legalize_dense((8, 16), (31, 16), (0, 0, 1))
+ _test_legalize_dense((7, 15), (31, 15), (1, 1, 1))
+ _test_legalize_dense((3, 16), (32, 16), (5, 0, 0))
+ _test_legalize_dense((2, 16), (32, 16), (0, 0, 0), False)
+
+
+@tvm.testing.uses_gpu
+def test_legalize_batch_matmul():
+ def _test_legalize_batch_matmul(data_shape, kernel_shape, pad_shape, do_pad=True):
+ """test legalize dense to enable tensorcore"""
+ B, M, _ = data_shape
+ _, N, _ = kernel_shape
+ out_shape = (B, M, N)
+ dm, dk, dn = pad_shape
+
+ def before():
+ x = relay.var("x", shape=data_shape, dtype="float16")
+ weight = relay.var("weight", shape=kernel_shape, dtype="float16")
+ y = relay.nn.batch_matmul(x, weight)
+ y = relay.Function([x, weight], y)
+ return y
+
+ def legalize_batch_matmul(attrs, inputs, types):
+ with tvm.target.Target("cuda"):
+ return topi.nn.batch_matmul_legalize(attrs, inputs, types)
+
+ def expected():
+ if not do_pad:
+ return before()
+ x = relay.var("x", shape=data_shape, dtype="float16")
+ if dm or dk:
+ x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
+ else:
+ x_pad = x
+ weight = relay.var("weight", shape=(kernel_shape), dtype="float16")
+ if dn or dk:
+ weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, dn), (0, dk)))
+ else:
+ weight_pad = weight
+ y_pad = relay.nn.batch_matmul(
+ x_pad,
+ weight_pad,
+ )
+ if dm or dn:
+ y = relay.strided_slice(y_pad, begin=[0, 0, 0], end=out_shape)
+ else:
+ y = y_pad
+ y = relay.Function([x, weight], y)
+ return y
+
+ with TempOpAttr("nn.batch_matmul", "FTVMLegalize", legalize_batch_matmul):
+ a = before()
+ a = run_opt_pass(a, transform.Legalize())
+ b = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
+
+ _test_legalize_batch_matmul((16, 8, 16), (16, 32, 16), (0, 0, 0), False)
+ _test_legalize_batch_matmul((16, 7, 16), (16, 32, 16), (1, 0, 0))
+ _test_legalize_batch_matmul((16, 8, 15), (16, 32, 15), (0, 1, 0))
+ _test_legalize_batch_matmul((16, 8, 16), (16, 31, 16), (0, 0, 1))
+ _test_legalize_batch_matmul((16, 7, 15), (16, 31, 15), (1, 1, 1))
+ _test_legalize_batch_matmul((16, 3, 16), (16, 32, 16), (5, 0, 0))
+ _test_legalize_batch_matmul((16, 2, 16), (16, 32, 16), (0, 0, 0), False)
+
+
+if __name__ == "__main__":
+ test_legalize_conv2d()
+ test_legalize_dense()
+ test_legalize_batch_matmul()