You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/12/02 01:35:08 UTC

[tvm] branch main updated: [CUTLASS] Initial conv2d support (#9595)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dc988b2  [CUTLASS] Initial conv2d support (#9595)
dc988b2 is described below

commit dc988b288d85660822e0fbadbf1fc74326e763e5
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Dec 2 10:34:46 2021 +0900

    [CUTLASS] Initial conv2d support (#9595)
    
    * Add initial conv generator
    
    * added conv2d pattern
    
    * profile by gemm profiler
    
    * remove conv2d profiler for now
    
    * remove unused code
    
    * add default
    
    * minor fix, profiling working
    
    * start codegen
    
    * generated code compiled
    
    * fixed layout initialization
    
    * matched with autotvm tensorcore result
    
    * test refactor
    
    * minor cleanup
    
    * remove iteration algo "Analytic"
    
    * add test for dynamic batch conv2d
    
    * pass dl tensor as output too
    
    * support conv2d dynamic shape in codegen
    
    * test working
    
    * lint
    
    * simplify codegen
    
    * fix weird formatting
    
    * typo fix
    
    * check if cutlass is enabled in the test
    
    * simplify gen_conv2d.py
---
 python/tvm/contrib/cutlass/build.py             |  90 +++++++--
 python/tvm/contrib/cutlass/conv2d_operation.py  | 240 ++++++++++++++++++++++++
 python/tvm/contrib/cutlass/gen_conv2d.py        | 147 +++++++++++++++
 python/tvm/contrib/cutlass/gen_gemm.py          |   3 +
 python/tvm/contrib/cutlass/library.py           |  57 +++++-
 python/tvm/relay/op/contrib/cutlass.py          |   7 +
 python/tvm/relay/op/nn/_nn.py                   |  15 ++
 src/relay/backend/contrib/codegen_c/codegen_c.h |  12 +-
 src/relay/backend/contrib/cutlass/codegen.cc    | 134 ++++++++++++-
 tests/python/contrib/test_cutlass.py            |  96 ++++++++++
 10 files changed, 776 insertions(+), 25 deletions(-)

diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py
index 615b900..c3a8fdc 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -23,6 +23,7 @@ import tvm
 from tvm import runtime, relay
 from tvm.contrib.nvcc import find_cuda_path, get_cuda_version
 from .gen_gemm import CutlassGemmProfiler
+from .gen_conv2d import CutlassConv2DProfiler
 
 logger = logging.getLogger("cutlass")
 
@@ -65,7 +66,7 @@ def _get_cutlass_compile_options(sm, threads):
     return kwargs
 
 
-class GemmAnnotator(tvm.relay.ExprVisitor):
+class OpAnnotator(tvm.relay.ExprVisitor):
     """Annotates partitioned functions with shape and dtype information."""
 
     def __init__(self):
@@ -81,6 +82,10 @@ class GemmAnnotator(tvm.relay.ExprVisitor):
                 self.signature["arg%d_dtype" % i] = arg.checked_type.dtype
             self.signature["ret_shape"] = op.ret_type.shape
             self.signature["ret_dtype"] = op.ret_type.dtype
+            self.visit(op.body)
+
+        if str(op) == "nn.conv2d":
+            self.op_attrs = call.attrs
 
 
 def select_gemm_kernel(
@@ -125,6 +130,8 @@ def handle_batch_matmul(
     else:
         raise ValueError("%s pattern is not implemented." % op_type)
 
+    assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now."
+
     return {
         "batch": arg0_shape[0],
         "batch_stride_A": arg0_shape[1] * arg0_shape[2],
@@ -132,6 +139,9 @@ def handle_batch_matmul(
         "batch_stride_C": arg0_shape[1] * arg1_shape[1],
         "cutlass_op_def": cutlass_op_def,
         "cutlass_op_name": out["name"],
+        "lda": "K",
+        "ldb": "K",
+        "ldc": "N",
     }
 
 
@@ -158,6 +168,50 @@ def handle_dense(
     else:
         raise ValueError("%s pattern is not implemented." % op_type)
 
+    assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now."
+
+    return {
+        "cutlass_op_def": cutlass_op_def,
+        "cutlass_op_name": out["name"],
+        "lda": "K",
+        "ldb": "K",
+        "ldc": "N",
+    }
+
+
+def handle_conv2d(
+    cutlass_profiler,
+    op_type,
+    d_shape,
+    w_shape,
+    out_shape,
+    out_dtype,
+    profile_all,
+    use_multiprocessing,
+):
+    """Profile and select a kernel for conv2d op workload."""
+    if any(isinstance(s, tvm.tir.Any) for s in d_shape):
+        out = cutlass_profiler.get_default(out_dtype)
+        logger.info("Picked the default kernel %s", out["name"])
+    else:
+        out = cutlass_profiler.profile(
+            d_shape,
+            w_shape,
+            out_shape,
+            out_dtype,
+            profile_all=profile_all,
+            use_multiprocessing=use_multiprocessing,
+        )
+        if profile_all:
+            logger.info("The best kernel is %s", out["name"])
+        else:
+            logger.info("Picked the first kernel found %s", out["name"])
+
+    if op_type == "cutlass.conv2d":
+        cutlass_op_def = out["opdef"]
+    else:
+        raise ValueError("%s pattern is not implemented." % op_type)
+
     return {
         "cutlass_op_def": cutlass_op_def,
         "cutlass_op_name": out["name"],
@@ -195,12 +249,13 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
     num_cutlass_partition : int
         The number of partitioned functions created for CUTLASS.
     """
-    cutlass_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
+    gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
+    conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir)
     num_cutlass_partition = 0
     for var in mod.get_global_vars():
         fun_name = var.name_hint
         func = mod[fun_name]
-        annotator = GemmAnnotator()
+        annotator = OpAnnotator()
         if "cutlass" in fun_name:
             num_cutlass_partition += 1
             annotator.visit(func)
@@ -213,10 +268,26 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
             arg0_shape = new_attrs["arg0_shape"]
             arg1_shape = new_attrs["arg1_shape"]
 
-            if "batch_matmul" in op_type:
+            if "conv2d" in op_type:
+                new_attrs["padding"] = annotator.op_attrs.padding
+                new_attrs["strides"] = annotator.op_attrs.strides
+                new_attrs["dilation"] = annotator.op_attrs.dilation
+                new_attrs.update(
+                    handle_conv2d(
+                        conv2d_profiler,
+                        op_type,
+                        arg0_shape,
+                        arg1_shape,
+                        annotator.signature["ret_shape"],
+                        out_dtype,
+                        profile_all,
+                        use_multiprocessing,
+                    )
+                )
+            elif "batch_matmul" in op_type:
                 new_attrs.update(
                     handle_batch_matmul(
-                        cutlass_profiler,
+                        gemm_profiler,
                         op_type,
                         arg0_shape,
                         arg1_shape,
@@ -228,7 +299,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
             elif "dense" in op_type:
                 new_attrs.update(
                     handle_dense(
-                        cutlass_profiler,
+                        gemm_profiler,
                         op_type,
                         arg0_shape,
                         arg1_shape,
@@ -240,13 +311,6 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
             else:
                 raise ValueError("%s unsupported composite" % op_type)
 
-            if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
-                new_attrs["lda"] = "K"
-                new_attrs["ldb"] = "K"
-                new_attrs["ldc"] = "N"
-            else:
-                raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"])
-
             new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs)
             new_func = relay.Function(
                 func.params,
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py
new file mode 100644
index 0000000..8a886ff
--- /dev/null
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -0,0 +1,240 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS Conv2D kernels."""
+from .library import *
+
+
+class Conv2dOperation:
+    """Describes various attributes for instantiating Conv2d kernels."""
+
+    def __init__(
+        self,
+        conv_kind,
+        iterator_algorithm,
+        arch,
+        tile_description,
+        A,
+        B,
+        C,
+        element_epilogue,
+        stride_support,
+        epilogue_functor=EpilogueFunctor.LinearCombination,
+        swizzling_functor=SwizzlingFunctor.Identity1,
+    ):
+        self.operation_kind = OperationKind.Conv2d
+        self.arch = arch
+        self.tile_description = tile_description
+        self.conv_kind = conv_kind
+        self.A = A
+        self.B = B
+        self.C = C
+        self.element_epilogue = element_epilogue
+        self.epilogue_functor = epilogue_functor
+        self.iterator_algorithm = iterator_algorithm
+        self.stride_support = stride_support
+        self.swizzling_functor = swizzling_functor
+
+    def accumulator_type(self):
+        return self.tile_description.math_instruction.element_accumulator
+
+    def core_name(self):
+        """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
+        intermediate_type = ""
+
+        if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
+            inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
+            if (
+                self.tile_description.math_instruction.element_a != self.A.element
+                and self.tile_description.math_instruction.element_a != self.accumulator_type()
+            ):
+                intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
+        else:
+            inst_shape = ""
+
+        return "%s%s%s%s_%s" % (
+            ShortDataTypeNames[self.accumulator_type()],
+            inst_shape,
+            intermediate_type,
+            ConvKindNames[self.conv_kind],
+            IteratorAlgorithmNames[self.iterator_algorithm],
+        )
+
+    def extended_name(self):
+        """ Append data types if they differ from compute type. """
+        if (
+            self.C.element != self.tile_description.math_instruction.element_accumulator
+            and self.A.element != self.tile_description.math_instruction.element_accumulator
+        ):
+            extended_name = "${element_c}_${core_name}_${element_a}"
+        elif (
+            self.C.element == self.tile_description.math_instruction.element_accumulator
+            and self.A.element != self.tile_description.math_instruction.element_accumulator
+        ):
+            extended_name = "${core_name}_${element_a}"
+        else:
+            extended_name = "${core_name}"
+
+        extended_name = substitute_template(
+            extended_name,
+            {
+                "element_a": DataTypeNames[self.A.element],
+                "element_c": DataTypeNames[self.C.element],
+                "core_name": self.core_name(),
+            },
+        )
+
+        return extended_name
+
+    def layout_name(self):
+        return "%s" % (ShortLayoutTypeNames[self.A.layout])
+
+    def procedural_name(self):
+        """
+        The full procedural name indicates architecture, extended name, tile size, and layout.
+        """
+        opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
+
+        threadblock = "%dx%d_%dx%d" % (
+            self.tile_description.threadblock_shape[0],
+            self.tile_description.threadblock_shape[1],
+            self.tile_description.threadblock_shape[2],
+            self.tile_description.stages,
+        )
+
+        if self.stride_support == StrideSupport.Unity:
+            configuration_name = (
+                "cutlass_${opcode_class}_${extended_name}_${threadblock}"
+                "_${layout}_align${alignment}_unity_stride"
+            )
+        else:
+            configuration_name = (
+                "cutlass_${opcode_class}_${extended_name}_${threadblock}"
+                "_${layout}_align${alignment}"
+            )
+
+        return substitute_template(
+            configuration_name,
+            {
+                "opcode_class": opcode_class_name,
+                "extended_name": self.extended_name(),
+                "threadblock": threadblock,
+                "layout": self.layout_name(),
+                "alignment": "%d" % self.A.alignment,
+            },
+        )
+
+
+class EmitConv2dInstance:
+    """ Responsible for emitting a CUTLASS template definition."""
+
+    def __init__(self):
+        self.template = """
+  // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
+  using ${operation_name} =
+  typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
+    ${element_a},
+    ${layout_a},
+    ${element_b},
+    ${layout_b},
+    ${element_c},
+    ${layout_c},
+    ${element_accumulator},
+    ${opcode_class},
+    ${arch},
+    cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
+    cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
+    cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
+    ${epilogue_functor}<
+      ${element_c},
+      ${epilogue_vector_length},
+      ${element_accumulator},
+      ${element_epilogue}
+    >,
+    ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
+    ${stages},
+    ${math_operator},
+    ${iterator_algorithm},
+    ${stride_support},
+    ${align_a},
+    ${align_b}
+  >::Kernel;
+"""
+
+    def emit(self, operation):
+        """Instantiate a Conv2d kernel from given `operation`."""
+        warp_shape = [
+            int(
+                operation.tile_description.threadblock_shape[idx]
+                / operation.tile_description.warp_count[idx]
+            )
+            for idx in range(3)
+        ]
+
+        epilogue_vector_length = int(
+            min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
+            / DataTypeSize[operation.C.element]
+        )
+
+        values = {
+            "operation_name": operation.procedural_name(),
+            "conv_kind": ConvKindTag[operation.conv_kind],
+            "conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(),
+            "element_a": DataTypeTag[operation.A.element],
+            "layout_a": LayoutTag[operation.A.layout],
+            "element_b": DataTypeTag[operation.B.element],
+            "layout_b": LayoutTag[operation.B.layout],
+            "element_c": DataTypeTag[operation.C.element],
+            "layout_c": LayoutTag[operation.C.layout],
+            "element_accumulator": DataTypeTag[operation.accumulator_type()],
+            "opcode_class": OpcodeClassTag[
+                operation.tile_description.math_instruction.opcode_class
+            ],
+            "arch": "cutlass::arch::Sm%d" % operation.arch,
+            "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
+            "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
+            "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
+            "warp_shape_m": str(warp_shape[0]),
+            "warp_shape_n": str(warp_shape[1]),
+            "warp_shape_k": str(warp_shape[2]),
+            "instruction_shape_m": str(
+                operation.tile_description.math_instruction.instruction_shape[0]
+            ),
+            "instruction_shape_n": str(
+                operation.tile_description.math_instruction.instruction_shape[1]
+            ),
+            "instruction_shape_k": str(
+                operation.tile_description.math_instruction.instruction_shape[2]
+            ),
+            "epilogue_vector_length": str(epilogue_vector_length),
+            "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
+            "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
+            "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
+            "stages": str(operation.tile_description.stages),
+            "iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
+            "iterator_algorithm_name": IteratorAlgorithmNames[
+                operation.iterator_algorithm
+            ].capitalize(),
+            "stride_support": StrideSupportTag[operation.stride_support],
+            "math_operator": MathOperationTag[
+                operation.tile_description.math_instruction.math_operation
+            ],
+            "align_a": str(operation.A.alignment),
+            "align_b": str(operation.B.alignment),
+        }
+
+        return substitute_template(self.template, values)
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py
new file mode 100644
index 0000000..d24e988
--- /dev/null
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -0,0 +1,147 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Conv2d kernel generator and profiler for CUTLASS."""
+from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
+from .gen_gemm import CutlassGemmProfiler
+from .library import (
+    EpilogueFunctor,
+    SwizzlingFunctor,
+    TensorDescription,
+    LayoutType,
+    ConvKind,
+    StrideSupport,
+    IteratorAlgorithm,
+)
+
+
+def create_conv2d_operator(
+    tile_descriptions,
+    data_type,
+    alignment_constraints,
+    swizzling_functor=SwizzlingFunctor.Identity4,
+):
+    """Exhaustively instantiate all kernels from a given configuration."""
+    ret = []
+
+    kernel_emitter = EmitConv2dInstance()
+
+    element_a, element_b, element_c, element_epilogue = data_type
+    iterator_algorithms = [IteratorAlgorithm.Optimized]
+
+    layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
+    for tile in tile_descriptions:
+        for alignment in alignment_constraints:
+
+            alignment_c = min(8, alignment)
+
+            A = TensorDescription(element_a, layout[0], alignment)
+            B = TensorDescription(element_b, layout[1], alignment)
+            C = TensorDescription(element_c, layout[2], alignment_c)
+
+            swizzling_functor_ = swizzling_functor
+
+            for iterator_algorithm in iterator_algorithms:
+                op_entry = {}
+
+                op = Conv2dOperation(
+                    ConvKind.Fprop,
+                    iterator_algorithm,
+                    tile.minimum_compute_capability,
+                    tile,
+                    A,
+                    B,
+                    C,
+                    element_epilogue,
+                    StrideSupport.Strided,
+                    EpilogueFunctor.LinearCombination,
+                    swizzling_functor_,
+                )
+
+                # TODO(masahi): Add profiler source here
+                op_entry["opdef"] = kernel_emitter.emit(op)
+                op_entry["op"] = op
+                op_entry["name"] = op.procedural_name()
+                op_entry["runtime"] = 9999999
+
+                # fused ops
+                for epilogue, opdef in zip(
+                    [
+                        EpilogueFunctor.LinearCombinationBias,
+                        EpilogueFunctor.LinearCombinationRelu,
+                    ],
+                    ["opdef_bias", "opdef_bias_relu"],
+                ):
+                    op = Conv2dOperation(
+                        ConvKind.Fprop,
+                        iterator_algorithm,
+                        tile.minimum_compute_capability,
+                        tile,
+                        A,
+                        B,
+                        C,
+                        element_epilogue,
+                        StrideSupport.Strided,
+                        epilogue,
+                        swizzling_functor_,
+                    )
+
+                    op_entry[opdef] = kernel_emitter.emit(op)
+
+                ret.append(op_entry)
+
+    return ret
+
+
+class CutlassConv2DProfiler:
+    """Profile all candidate kernels and select the best one."""
+
+    def __init__(self, sm, cutlass_path, binary_path):
+        self.gemm_profiler = CutlassGemmProfiler(sm, cutlass_path, binary_path)
+        self.sm = sm
+
+    def get_default(self, out_dtype):
+        gemm_profile_result = self.gemm_profiler.get_default(out_dtype)
+        tile_description = gemm_profile_result["tile_description"]
+        alignment = gemm_profile_result["alignment"]
+        data_type = gemm_profile_result["data_type"]
+        return create_conv2d_operator([tile_description], data_type, [alignment])[0]
+
+    def profile(
+        self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, use_multiprocessing=False
+    ):
+        """Profile and select the best kernel from candidate kernels.
+        If profile_all is False, return immediately after the first applicable kernel is found.
+        If use_multiprocessing is True, compile all profiler executables in parallel.
+        """
+        B, H, W, C = d_shape
+        K, R, S, _ = w_shape
+        _, P, Q, _ = out_shape
+
+        M = B * H * W
+        K = R * S * C
+        N = B * P * Q
+
+        gemm_profile_result = self.gemm_profiler.profile(
+            M, K, N, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
+        )
+
+        tile_description = gemm_profile_result["tile_description"]
+        alignment = gemm_profile_result["alignment"]
+        data_type = gemm_profile_result["data_type"]
+
+        return create_conv2d_operator([tile_description], data_type, [alignment])[0]
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py
index 4025354..cec64f0 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -125,6 +125,9 @@ def create_gemm_operator(
                     op.leading_dim(),
                 )
                 op_entry["runtime"] = 9999999
+                op_entry["tile_description"] = tile_description
+                op_entry["alignment"] = alignment
+                op_entry["data_type"] = data_type
                 ret.append(op_entry)
     return ret
 
diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py
index a3b90ff..902dc57 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -64,23 +64,27 @@ MathOperationTag = {
 class LayoutType(enum.Enum):
     ColumnMajor = enum_auto()
     RowMajor = enum_auto()
+    TensorNHWC = enum_auto()
 
 
 LayoutTag = {
     LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor",
     LayoutType.RowMajor: "cutlass::layout::RowMajor",
+    LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC",
 }
 
 
 TransposedLayout = {
     LayoutType.ColumnMajor: LayoutType.RowMajor,
     LayoutType.RowMajor: LayoutType.ColumnMajor,
+    LayoutType.TensorNHWC: LayoutType.TensorNHWC,
 }
 
 
 ShortLayoutTypeNames = {
     LayoutType.ColumnMajor: "n",
     LayoutType.RowMajor: "t",
+    LayoutType.TensorNHWC: "nhwc",
 }
 
 
@@ -105,11 +109,10 @@ OpcodeClassTag = {
 
 class OperationKind(enum.Enum):
     Gemm = enum_auto()
+    Conv2d = enum_auto()
 
 
-OperationKindNames = {
-    OperationKind.Gemm: "gemm",
-}
+OperationKindNames = {OperationKind.Gemm: "gemm", OperationKind.Conv2d: "conv2d"}
 
 
 class Target(enum.Enum):
@@ -172,6 +175,54 @@ SwizzlingFunctorTag = {
 }
 
 
+class ConvKind(enum.Enum):
+    Fprop = enum_auto()
+
+
+ConvKindTag = {
+    ConvKind.Fprop: "cutlass::conv::Operator::kFprop",
+}
+
+
+ConvKindNames = {
+    ConvKind.Fprop: "fprop",
+}
+
+
+class StrideSupport(enum.Enum):
+    Strided = enum_auto()
+    Unity = enum_auto()
+
+
+StrideSupportTag = {
+    StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
+    StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
+}
+
+
+StrideSupportNames = {
+    StrideSupport.Strided: "",
+    StrideSupport.Unity: "unity_stride",
+}
+
+
+class IteratorAlgorithm(enum.Enum):
+    Analytic = enum_auto()
+    Optimized = enum_auto()
+
+
+IteratorAlgorithmTag = {
+    IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
+    IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
+}
+
+
+IteratorAlgorithmNames = {
+    IteratorAlgorithm.Analytic: "analytic",
+    IteratorAlgorithm.Optimized: "optimized",
+}
+
+
 class MathInstruction:
     """Describe characteristics of a math instruction."""
 
diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py
index 8ed3718..4ae529e 100644
--- a/python/tvm/relay/op/contrib/cutlass.py
+++ b/python/tvm/relay/op/contrib/cutlass.py
@@ -55,6 +55,11 @@ def make_batch_matmul_pattern():
     return is_op("nn.batch_matmul")(wildcard(), wildcard())
 
 
+def make_conv2d_pattern():
+    # TODO(masahi): Check layout and alignment
+    return is_op("nn.conv2d")(wildcard(), wildcard())
+
+
 def partition_for_cutlass(mod):
     """Partition the input module into CUTLASS-supported subgraphs."""
     dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
@@ -72,6 +77,8 @@ def partition_for_cutlass(mod):
         dense_bias_pat,
         dense_pat,
         ("cutlass.batch_matmul", make_batch_matmul_pattern()),
+        # TODO(masahi): Add more conv2d patterns
+        ("cutlass.conv2d", make_conv2d_pattern()),
     ]
     mod = transform.MergeComposite(cutlass_patterns)(mod)
     mod = transform.AnnotateTarget(["cutlass"])(mod)
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 17f75a0..8357f28 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -1090,6 +1090,19 @@ def _conv_shape_func_nhwc_hwoi(dshape, kshape, strides, padding, dilation):
     return out
 
 
+@script
+def _conv_shape_func_nhwc_ohwi(dshape, kshape, strides, padding, dilation):
+    """Shape function for conv*d op with nhwc & ohwi layout."""
+    out = output_tensor((dshape.shape[0],), "int64")
+    out[0] = dshape[0]
+    out[dshape.shape[0] - 1] = kshape[0]
+
+    for i in const_range(dshape.shape[0] - 2):
+        dilated_k = (kshape[i + 1] - 1) * dilation[i] + 1
+        out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1
+    return out
+
+
 def conv_shape_func(attrs, inputs, _):
     """Shape function for conv*d op."""
     strides = get_const_tuple(attrs.strides)
@@ -1103,6 +1116,8 @@ def conv_shape_func(attrs, inputs, _):
         shape_func = _conv_shape_func_nhwc_hwio
     elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWOI":
         shape_func = _conv_shape_func_nhwc_hwoi
+    elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "OHWI":
+        shape_func = _conv_shape_func_nhwc_ohwi
     else:
         raise ValueError(
             "Unsupported data/kernel layout: %s, %s"
diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h
index 964d7de..3c6f810 100644
--- a/src/relay/backend/contrib/codegen_c/codegen_c.h
+++ b/src/relay/backend/contrib/codegen_c/codegen_c.h
@@ -191,10 +191,18 @@ class CodegenCBase {
       PrintIndents();
     }
     for (size_t i = 0; i < outs.size() - 1; i++) {
-      code_stream_ << "(" << outs[i].dtype << "*)(out" << i << "->data),\n";
+      if (pass_dl_tensor) {
+        code_stream_ << "out" << i << ",\n";
+      } else {
+        code_stream_ << "(" << outs[i].dtype << "*)(out" << i << "->data),\n";
+      }
       PrintIndents();
     }
-    code_stream_ << "(" << outs.back().dtype << "*)(out" << outs.size() - 1 << "->data));\n";
+    if (pass_dl_tensor) {
+      code_stream_ << "out" << outs.size() - 1 << ");\n";
+    } else {
+      code_stream_ << "(" << outs.back().dtype << "*)(out" << outs.size() - 1 << "->data));\n";
+    }
     PrintIndents();
     code_stream_ << "return 0;\n";
     ExitScope();
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc
index f154f86..c226da5 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -61,7 +61,7 @@ inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int in
   os << stmt;
 }
 
-Str2StrMap GemmArgsCommon(const Map<String, ObjectRef>& attrs) {
+Str2StrMap ArgsCommon(const Map<String, ObjectRef>& attrs) {
   Str2StrMap args;
   auto arg0_dtype = std::string(attrs["arg0_dtype"].as<StringObj>()->data);
   auto arg1_dtype = std::string(attrs["arg1_dtype"].as<StringObj>()->data);
@@ -72,6 +72,11 @@ Str2StrMap GemmArgsCommon(const Map<String, ObjectRef>& attrs) {
   args["op_def"] = std::string(attrs["cutlass_op_def"].as<StringObj>()->data);
   args["op_name"] = std::string(attrs["cutlass_op_name"].as<StringObj>()->data);
   args["op_type"] = std::string(attrs["op_type"].as<StringObj>()->data);
+  return args;
+}
+
+Str2StrMap GemmArgsCommon(const Map<String, ObjectRef>& attrs) {
+  Str2StrMap args = ArgsCommon(attrs);
   args["lda"] = std::string(attrs["lda"].as<StringObj>()->data);
   args["ldb"] = std::string(attrs["ldb"].as<StringObj>()->data);
   args["ldc"] = std::string(attrs["ldc"].as<StringObj>()->data);
@@ -110,7 +115,7 @@ void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs,
   CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n");
   CutlassPrint(gemm_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n");
   CutlassPrint(gemm_decl, attrs.at("op_def"));
-  CutlassPrint(gemm_decl, "using Gemm = Operation_" + attrs.at("op_name") + ";\n");
+  CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n");
 
   auto get_dim = [&attrs, &func_args](const std::string& axis, int arg_idx, int axis_idx) {
     if (attrs.at(axis) == kAnyDim) {
@@ -139,9 +144,8 @@ void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs,
     CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n");
   }
 
-  CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0);\n");
+  CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0->data);\n");
 
-  CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n");
   CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n");
   CutlassPrint(gemm_decl, " problem_size,\n");
 }
@@ -234,6 +238,112 @@ std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs,
   return gemm_decl.str();
 }
 
+Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {
+  Str2StrMap args = ArgsCommon(attrs);
+  auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
+  auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
+  auto out_shape = attrs["ret_shape"].as<ArrayNode>();
+  args["N"] = GetDimAsStr(arg0_shape->at(0));
+  args["H"] = GetDimAsStr(arg0_shape->at(1));
+  args["W"] = GetDimAsStr(arg0_shape->at(2));
+  args["C"] = GetDimAsStr(arg0_shape->at(3));
+  args["K"] = GetDimAsStr(arg1_shape->at(0));
+  args["R"] = GetDimAsStr(arg1_shape->at(1));
+  args["S"] = GetDimAsStr(arg1_shape->at(1));
+  args["P"] = GetDimAsStr(out_shape->at(1));
+  args["Q"] = GetDimAsStr(out_shape->at(2));
+  args["pad_h"] = GetDimAsStr(attrs["padding"].as<ArrayNode>()->at(0));
+  args["pad_w"] = GetDimAsStr(attrs["padding"].as<ArrayNode>()->at(1));
+  args["stride_h"] = GetDimAsStr(attrs["strides"].as<ArrayNode>()->at(0));
+  args["stride_w"] = GetDimAsStr(attrs["strides"].as<ArrayNode>()->at(1));
+  args["dilation_h"] = GetDimAsStr(attrs["dilation"].as<ArrayNode>()->at(0));
+  args["dilation_w"] = GetDimAsStr(attrs["dilation"].as<ArrayNode>()->at(1));
+  return args;
+}
+
+std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
+                     const std::vector<std::string>& func_args) {
+  std::ostringstream conv2d_decl;
+  CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
+  CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n");
+  CutlassPrint(conv2d_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n");
+
+  CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n");
+  CutlassPrint(conv2d_decl, attrs.at("op_def"));
+  CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
+                                " = cutlass::conv::device::ImplicitGemmConvolution<" +
+                                attrs.at("op_name") + ">;\n");
+  CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") + ";\n");
+
+  auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) {
+    if (attrs.at(axis) == kAnyDim) {
+      return var_name + "->shape[" + std::to_string(axis_idx) + "]";
+    } else {
+      return attrs.at(axis);
+    }
+  };
+
+  CutlassPrint(conv2d_decl, "int N = " + get_dim("N", func_args[0], 0) + ";\n");
+  CutlassPrint(conv2d_decl, "int H = " + get_dim("H", func_args[0], 1) + ";\n");
+  CutlassPrint(conv2d_decl, "int W = " + get_dim("W", func_args[0], 2) + ";\n");
+  CutlassPrint(conv2d_decl, "int C = " + attrs.at("C") + ";\n");
+  CutlassPrint(conv2d_decl, "int K = " + attrs.at("K") + ";\n");
+  CutlassPrint(conv2d_decl, "int R = " + attrs.at("R") + ";\n");
+  CutlassPrint(conv2d_decl, "int S = " + attrs.at("S") + ";\n");
+  CutlassPrint(conv2d_decl, "int P = " + get_dim("P", "out0", 1) + ";\n");
+  CutlassPrint(conv2d_decl, "int Q = " + get_dim("Q", "out0", 2) + ";\n");
+  CutlassPrint(conv2d_decl, "int pad_h = " + attrs.at("pad_h") + ";\n");
+  CutlassPrint(conv2d_decl, "int pad_w = " + attrs.at("pad_w") + ";\n");
+  CutlassPrint(conv2d_decl, "int stride_h = " + attrs.at("stride_h") + ";\n");
+  CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n");
+  CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n");
+  CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n");
+
+  CutlassPrint(
+      conv2d_decl,
+      "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, "
+      "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, 1);\n");
+
+  ICHECK(func_args.size() >= 2);
+  CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n");
+  CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");
+  CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
+  CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
+  CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
+
+  CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n");
+  CutlassPrint(conv2d_decl,
+               "TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n");
+  CutlassPrint(conv2d_decl,
+               "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n");
+  CutlassPrint(conv2d_decl,
+               "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n");
+  CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
+  CutlassPrint(conv2d_decl, " problem_size,\n");
+  CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
+  CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
+  CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
+  CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
+  CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
+  CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");
+
+  CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
+  // Allocate workspace memory
+  CutlassPrint(conv2d_decl,
+               "cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);\n");
+  // Check the problem size is supported or not
+  CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n");
+  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+  // Initialize CUTLASS kernel with arguments and workspace pointer
+  CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n");
+  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+  // Launch initialized CUTLASS kernel
+  CutlassPrint(conv2d_decl, "status = conv2d_op();\n");
+  CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+
+  return conv2d_decl.str();
+}
+
 class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
  public:
   CodegenCutlass(const std::string& id, const Map<String, ObjectRef>& attrs) {
@@ -268,9 +378,9 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
       code_stream_ << "DLTensor* " << arg->name_hint() << ", ";
     }
     for (size_t i = 0; i < out.size() - 1; ++i) {
-      code_stream_ << out[i].dtype << "* out" << i << ", ";
+      code_stream_ << "DLTensor* out" << i << ", ";
     }
-    code_stream_ << out.back().dtype << "* out" << out.size() - 1 << ") {\n";
+    code_stream_ << "DLTensor* out" << out.size() - 1 << ") {\n";
     this->EnterScope();
 
     // Function body
@@ -347,7 +457,12 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
           GetRootCall(callee->body.as<CallNode>(), 0, {"nn.batch_matmul"});
       return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller),
                           BatchMatmulArgs(std::ref(attrs_)));
+    } else if (pattern_name == "cutlass.conv2d") {
+      const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d"});
+      return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller),
+                          Conv2dArgs(std::ref(attrs_)));
     }
+
     LOG(FATAL) << "Unknown composite function: " << pattern_name;
     return {};
   }
@@ -392,7 +507,10 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
       ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
     } else if (func_name == "cutlass_batch_matmul") {
       ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
+    } else if (func_name == "cutlass_conv2d") {
+      ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
     }
+
     return ret;
   }
   /*! \brief The id of the external cutlass ext_func. */
@@ -441,10 +559,12 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase {
     // cutlass header
     code_stream_ << "#include <cuda_fp16.h>\n";
     code_stream_ << "#include <cutlass/cutlass.h>\n";
+    code_stream_ << "#include <cutlass/coord.h>\n";
     code_stream_ << "#include <cutlass/util/host_tensor.h>\n";
-    code_stream_ << "#include <cutlass/util/reference/host/tensor_fill.h>\n";
     code_stream_ << "#include <cutlass/gemm/device/gemm.h>\n";
     code_stream_ << "#include <cutlass/gemm/device/gemm_batched.h>\n";
+    code_stream_ << "#include <cutlass/conv/kernel/default_conv2d_fprop.h>\n";
+    code_stream_ << "#include <cutlass/conv/device/implicit_gemm_convolution.h>\n";
     code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_bias_relu.h>\n";
     code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_gelu.h>\n";
 
diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py
index 6f27d57..a258da3 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -110,6 +110,22 @@ def get_batch_matmul(batch, M, N, K, out_dtype="float16"):
     return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16")
 
 
+def get_conv2d_nchw(d_shape, w_shape):
+    data = relay.var("data", shape=d_shape, dtype="float16")
+    weight = relay.var("weight", shape=w_shape, dtype="float16")
+    out_channel = w_shape[0]
+    return tvm.IRModule.from_expr(
+        relay.nn.conv2d(
+            data=data,
+            weight=weight,
+            kernel_size=(3, 3),
+            channels=out_channel,
+            padding=(1, 1),
+            out_dtype="float16",
+        )
+    )
+
+
 def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
     mod = partition_for_cutlass(mod)
     mod, num_cutlass_partition = tune_cutlass_kernels(
@@ -289,5 +305,85 @@ def test_batch_matmul():
         )
 
 
+def convert_conv2d_layout(mod, desired_layouts):
+    with tvm.transform.PassContext(opt_level=3):
+        seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
+        return seq(mod)
+
+
+def verify_conv2d(
+    mod_nchw,
+    mod_ref,
+    d_shape,
+    w_shape,
+    sm=80,
+    atol=1e-5,
+    rtol=1e-5,
+    run_benchmark=False,
+):
+    if not has_cutlass():
+        return
+
+    np_data = np.random.uniform(-1, 1, d_shape).astype("float16")
+    np_weight = np.random.uniform(-1, 1, w_shape).astype("float16")
+
+    params = {"weight": np_weight}
+
+    typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
+    use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
+
+    if use_vm:
+        rt_mod, dev, num_cutlass_partition = profile_and_build_vm(
+            convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm
+        )
+        out = get_output_vm(rt_mod, ["data"], [np_data])
+    else:
+        rt_mod, dev, num_cutlass_partition = profile_and_build(
+            convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}),
+            params,
+            sm,
+        )
+        out = get_output(rt_mod, ["data"], [np_data])
+
+    assert num_cutlass_partition > 0
+
+    rt_mod_ref, _ = get_ref_rt_mod(
+        convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
+        params,
+        target="cuda",
+    )
+    ref_out = get_output(rt_mod_ref, ["data"], [np_data])
+
+    np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
+
+    if run_benchmark:
+        print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
+        print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))
+
+
+def test_conv2d():
+    d_shape = (16, 16, 32, 32)
+    w_shape = (32, 16, 3, 3)
+    mod_nchw = get_conv2d_nchw(d_shape, w_shape)
+
+    verify_conv2d(
+        mod_nchw,
+        mod_nchw,
+        d_shape,
+        w_shape,
+        sm=80,
+        atol=1e-5,
+        rtol=1e-5,
+        run_benchmark=False,
+    )
+
+    dyn_batch_shape = (relay.Any(),) + d_shape[1:]
+    mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape)
+
+    verify_conv2d(
+        mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
+    )
+
+
 if __name__ == "__main__":
     pytest.main([__file__])