You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/04/12 19:25:55 UTC

[tvm] branch main updated: [CUDNN] Add partitioning support for conv2d and log_softmax (#10961)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 98fc6495bb [CUDNN] Add partitioning support for conv2d and log_softmax (#10961)
98fc6495bb is described below

commit 98fc6495bbf9f6d1ae68ed2a495e87c4b469fd67
Author: Matthew Barrett <55...@users.noreply.github.com>
AuthorDate: Tue Apr 12 20:25:50 2022 +0100

    [CUDNN] Add partitioning support for conv2d and log_softmax (#10961)
---
 python/tvm/relay/op/contrib/cudnn.py | 62 +++++++++++++++++++++++++++++++++
 tests/python/contrib/test_cudnn.py   | 66 +++++++++++++++++++++++++++++++++++-
 2 files changed, 127 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/contrib/cudnn.py b/python/tvm/relay/op/contrib/cudnn.py
index 591178e6f8..9714a0b87d 100644
--- a/python/tvm/relay/op/contrib/cudnn.py
+++ b/python/tvm/relay/op/contrib/cudnn.py
@@ -24,6 +24,7 @@ from tvm import relay
 from tvm import te
 from tvm.relay import transform
 from tvm.contrib import cudnn
+from tvm.relay.build_module import bind_params_by_name
 
 from ...dataflow_pattern import is_op, wildcard
 from .te_target import lower_composite, relay_to_runtime
@@ -50,6 +51,8 @@ def partition_for_cudnn(
     tvm.IRModule
         The partitioned module.
     """
+    if params:
+        mod["main"] = bind_params_by_name(mod["main"], params)
 
     seq = tvm.transform.Sequential(
         [
@@ -71,6 +74,14 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], boo
         """Create pattern for softmax."""
         return is_op("nn.softmax")(wildcard())
 
+    def log_softmax_pattern() -> relay.Pattern:
+        """Create pattern for log_softmax."""
+        return is_op("nn.log_softmax")(wildcard())
+
+    def conv2d_pattern() -> relay.Pattern:
+        """Create pattern for conv2d."""
+        return is_op("nn.conv2d")(wildcard(), wildcard())
+
     def check_softmax(matched: relay.Call) -> bool:
         """Check if softmax is supported by cuDNN."""
         if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
@@ -78,8 +89,36 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], boo
 
         return True
 
+    def check_log_softmax(matched: relay.Call) -> bool:
+        """Check if log_softmax is supported by cuDNN."""
+        if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
+            return False
+
+        if len(matched.args[0].checked_type.shape) != 2:
+            return False
+
+        if matched.attrs["axis"] not in (1, -1):
+            return False
+
+        return True
+
+    def check_conv2d(matched: relay.Call) -> bool:
+        if matched.args[0].checked_type.dtype not in ["float64", "float32", "float16"]:
+            return False
+
+        if matched.attrs["data_layout"] != "NCHW" or matched.attrs["kernel_layout"] != "OIHW":
+            return False
+
+        padding = matched.attrs["padding"]
+        if padding[0] != padding[2] or padding[1] != padding[3]:
+            return False
+
+        return True
+
     return [
         ("cudnn.softmax", softmax_pattern(), check_softmax),
+        ("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax),
+        ("cudnn.conv2d", conv2d_pattern(), check_conv2d),
     ]
 
 
@@ -87,3 +126,26 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, Callable[[relay.Call], boo
 def _lower_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
     """Lower a softmax using cuDNN."""
     return cudnn.softmax(inputs[0], axis=op.attrs["axis"])
+
+
+@lower_composite("cudnn.log_softmax")
+def _lower_log_softmax(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
+    """Lower a log_softmax using cuDNN."""
+    return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"])
+
+
+@lower_composite("cudnn.conv2d")
+def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
+    """Lower a conv2d using cuDNN."""
+    return cudnn.conv_forward(
+        inputs[0],
+        inputs[1],
+        pad=op.attrs["padding"],
+        stride=op.attrs["strides"],
+        dilation=op.attrs["dilation"],
+        conv_mode=1,
+        tensor_format=0,
+        algo=1,
+        conv_dtype=op.checked_type.dtype,
+        groups=op.attrs["groups"],
+    )
diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py
index 45ca7c9171..8ca3df343d 100644
--- a/tests/python/contrib/test_cudnn.py
+++ b/tests/python/contrib/test_cudnn.py
@@ -484,7 +484,7 @@ def _verify_cudnn_relay(expr):
     tvm.testing.assert_allclose(
         outputs[0],
         outputs[1],
-        rtol=1e-3,
+        rtol=1e-2,
     )
 
 
@@ -513,5 +513,69 @@ def test_relay_cudnn_softmax(shape, axis, dtype):
     _verify_cudnn_relay(softmax)
 
 
+@tvm.testing.requires_cuda
+@pytest.mark.parametrize(
+    "shape,axis",
+    [
+        ((32, 16), -1),
+        ((13, 27), 1),
+    ],
+)
+@pytest.mark.parametrize(
+    "dtype",
+    [
+        "float32",
+        "float16",
+        "float64",
+    ],
+)
+def test_relay_cudnn_log_softmax(shape, axis, dtype):
+    x = tvm.relay.var("x", tvm.relay.TensorType(shape, dtype))
+    log_softmax = relay.op.nn.log_softmax(x, axis=axis)
+    _verify_cudnn_relay(log_softmax)
+
+
+@tvm.testing.requires_cuda
+@pytest.mark.parametrize(
+    "n,h,w,ci,co,groups",
+    [
+        (1, 16, 20, 8, 16, 1),
+        (10, 17, 19, 16, 8, 4),
+    ],
+)
+@pytest.mark.parametrize(
+    "kh,kw,padding",
+    [
+        (1, 1, (3, 1, 3, 1)),
+        (3, 3, (1, 2)),
+        (7, 2, (0, 0)),
+    ],
+)
+@pytest.mark.parametrize(
+    "strides,dilation,dtype",
+    [
+        ((1, 1), (1, 1), "float32"),
+        ((2, 1), (2, 2), "float16"),
+        ((3, 3), (1, 2), "float64"),
+    ],
+)
+def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype):
+    data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype))
+    weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, kh, kw), dtype))
+    conv2d = relay.op.nn.conv2d(
+        data,
+        weight,
+        groups=groups,
+        channels=co,
+        kernel_size=(kh, kw),
+        strides=strides,
+        dilation=dilation,
+        padding=padding,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+    )
+    _verify_cudnn_relay(conv2d)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main(sys.argv))