You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/06/18 06:58:45 UTC

[tvm] branch main updated: [CUDA][PASS] conv2d NWHC/HWNC legalize tensorcore (#8222)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5f94c1e  [CUDA][PASS] conv2d NWHC/HWNC legalize tensorcore (#8222)
5f94c1e is described below

commit 5f94c1e0fa7ccd3e6947074a594f136a43da9cce
Author: Wang Yucheng <wy...@163.com>
AuthorDate: Fri Jun 18 14:58:31 2021 +0800

    [CUDA][PASS] conv2d NWHC/HWNC legalize tensorcore (#8222)
    
    * add conv2d leg
    
    * minor fix
    
    * fix pylint
    
    * fix pylint
    
    Co-authored-by: wangyucheng <wa...@sensetime.com>
---
 python/tvm/topi/cuda/conv2d_alter_op.py            | 169 ++++++++++++++++++---
 python/tvm/topi/cuda/tensorcore_alter_op.py        |  10 +-
 .../python/relay/test_pass_legalize_tensorcore.py  | 127 +++++++++++++---
 3 files changed, 263 insertions(+), 43 deletions(-)

diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py
index 067f272..4863a06 100644
--- a/python/tvm/topi/cuda/conv2d_alter_op.py
+++ b/python/tvm/topi/cuda/conv2d_alter_op.py
@@ -270,6 +270,60 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
     return None
 
 
+def _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor):
+    # Pad batch size
+    if db != 0:
+        data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, db), (0, 0)))
+
+    # Pad input channel
+    if di != 0:
+        data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
+        kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
+
+    # Pad output channel
+    if do != 0:
+        kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, do), (0, 0)))
+
+    if do != 0:
+        new_out_channel = out_channel + do
+        new_attrs["channels"] = new_out_channel
+
+    out = relay.nn.conv2d(data, kernel, **new_attrs)
+
+    if db != 0 or do != 0:
+        original_out_shape = [x.value for x in output_tensor.shape]
+        out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
+
+    return out
+
+
+def _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor):
+    # Pad batch size
+    if db != 0:
+        data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))
+
+    # Pad input channel
+    if di != 0:
+        data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
+        kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))
+
+    # Pad output channel
+    if do != 0:
+        kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))
+
+    if do != 0:
+        new_out_channel = out_channel + do
+        new_attrs["channels"] = new_out_channel
+
+    out = relay.nn.conv2d(data, kernel, **new_attrs)
+
+    if db != 0 or do != 0:
+        original_out_shape = [x.value for x in output_tensor.shape]
+        out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
+
+    return out
+
+
 @conv2d_legalize.register("cuda")
 def _conv2d_legalize(attrs, inputs, arg_types):
     """Legalizes Conv2D op.
@@ -347,7 +401,7 @@ def _conv2d_legalize(attrs, inputs, arg_types):
             else:
                 out = relay.nn.conv2d(data, kernel, **new_attrs)
             return out
-    elif data_dtype in ["float16"]:  # todo: support int8/int4
+
         if data_layout == "NHWC" and kernel_layout == "HWIO":
             batch = data_tensor.shape[0].value
             in_channel = data_tensor.shape[3].value
@@ -361,7 +415,10 @@ def _conv2d_legalize(attrs, inputs, arg_types):
                 # no need to pad
                 return None
 
-            (db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel)
+            candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
+            (db, di, do), extra_flops = pad_to_tensorcore(
+                batch, in_channel, out_channel, candidates
+            )
 
             if extra_flops > 2:
                 logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
@@ -369,28 +426,100 @@ def _conv2d_legalize(attrs, inputs, arg_types):
 
             logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)
 
-            # Pad batch size
-            if db != 0:
-                data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))
+            return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)
 
-            # Pad input channel
-            if di != 0:
-                data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
-                kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))
+        if data_layout == "HWNC" and kernel_layout == "HWOI":
+            batch = data_tensor.shape[2].value
+            in_channel = data_tensor.shape[3].value
+            out_channel = kernel_tensor.shape[2].value
 
-            # Pad output channel
-            if do != 0:
-                kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))
+            if batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0:
+                return None
 
-            if do != 0:
-                new_out_channel = out_channel + do
-                new_attrs["channels"] = new_out_channel
+            candidates = [(8, 16, 32)]
+            (db, di, do), extra_flops = pad_to_tensorcore(
+                batch, in_channel, out_channel, candidates
+            )
 
-            out = relay.nn.conv2d(data, kernel, **new_attrs)
+            if extra_flops > 2:
+                logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
+                return None
+            logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)
 
-            if db != 0 or do != 0:
-                original_out_shape = [x.value for x in output_tensor.shape]
-                out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
+            return _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)
+
+    elif data_dtype in ["float16"]:
+        if data_layout == "NHWC" and kernel_layout == "HWIO":
+            batch = data_tensor.shape[0].value
+            in_channel = data_tensor.shape[3].value
+            out_channel = kernel_tensor.shape[3].value
+
+            if (
+                (batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
+                or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
+                or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
+            ):
+                # no need to pad
+                return None
+
+            candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
+            (db, di, do), extra_flops = pad_to_tensorcore(
+                batch, in_channel, out_channel, candidates
+            )
+
+            if extra_flops > 2:
+                logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
+                return None
+
+            logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)
+
+            return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)
+
+    elif data_dtype in ["int4", "uint4"]:
+        if data_layout == "NHWC" and kernel_layout == "HWIO":
+            batch = data_tensor.shape[0].value
+            in_channel = data_tensor.shape[3].value
+            out_channel = kernel_tensor.shape[3].value
+
+            if (
+                (batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
+                or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
+                or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
+            ):
+                # no need to pad
+                return None
+
+            candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
+            (db, di, do), extra_flops = pad_to_tensorcore(
+                batch, in_channel, out_channel, candidates
+            )
+
+            if extra_flops > 2:
+                logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
+                return None
+
+            logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)
+
+            return _pad_conv2d_NHWC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)
+
+        if data_layout == "HWNC" and kernel_layout == "HWOI":
+            batch = data_tensor.shape[2].value
+            in_channel = data_tensor.shape[3].value
+            out_channel = kernel_tensor.shape[2].value
+
+            if batch % 8 == 0 and in_channel % 32 == 0 and out_channel % 8 == 0:
+                return None
+
+            candidates = [(8, 32, 8)]
+            (db, di, do), extra_flops = pad_to_tensorcore(
+                batch, in_channel, out_channel, candidates
+            )
+
+            if extra_flops > 2:
+                logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
+                return None
+            logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)
+
+            return _pad_conv2d_HWNC(db, di, do, data, kernel, out_channel, new_attrs, output_tensor)
 
-            return out
     return None
diff --git a/python/tvm/topi/cuda/tensorcore_alter_op.py b/python/tvm/topi/cuda/tensorcore_alter_op.py
index aec7acb..eb7c71d 100644
--- a/python/tvm/topi/cuda/tensorcore_alter_op.py
+++ b/python/tvm/topi/cuda/tensorcore_alter_op.py
@@ -71,7 +71,8 @@ def _batch_matmul_legalize(attrs, inputs, arg_types):
             # no need to pad
             return None
 
-        (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N)
+        candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
+        (dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N, candidates)
 
         if extra_flops > 2:
             logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops)
@@ -145,7 +146,8 @@ def _dense_legalize(attrs, inputs, arg_types):
             # no need to pad
             return None
 
-        (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N)
+        candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
+        (dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates)
 
         if extra_flops_ratio > 2:
             logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
@@ -171,10 +173,8 @@ def _dense_legalize(attrs, inputs, arg_types):
     return None
 
 
-def pad_to_tensorcore(M, K, N):
+def pad_to_tensorcore(M, K, N, candidates):
     """pad shape to enable tensorcore"""
-    candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]
-
     flops = M * K * N
     extra_flops = math.inf
     best_pad = (0, 0, 0)
diff --git a/tests/python/relay/test_pass_legalize_tensorcore.py b/tests/python/relay/test_pass_legalize_tensorcore.py
index f45e390..1312b39 100644
--- a/tests/python/relay/test_pass_legalize_tensorcore.py
+++ b/tests/python/relay/test_pass_legalize_tensorcore.py
@@ -36,18 +36,18 @@ def run_opt_pass(expr, passes):
 
 
 @tvm.testing.uses_gpu
-def test_legalize_conv2d():
-    """test legalize conv2d to enable tensorcore"""
+def test_legalize_conv2d_NHWC():
+    """test legalize NHWC conv2d to enable tensorcore"""
 
-    def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, do_pad=True):
+    def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, dtype, do_pad=True):
         out_channel = kernel_shape[3]
         out_shape = list(data_shape)
         out_shape[3] = out_channel
         db, di, do = pad_shape
 
         def before():
-            x = relay.var("x", shape=data_shape, dtype="float16")
-            weight = relay.var("weight", shape=kernel_shape, dtype="float16")
+            x = relay.var("x", shape=data_shape, dtype=dtype)
+            weight = relay.var("weight", shape=kernel_shape, dtype=dtype)
             y = relay.nn.conv2d(
                 x,
                 weight,
@@ -67,12 +67,12 @@ def test_legalize_conv2d():
         def expected():
             if not do_pad:
                 return before()
-            x = relay.var("x", shape=data_shape, dtype="float16")
+            x = relay.var("x", shape=data_shape, dtype=dtype)
             if db or di:
                 x_pad = relay.nn.pad(x, pad_width=((0, db), (0, 0), (0, 0), (0, di)))
             else:
                 x_pad = x
-            weight = relay.var("weight", shape=(kernel_shape), dtype="float16")
+            weight = relay.var("weight", shape=(kernel_shape), dtype=dtype)
             if di or do:
                 weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, 0), (0, di), (0, do)))
             else:
@@ -99,19 +99,109 @@ def test_legalize_conv2d():
             b = run_opt_pass(expected(), transform.InferType())
         assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
 
+    for dtype in ["float16", "int8", "int4"]:
+        # conv2d pad batch
+        _test_legalize_conv2d((7, 16, 16, 64), (3, 3, 64, 64), (1, 0, 0), dtype)
+        _test_legalize_conv2d((3, 16, 16, 64), (3, 3, 64, 64), (5, 0, 0), dtype)
+        _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), dtype, False)
+        # conv2d pad in_channel
+        _test_legalize_conv2d((8, 16, 16, 63), (3, 3, 63, 64), (0, 1, 0), dtype)
+        _test_legalize_conv2d((8, 16, 16, 33), (3, 3, 33, 64), (0, 15, 0), dtype)
+        _test_legalize_conv2d((8, 16, 16, 13), (3, 3, 13, 64), (0, 3, 0), dtype)
+        _test_legalize_conv2d((8, 16, 16, 1), (3, 3, 1, 64), (0, 0, 0), dtype, False)
+        # conv2d pad out_channel
+        _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 63), (0, 0, 1), dtype)
+        _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 33), (0, 0, 31), dtype)
+        _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 1), (0, 0, 0), dtype, False)
+
+
+@tvm.testing.uses_gpu
+def test_legalize_conv2d_HWNC():
+    """test legalize HWNC conv2d to enable tensorcore"""
+
+    def _test_legalize_conv2d(data_shape, kernel_shape, pad_shape, dtype, do_pad=True):
+        out_channel = kernel_shape[2]
+        out_shape = list(data_shape)
+        out_shape[3] = out_channel
+        db, di, do = pad_shape
+
+        def before():
+            x = relay.var("x", shape=data_shape, dtype=dtype)
+            weight = relay.var("weight", shape=kernel_shape, dtype=dtype)
+            y = relay.nn.conv2d(
+                x,
+                weight,
+                channels=out_channel,
+                kernel_size=(3, 3),
+                padding=(1, 1),
+                data_layout="HWNC",
+                kernel_layout="HWOI",
+            )
+            y = relay.Function([x, weight], y)
+            return y
+
+        def legalize_conv2d(attrs, inputs, types):
+            with tvm.target.Target("cuda"):
+                return topi.nn.conv2d_legalize(attrs, inputs, types)
+
+        def expected():
+            if not do_pad:
+                return before()
+            x = relay.var("x", shape=data_shape, dtype=dtype)
+            if db or di:
+                x_pad = relay.nn.pad(x, pad_width=((0, 0), (0, 0), (0, db), (0, di)))
+            else:
+                x_pad = x
+            weight = relay.var("weight", shape=(kernel_shape), dtype=dtype)
+            if di or do:
+                weight_pad = relay.nn.pad(weight, pad_width=((0, 0), (0, 0), (0, do), (0, di)))
+            else:
+                weight_pad = weight
+            y_pad = relay.nn.conv2d(
+                x_pad,
+                weight=weight_pad,
+                channels=out_channel + do,
+                kernel_size=(3, 3),
+                padding=(1, 1),
+                data_layout="HWNC",
+                kernel_layout="HWOI",
+            )
+            if db or do:
+                y = relay.strided_slice(y_pad, begin=[0, 0, 0, 0], end=out_shape)
+            else:
+                y = y_pad
+            y = relay.Function([x, weight], y)
+            return y
+
+        with TempOpAttr("nn.conv2d", "FTVMLegalize", legalize_conv2d):
+            a = before()
+            a = run_opt_pass(a, transform.Legalize())
+            b = run_opt_pass(expected(), transform.InferType())
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
+
     # conv2d pad batch
-    _test_legalize_conv2d((7, 16, 16, 64), (3, 3, 64, 64), (1, 0, 0))
-    _test_legalize_conv2d((3, 16, 16, 64), (3, 3, 64, 64), (5, 0, 0))
-    _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), False)
+    _test_legalize_conv2d((16, 16, 7, 64), (3, 3, 64, 64), (1, 0, 0), "int8")
+    _test_legalize_conv2d((16, 16, 3, 64), (3, 3, 64, 64), (5, 0, 0), "int8")
+    _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), "int8", False)
+    _test_legalize_conv2d((16, 16, 7, 64), (3, 3, 64, 64), (1, 0, 0), "int4")
+    _test_legalize_conv2d((16, 16, 3, 64), (3, 3, 64, 64), (5, 0, 0), "int4")
+    _test_legalize_conv2d((2, 16, 16, 64), (3, 3, 64, 64), (0, 0, 0), "int4", False)
     # conv2d pad in_channel
-    _test_legalize_conv2d((8, 16, 16, 63), (3, 3, 63, 64), (0, 1, 0))
-    _test_legalize_conv2d((8, 16, 16, 33), (3, 3, 33, 64), (0, 15, 0))
-    _test_legalize_conv2d((8, 16, 16, 13), (3, 3, 13, 64), (0, 3, 0))
-    _test_legalize_conv2d((8, 16, 16, 1), (3, 3, 1, 64), (0, 0, 0), False)
+    _test_legalize_conv2d((16, 16, 8, 63), (3, 3, 64, 63), (0, 1, 0), "int8")
+    _test_legalize_conv2d((16, 16, 8, 33), (3, 3, 64, 33), (0, 15, 0), "int8")
+    _test_legalize_conv2d((16, 16, 8, 13), (3, 3, 64, 13), (0, 3, 0), "int8")
+    _test_legalize_conv2d((16, 16, 8, 1), (3, 3, 64, 1), (0, 0, 0), "int8", False)
+    _test_legalize_conv2d((16, 16, 8, 63), (3, 3, 64, 63), (0, 1, 0), "int4")
+    _test_legalize_conv2d((16, 16, 8, 33), (3, 3, 64, 33), (0, 31, 0), "int4")
+    _test_legalize_conv2d((16, 16, 8, 13), (3, 3, 64, 13), (0, 19, 0), "int4")
+    _test_legalize_conv2d((16, 16, 8, 1), (3, 3, 64, 1), (0, 0, 0), "int4", False)
     # conv2d pad out_channel
-    _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 63), (0, 0, 1))
-    _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 33), (0, 0, 31))
-    _test_legalize_conv2d((8, 16, 16, 64), (3, 3, 64, 1), (0, 0, 0), False)
+    _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 63, 64), (0, 0, 1), "int8")
+    _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 33, 64), (0, 0, 31), "int8")
+    _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 1, 64), (0, 0, 0), "int8", False)
+    _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 63, 64), (0, 0, 1), "int4")
+    _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 33, 64), (0, 0, 7), "int4")
+    _test_legalize_conv2d((16, 16, 8, 64), (3, 3, 1, 64), (0, 0, 0), "int4", False)
 
 
 @tvm.testing.uses_gpu
@@ -234,6 +324,7 @@ def test_legalize_batch_matmul():
 
 
 if __name__ == "__main__":
-    test_legalize_conv2d()
+    test_legalize_conv2d_NHWC()
+    test_legalize_conv2d_HWNC()
     test_legalize_dense()
     test_legalize_batch_matmul()