You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/12/02 01:35:08 UTC
[tvm] branch main updated: [CUTLASS] Initial conv2d support (#9595)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dc988b2 [CUTLASS] Initial conv2d support (#9595)
dc988b2 is described below
commit dc988b288d85660822e0fbadbf1fc74326e763e5
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Dec 2 10:34:46 2021 +0900
[CUTLASS] Initial conv2d support (#9595)
* Add initial conv generator
* added conv2d pattern
* profile by gemm profiler
* remove conv2d profiler for now
* remove unused code
* add default
* minor fix, profiling working
* start codegen
* generated code compiled
* fixed layout initialization
* matched with autotvm tensorcore result
* test refactor
* minor cleanup
* remove iteration algo "Analytic"
* add test for dynamic batch conv2d
* pass dl tensor as output too
* support conv2d dynamic shape in codegen
* test working
* lint
* simplify codegen
* fix weird formatting
* typo fix
* check if cutlass is enabled in the test
* simplify gen_conv2d.py
---
python/tvm/contrib/cutlass/build.py | 90 +++++++--
python/tvm/contrib/cutlass/conv2d_operation.py | 240 ++++++++++++++++++++++++
python/tvm/contrib/cutlass/gen_conv2d.py | 147 +++++++++++++++
python/tvm/contrib/cutlass/gen_gemm.py | 3 +
python/tvm/contrib/cutlass/library.py | 57 +++++-
python/tvm/relay/op/contrib/cutlass.py | 7 +
python/tvm/relay/op/nn/_nn.py | 15 ++
src/relay/backend/contrib/codegen_c/codegen_c.h | 12 +-
src/relay/backend/contrib/cutlass/codegen.cc | 134 ++++++++++++-
tests/python/contrib/test_cutlass.py | 96 ++++++++++
10 files changed, 776 insertions(+), 25 deletions(-)
diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py
index 615b900..c3a8fdc 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -23,6 +23,7 @@ import tvm
from tvm import runtime, relay
from tvm.contrib.nvcc import find_cuda_path, get_cuda_version
from .gen_gemm import CutlassGemmProfiler
+from .gen_conv2d import CutlassConv2DProfiler
logger = logging.getLogger("cutlass")
@@ -65,7 +66,7 @@ def _get_cutlass_compile_options(sm, threads):
return kwargs
-class GemmAnnotator(tvm.relay.ExprVisitor):
+class OpAnnotator(tvm.relay.ExprVisitor):
"""Annotates partitioned functions with shape and dtype information."""
def __init__(self):
@@ -81,6 +82,10 @@ class GemmAnnotator(tvm.relay.ExprVisitor):
self.signature["arg%d_dtype" % i] = arg.checked_type.dtype
self.signature["ret_shape"] = op.ret_type.shape
self.signature["ret_dtype"] = op.ret_type.dtype
+ self.visit(op.body)
+
+ if str(op) == "nn.conv2d":
+ self.op_attrs = call.attrs
def select_gemm_kernel(
@@ -125,6 +130,8 @@ def handle_batch_matmul(
else:
raise ValueError("%s pattern is not implemented." % op_type)
+ assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now."
+
return {
"batch": arg0_shape[0],
"batch_stride_A": arg0_shape[1] * arg0_shape[2],
@@ -132,6 +139,9 @@ def handle_batch_matmul(
"batch_stride_C": arg0_shape[1] * arg1_shape[1],
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
+ "lda": "K",
+ "ldb": "K",
+ "ldc": "N",
}
@@ -158,6 +168,50 @@ def handle_dense(
else:
raise ValueError("%s pattern is not implemented." % op_type)
+ assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now."
+
+ return {
+ "cutlass_op_def": cutlass_op_def,
+ "cutlass_op_name": out["name"],
+ "lda": "K",
+ "ldb": "K",
+ "ldc": "N",
+ }
+
+
+def handle_conv2d(
+ cutlass_profiler,
+ op_type,
+ d_shape,
+ w_shape,
+ out_shape,
+ out_dtype,
+ profile_all,
+ use_multiprocessing,
+):
+ """Profile and select a kernel for conv2d op workload."""
+ if any(isinstance(s, tvm.tir.Any) for s in d_shape):
+ out = cutlass_profiler.get_default(out_dtype)
+ logger.info("Picked the default kernel %s", out["name"])
+ else:
+ out = cutlass_profiler.profile(
+ d_shape,
+ w_shape,
+ out_shape,
+ out_dtype,
+ profile_all=profile_all,
+ use_multiprocessing=use_multiprocessing,
+ )
+ if profile_all:
+ logger.info("The best kernel is %s", out["name"])
+ else:
+ logger.info("Picked the first kernel found %s", out["name"])
+
+ if op_type == "cutlass.conv2d":
+ cutlass_op_def = out["opdef"]
+ else:
+ raise ValueError("%s pattern is not implemented." % op_type)
+
return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
@@ -195,12 +249,13 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
num_cutlass_partition : int
The number of partitioned functions created for CUTLASS.
"""
- cutlass_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
+ gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
+ conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir)
num_cutlass_partition = 0
for var in mod.get_global_vars():
fun_name = var.name_hint
func = mod[fun_name]
- annotator = GemmAnnotator()
+ annotator = OpAnnotator()
if "cutlass" in fun_name:
num_cutlass_partition += 1
annotator.visit(func)
@@ -213,10 +268,26 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]
- if "batch_matmul" in op_type:
+ if "conv2d" in op_type:
+ new_attrs["padding"] = annotator.op_attrs.padding
+ new_attrs["strides"] = annotator.op_attrs.strides
+ new_attrs["dilation"] = annotator.op_attrs.dilation
+ new_attrs.update(
+ handle_conv2d(
+ conv2d_profiler,
+ op_type,
+ arg0_shape,
+ arg1_shape,
+ annotator.signature["ret_shape"],
+ out_dtype,
+ profile_all,
+ use_multiprocessing,
+ )
+ )
+ elif "batch_matmul" in op_type:
new_attrs.update(
handle_batch_matmul(
- cutlass_profiler,
+ gemm_profiler,
op_type,
arg0_shape,
arg1_shape,
@@ -228,7 +299,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
elif "dense" in op_type:
new_attrs.update(
handle_dense(
- cutlass_profiler,
+ gemm_profiler,
op_type,
arg0_shape,
arg1_shape,
@@ -240,13 +311,6 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
else:
raise ValueError("%s unsupported composite" % op_type)
- if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
- new_attrs["lda"] = "K"
- new_attrs["ldb"] = "K"
- new_attrs["ldc"] = "N"
- else:
- raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"])
-
new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs)
new_func = relay.Function(
func.params,
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py
new file mode 100644
index 0000000..8a886ff
--- /dev/null
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -0,0 +1,240 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import
+"""Generator for CUTLASS Conv2D kernels."""
+from .library import *
+
+
+class Conv2dOperation:
+ """Describes various attributes for instantiating Conv2d kernels."""
+
+ def __init__(
+ self,
+ conv_kind,
+ iterator_algorithm,
+ arch,
+ tile_description,
+ A,
+ B,
+ C,
+ element_epilogue,
+ stride_support,
+ epilogue_functor=EpilogueFunctor.LinearCombination,
+ swizzling_functor=SwizzlingFunctor.Identity1,
+ ):
+ self.operation_kind = OperationKind.Conv2d
+ self.arch = arch
+ self.tile_description = tile_description
+ self.conv_kind = conv_kind
+ self.A = A
+ self.B = B
+ self.C = C
+ self.element_epilogue = element_epilogue
+ self.epilogue_functor = epilogue_functor
+ self.iterator_algorithm = iterator_algorithm
+ self.stride_support = stride_support
+ self.swizzling_functor = swizzling_functor
+
+ def accumulator_type(self):
+ return self.tile_description.math_instruction.element_accumulator
+
+ def core_name(self):
+ """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
+ intermediate_type = ""
+
+ if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
+ inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
+ if (
+ self.tile_description.math_instruction.element_a != self.A.element
+ and self.tile_description.math_instruction.element_a != self.accumulator_type()
+ ):
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
+ else:
+ inst_shape = ""
+
+ return "%s%s%s%s_%s" % (
+ ShortDataTypeNames[self.accumulator_type()],
+ inst_shape,
+ intermediate_type,
+ ConvKindNames[self.conv_kind],
+ IteratorAlgorithmNames[self.iterator_algorithm],
+ )
+
+ def extended_name(self):
+ """ Append data types if they differ from compute type. """
+ if (
+ self.C.element != self.tile_description.math_instruction.element_accumulator
+ and self.A.element != self.tile_description.math_instruction.element_accumulator
+ ):
+ extended_name = "${element_c}_${core_name}_${element_a}"
+ elif (
+ self.C.element == self.tile_description.math_instruction.element_accumulator
+ and self.A.element != self.tile_description.math_instruction.element_accumulator
+ ):
+ extended_name = "${core_name}_${element_a}"
+ else:
+ extended_name = "${core_name}"
+
+ extended_name = substitute_template(
+ extended_name,
+ {
+ "element_a": DataTypeNames[self.A.element],
+ "element_c": DataTypeNames[self.C.element],
+ "core_name": self.core_name(),
+ },
+ )
+
+ return extended_name
+
+ def layout_name(self):
+ return "%s" % (ShortLayoutTypeNames[self.A.layout])
+
+ def procedural_name(self):
+ """
+ The full procedural name indicates architecture, extended name, tile size, and layout.
+ """
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
+
+ threadblock = "%dx%d_%dx%d" % (
+ self.tile_description.threadblock_shape[0],
+ self.tile_description.threadblock_shape[1],
+ self.tile_description.threadblock_shape[2],
+ self.tile_description.stages,
+ )
+
+ if self.stride_support == StrideSupport.Unity:
+ configuration_name = (
+ "cutlass_${opcode_class}_${extended_name}_${threadblock}"
+ "_${layout}_align${alignment}_unity_stride"
+ )
+ else:
+ configuration_name = (
+ "cutlass_${opcode_class}_${extended_name}_${threadblock}"
+ "_${layout}_align${alignment}"
+ )
+
+ return substitute_template(
+ configuration_name,
+ {
+ "opcode_class": opcode_class_name,
+ "extended_name": self.extended_name(),
+ "threadblock": threadblock,
+ "layout": self.layout_name(),
+ "alignment": "%d" % self.A.alignment,
+ },
+ )
+
+
+class EmitConv2dInstance:
+ """ Responsible for emitting a CUTLASS template definition."""
+
+ def __init__(self):
+ self.template = """
+ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
+ using ${operation_name} =
+ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
+ ${element_a},
+ ${layout_a},
+ ${element_b},
+ ${layout_b},
+ ${element_c},
+ ${layout_c},
+ ${element_accumulator},
+ ${opcode_class},
+ ${arch},
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
+ ${epilogue_functor}<
+ ${element_c},
+ ${epilogue_vector_length},
+ ${element_accumulator},
+ ${element_epilogue}
+ >,
+ ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
+ ${stages},
+ ${math_operator},
+ ${iterator_algorithm},
+ ${stride_support},
+ ${align_a},
+ ${align_b}
+ >::Kernel;
+"""
+
+ def emit(self, operation):
+ """Instantiate a Conv2d kernel from given `operation`."""
+ warp_shape = [
+ int(
+ operation.tile_description.threadblock_shape[idx]
+ / operation.tile_description.warp_count[idx]
+ )
+ for idx in range(3)
+ ]
+
+ epilogue_vector_length = int(
+ min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
+ / DataTypeSize[operation.C.element]
+ )
+
+ values = {
+ "operation_name": operation.procedural_name(),
+ "conv_kind": ConvKindTag[operation.conv_kind],
+ "conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(),
+ "element_a": DataTypeTag[operation.A.element],
+ "layout_a": LayoutTag[operation.A.layout],
+ "element_b": DataTypeTag[operation.B.element],
+ "layout_b": LayoutTag[operation.B.layout],
+ "element_c": DataTypeTag[operation.C.element],
+ "layout_c": LayoutTag[operation.C.layout],
+ "element_accumulator": DataTypeTag[operation.accumulator_type()],
+ "opcode_class": OpcodeClassTag[
+ operation.tile_description.math_instruction.opcode_class
+ ],
+ "arch": "cutlass::arch::Sm%d" % operation.arch,
+ "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
+ "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
+ "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
+ "warp_shape_m": str(warp_shape[0]),
+ "warp_shape_n": str(warp_shape[1]),
+ "warp_shape_k": str(warp_shape[2]),
+ "instruction_shape_m": str(
+ operation.tile_description.math_instruction.instruction_shape[0]
+ ),
+ "instruction_shape_n": str(
+ operation.tile_description.math_instruction.instruction_shape[1]
+ ),
+ "instruction_shape_k": str(
+ operation.tile_description.math_instruction.instruction_shape[2]
+ ),
+ "epilogue_vector_length": str(epilogue_vector_length),
+ "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
+ "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
+ "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
+ "stages": str(operation.tile_description.stages),
+ "iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm],
+ "iterator_algorithm_name": IteratorAlgorithmNames[
+ operation.iterator_algorithm
+ ].capitalize(),
+ "stride_support": StrideSupportTag[operation.stride_support],
+ "math_operator": MathOperationTag[
+ operation.tile_description.math_instruction.math_operation
+ ],
+ "align_a": str(operation.A.alignment),
+ "align_b": str(operation.B.alignment),
+ }
+
+ return substitute_template(self.template, values)
diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py
new file mode 100644
index 0000000..d24e988
--- /dev/null
+++ b/python/tvm/contrib/cutlass/gen_conv2d.py
@@ -0,0 +1,147 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Conv2d kernel generator and profiler for CUTLASS."""
+from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
+from .gen_gemm import CutlassGemmProfiler
+from .library import (
+ EpilogueFunctor,
+ SwizzlingFunctor,
+ TensorDescription,
+ LayoutType,
+ ConvKind,
+ StrideSupport,
+ IteratorAlgorithm,
+)
+
+
+def create_conv2d_operator(
+ tile_descriptions,
+ data_type,
+ alignment_constraints,
+ swizzling_functor=SwizzlingFunctor.Identity4,
+):
+ """Exhaustively instantiate all kernels from a given configuration."""
+ ret = []
+
+ kernel_emitter = EmitConv2dInstance()
+
+ element_a, element_b, element_c, element_epilogue = data_type
+ iterator_algorithms = [IteratorAlgorithm.Optimized]
+
+ layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
+ for tile in tile_descriptions:
+ for alignment in alignment_constraints:
+
+ alignment_c = min(8, alignment)
+
+ A = TensorDescription(element_a, layout[0], alignment)
+ B = TensorDescription(element_b, layout[1], alignment)
+ C = TensorDescription(element_c, layout[2], alignment_c)
+
+ swizzling_functor_ = swizzling_functor
+
+ for iterator_algorithm in iterator_algorithms:
+ op_entry = {}
+
+ op = Conv2dOperation(
+ ConvKind.Fprop,
+ iterator_algorithm,
+ tile.minimum_compute_capability,
+ tile,
+ A,
+ B,
+ C,
+ element_epilogue,
+ StrideSupport.Strided,
+ EpilogueFunctor.LinearCombination,
+ swizzling_functor_,
+ )
+
+ # TODO(masahi): Add profiler source here
+ op_entry["opdef"] = kernel_emitter.emit(op)
+ op_entry["op"] = op
+ op_entry["name"] = op.procedural_name()
+ op_entry["runtime"] = 9999999
+
+ # fused ops
+ for epilogue, opdef in zip(
+ [
+ EpilogueFunctor.LinearCombinationBias,
+ EpilogueFunctor.LinearCombinationRelu,
+ ],
+ ["opdef_bias", "opdef_bias_relu"],
+ ):
+ op = Conv2dOperation(
+ ConvKind.Fprop,
+ iterator_algorithm,
+ tile.minimum_compute_capability,
+ tile,
+ A,
+ B,
+ C,
+ element_epilogue,
+ StrideSupport.Strided,
+ epilogue,
+ swizzling_functor_,
+ )
+
+ op_entry[opdef] = kernel_emitter.emit(op)
+
+ ret.append(op_entry)
+
+ return ret
+
+
+class CutlassConv2DProfiler:
+ """Profile all candidate kernels and select the best one."""
+
+ def __init__(self, sm, cutlass_path, binary_path):
+ self.gemm_profiler = CutlassGemmProfiler(sm, cutlass_path, binary_path)
+ self.sm = sm
+
+ def get_default(self, out_dtype):
+ gemm_profile_result = self.gemm_profiler.get_default(out_dtype)
+ tile_description = gemm_profile_result["tile_description"]
+ alignment = gemm_profile_result["alignment"]
+ data_type = gemm_profile_result["data_type"]
+ return create_conv2d_operator([tile_description], data_type, [alignment])[0]
+
+ def profile(
+ self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, use_multiprocessing=False
+ ):
+ """Profile and select the best kernel from candidate kernels.
+ If profile_all is False, return immediately after the first applicable kernel is found.
+ If use_multiprocessing is True, compile all profiler executables in parallel.
+ """
+ B, H, W, C = d_shape
+ K, R, S, _ = w_shape
+ _, P, Q, _ = out_shape
+
+ M = B * H * W
+ K = R * S * C
+ N = B * P * Q
+
+ gemm_profile_result = self.gemm_profiler.profile(
+ M, K, N, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
+ )
+
+ tile_description = gemm_profile_result["tile_description"]
+ alignment = gemm_profile_result["alignment"]
+ data_type = gemm_profile_result["data_type"]
+
+ return create_conv2d_operator([tile_description], data_type, [alignment])[0]
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py
index 4025354..cec64f0 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -125,6 +125,9 @@ def create_gemm_operator(
op.leading_dim(),
)
op_entry["runtime"] = 9999999
+ op_entry["tile_description"] = tile_description
+ op_entry["alignment"] = alignment
+ op_entry["data_type"] = data_type
ret.append(op_entry)
return ret
diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py
index a3b90ff..902dc57 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -64,23 +64,27 @@ MathOperationTag = {
class LayoutType(enum.Enum):
ColumnMajor = enum_auto()
RowMajor = enum_auto()
+ TensorNHWC = enum_auto()
LayoutTag = {
LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor",
LayoutType.RowMajor: "cutlass::layout::RowMajor",
+ LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC",
}
TransposedLayout = {
LayoutType.ColumnMajor: LayoutType.RowMajor,
LayoutType.RowMajor: LayoutType.ColumnMajor,
+ LayoutType.TensorNHWC: LayoutType.TensorNHWC,
}
ShortLayoutTypeNames = {
LayoutType.ColumnMajor: "n",
LayoutType.RowMajor: "t",
+ LayoutType.TensorNHWC: "nhwc",
}
@@ -105,11 +109,10 @@ OpcodeClassTag = {
class OperationKind(enum.Enum):
Gemm = enum_auto()
+ Conv2d = enum_auto()
-OperationKindNames = {
- OperationKind.Gemm: "gemm",
-}
+OperationKindNames = {OperationKind.Gemm: "gemm", OperationKind.Conv2d: "conv2d"}
class Target(enum.Enum):
@@ -172,6 +175,54 @@ SwizzlingFunctorTag = {
}
+class ConvKind(enum.Enum):
+ Fprop = enum_auto()
+
+
+ConvKindTag = {
+ ConvKind.Fprop: "cutlass::conv::Operator::kFprop",
+}
+
+
+ConvKindNames = {
+ ConvKind.Fprop: "fprop",
+}
+
+
+class StrideSupport(enum.Enum):
+ Strided = enum_auto()
+ Unity = enum_auto()
+
+
+StrideSupportTag = {
+ StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
+ StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
+}
+
+
+StrideSupportNames = {
+ StrideSupport.Strided: "",
+ StrideSupport.Unity: "unity_stride",
+}
+
+
+class IteratorAlgorithm(enum.Enum):
+ Analytic = enum_auto()
+ Optimized = enum_auto()
+
+
+IteratorAlgorithmTag = {
+ IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
+ IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
+}
+
+
+IteratorAlgorithmNames = {
+ IteratorAlgorithm.Analytic: "analytic",
+ IteratorAlgorithm.Optimized: "optimized",
+}
+
+
class MathInstruction:
"""Describe characteristics of a math instruction."""
diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py
index 8ed3718..4ae529e 100644
--- a/python/tvm/relay/op/contrib/cutlass.py
+++ b/python/tvm/relay/op/contrib/cutlass.py
@@ -55,6 +55,11 @@ def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())
+def make_conv2d_pattern():
+ # TODO(masahi): Check layout and alignment
+ return is_op("nn.conv2d")(wildcard(), wildcard())
+
+
def partition_for_cutlass(mod):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
@@ -72,6 +77,8 @@ def partition_for_cutlass(mod):
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern()),
+ # TODO(masahi): Add more conv2d patterns
+ ("cutlass.conv2d", make_conv2d_pattern()),
]
mod = transform.MergeComposite(cutlass_patterns)(mod)
mod = transform.AnnotateTarget(["cutlass"])(mod)
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 17f75a0..8357f28 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -1090,6 +1090,19 @@ def _conv_shape_func_nhwc_hwoi(dshape, kshape, strides, padding, dilation):
return out
+@script
+def _conv_shape_func_nhwc_ohwi(dshape, kshape, strides, padding, dilation):
+ """Shape function for conv*d op with nhwc & ohwi layout."""
+ out = output_tensor((dshape.shape[0],), "int64")
+ out[0] = dshape[0]
+ out[dshape.shape[0] - 1] = kshape[0]
+
+ for i in const_range(dshape.shape[0] - 2):
+ dilated_k = (kshape[i + 1] - 1) * dilation[i] + 1
+ out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1
+ return out
+
+
def conv_shape_func(attrs, inputs, _):
"""Shape function for conv*d op."""
strides = get_const_tuple(attrs.strides)
@@ -1103,6 +1116,8 @@ def conv_shape_func(attrs, inputs, _):
shape_func = _conv_shape_func_nhwc_hwio
elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWOI":
shape_func = _conv_shape_func_nhwc_hwoi
+ elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "OHWI":
+ shape_func = _conv_shape_func_nhwc_ohwi
else:
raise ValueError(
"Unsupported data/kernel layout: %s, %s"
diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h
index 964d7de..3c6f810 100644
--- a/src/relay/backend/contrib/codegen_c/codegen_c.h
+++ b/src/relay/backend/contrib/codegen_c/codegen_c.h
@@ -191,10 +191,18 @@ class CodegenCBase {
PrintIndents();
}
for (size_t i = 0; i < outs.size() - 1; i++) {
- code_stream_ << "(" << outs[i].dtype << "*)(out" << i << "->data),\n";
+ if (pass_dl_tensor) {
+ code_stream_ << "out" << i << ",\n";
+ } else {
+ code_stream_ << "(" << outs[i].dtype << "*)(out" << i << "->data),\n";
+ }
PrintIndents();
}
- code_stream_ << "(" << outs.back().dtype << "*)(out" << outs.size() - 1 << "->data));\n";
+ if (pass_dl_tensor) {
+ code_stream_ << "out" << outs.size() - 1 << ");\n";
+ } else {
+ code_stream_ << "(" << outs.back().dtype << "*)(out" << outs.size() - 1 << "->data));\n";
+ }
PrintIndents();
code_stream_ << "return 0;\n";
ExitScope();
diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc
index f154f86..c226da5 100644
--- a/src/relay/backend/contrib/cutlass/codegen.cc
+++ b/src/relay/backend/contrib/cutlass/codegen.cc
@@ -61,7 +61,7 @@ inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int in
os << stmt;
}
-Str2StrMap GemmArgsCommon(const Map<String, ObjectRef>& attrs) {
+Str2StrMap ArgsCommon(const Map<String, ObjectRef>& attrs) {
Str2StrMap args;
auto arg0_dtype = std::string(attrs["arg0_dtype"].as<StringObj>()->data);
auto arg1_dtype = std::string(attrs["arg1_dtype"].as<StringObj>()->data);
@@ -72,6 +72,11 @@ Str2StrMap GemmArgsCommon(const Map<String, ObjectRef>& attrs) {
args["op_def"] = std::string(attrs["cutlass_op_def"].as<StringObj>()->data);
args["op_name"] = std::string(attrs["cutlass_op_name"].as<StringObj>()->data);
args["op_type"] = std::string(attrs["op_type"].as<StringObj>()->data);
+ return args;
+}
+
+Str2StrMap GemmArgsCommon(const Map<String, ObjectRef>& attrs) {
+ Str2StrMap args = ArgsCommon(attrs);
args["lda"] = std::string(attrs["lda"].as<StringObj>()->data);
args["ldb"] = std::string(attrs["ldb"].as<StringObj>()->data);
args["ldc"] = std::string(attrs["ldc"].as<StringObj>()->data);
@@ -110,7 +115,7 @@ void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs,
CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n");
CutlassPrint(gemm_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n");
CutlassPrint(gemm_decl, attrs.at("op_def"));
- CutlassPrint(gemm_decl, "using Gemm = Operation_" + attrs.at("op_name") + ";\n");
+ CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n");
auto get_dim = [&attrs, &func_args](const std::string& axis, int arg_idx, int axis_idx) {
if (attrs.at(axis) == kAnyDim) {
@@ -139,9 +144,8 @@ void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs,
CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n");
}
- CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0);\n");
+ CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0->data);\n");
- CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n");
CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n");
CutlassPrint(gemm_decl, " problem_size,\n");
}
@@ -234,6 +238,112 @@ std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs,
return gemm_decl.str();
}
+Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {
+ Str2StrMap args = ArgsCommon(attrs);
+ auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
+ auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
+ auto out_shape = attrs["ret_shape"].as<ArrayNode>();
+ args["N"] = GetDimAsStr(arg0_shape->at(0));
+ args["H"] = GetDimAsStr(arg0_shape->at(1));
+ args["W"] = GetDimAsStr(arg0_shape->at(2));
+ args["C"] = GetDimAsStr(arg0_shape->at(3));
+ args["K"] = GetDimAsStr(arg1_shape->at(0));
+ args["R"] = GetDimAsStr(arg1_shape->at(1));
+ args["S"] = GetDimAsStr(arg1_shape->at(1));
+ args["P"] = GetDimAsStr(out_shape->at(1));
+ args["Q"] = GetDimAsStr(out_shape->at(2));
+ args["pad_h"] = GetDimAsStr(attrs["padding"].as<ArrayNode>()->at(0));
+ args["pad_w"] = GetDimAsStr(attrs["padding"].as<ArrayNode>()->at(1));
+ args["stride_h"] = GetDimAsStr(attrs["strides"].as<ArrayNode>()->at(0));
+ args["stride_w"] = GetDimAsStr(attrs["strides"].as<ArrayNode>()->at(1));
+ args["dilation_h"] = GetDimAsStr(attrs["dilation"].as<ArrayNode>()->at(0));
+ args["dilation_w"] = GetDimAsStr(attrs["dilation"].as<ArrayNode>()->at(1));
+ return args;
+}
+
+std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
+ const std::vector<std::string>& func_args) {
+ std::ostringstream conv2d_decl;
+ CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
+ CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n");
+ CutlassPrint(conv2d_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n");
+
+ CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n");
+ CutlassPrint(conv2d_decl, attrs.at("op_def"));
+ CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
+ " = cutlass::conv::device::ImplicitGemmConvolution<" +
+ attrs.at("op_name") + ">;\n");
+ CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") + ";\n");
+
+ auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) {
+ if (attrs.at(axis) == kAnyDim) {
+ return var_name + "->shape[" + std::to_string(axis_idx) + "]";
+ } else {
+ return attrs.at(axis);
+ }
+ };
+
+ CutlassPrint(conv2d_decl, "int N = " + get_dim("N", func_args[0], 0) + ";\n");
+ CutlassPrint(conv2d_decl, "int H = " + get_dim("H", func_args[0], 1) + ";\n");
+ CutlassPrint(conv2d_decl, "int W = " + get_dim("W", func_args[0], 2) + ";\n");
+ CutlassPrint(conv2d_decl, "int C = " + attrs.at("C") + ";\n");
+ CutlassPrint(conv2d_decl, "int K = " + attrs.at("K") + ";\n");
+ CutlassPrint(conv2d_decl, "int R = " + attrs.at("R") + ";\n");
+ CutlassPrint(conv2d_decl, "int S = " + attrs.at("S") + ";\n");
+ CutlassPrint(conv2d_decl, "int P = " + get_dim("P", "out0", 1) + ";\n");
+ CutlassPrint(conv2d_decl, "int Q = " + get_dim("Q", "out0", 2) + ";\n");
+ CutlassPrint(conv2d_decl, "int pad_h = " + attrs.at("pad_h") + ";\n");
+ CutlassPrint(conv2d_decl, "int pad_w = " + attrs.at("pad_w") + ";\n");
+ CutlassPrint(conv2d_decl, "int stride_h = " + attrs.at("stride_h") + ";\n");
+ CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n");
+ CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n");
+ CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n");
+
+ CutlassPrint(
+ conv2d_decl,
+ "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, "
+ "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, 1);\n");
+
+ ICHECK(func_args.size() >= 2);
+ CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n");
+ CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");
+ CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
+ CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
+ CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
+
+ CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n");
+ CutlassPrint(conv2d_decl,
+ "TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n");
+ CutlassPrint(conv2d_decl,
+ "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n");
+ CutlassPrint(conv2d_decl,
+ "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n");
+ CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
+ CutlassPrint(conv2d_decl, " problem_size,\n");
+ CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
+ CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
+ CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
+ CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
+ CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
+ CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");
+
+ CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
+ // Allocate workspace memory
+ CutlassPrint(conv2d_decl,
+ "cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);\n");
+ // Check the problem size is supported or not
+ CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n");
+ CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+ // Initialize CUTLASS kernel with arguments and workspace pointer
+ CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n");
+ CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+ // Launch initialized CUTLASS kernel
+ CutlassPrint(conv2d_decl, "status = conv2d_op();\n");
+ CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
+
+ return conv2d_decl.str();
+}
+
class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, public CodegenCBase {
public:
CodegenCutlass(const std::string& id, const Map<String, ObjectRef>& attrs) {
@@ -268,9 +378,9 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
code_stream_ << "DLTensor* " << arg->name_hint() << ", ";
}
for (size_t i = 0; i < out.size() - 1; ++i) {
- code_stream_ << out[i].dtype << "* out" << i << ", ";
+ code_stream_ << "DLTensor* out" << i << ", ";
}
- code_stream_ << out.back().dtype << "* out" << out.size() - 1 << ") {\n";
+ code_stream_ << "DLTensor* out" << out.size() - 1 << ") {\n";
this->EnterScope();
// Function body
@@ -347,7 +457,12 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 0, {"nn.batch_matmul"});
return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller),
BatchMatmulArgs(std::ref(attrs_)));
+ } else if (pattern_name == "cutlass.conv2d") {
+ const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d"});
+ return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller),
+ Conv2dArgs(std::ref(attrs_)));
}
+
LOG(FATAL) << "Unknown composite function: " << pattern_name;
return {};
}
@@ -392,7 +507,10 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
+ } else if (func_name == "cutlass_conv2d") {
+ ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}
+
return ret;
}
/*! \brief The id of the external cutlass ext_func. */
@@ -441,10 +559,12 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase {
// cutlass header
code_stream_ << "#include <cuda_fp16.h>\n";
code_stream_ << "#include <cutlass/cutlass.h>\n";
+ code_stream_ << "#include <cutlass/coord.h>\n";
code_stream_ << "#include <cutlass/util/host_tensor.h>\n";
- code_stream_ << "#include <cutlass/util/reference/host/tensor_fill.h>\n";
code_stream_ << "#include <cutlass/gemm/device/gemm.h>\n";
code_stream_ << "#include <cutlass/gemm/device/gemm_batched.h>\n";
+ code_stream_ << "#include <cutlass/conv/kernel/default_conv2d_fprop.h>\n";
+ code_stream_ << "#include <cutlass/conv/device/implicit_gemm_convolution.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_bias_relu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_gelu.h>\n";
diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py
index 6f27d57..a258da3 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -110,6 +110,22 @@ def get_batch_matmul(batch, M, N, K, out_dtype="float16"):
return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16")
+def get_conv2d_nchw(d_shape, w_shape):
+ data = relay.var("data", shape=d_shape, dtype="float16")
+ weight = relay.var("weight", shape=w_shape, dtype="float16")
+ out_channel = w_shape[0]
+ return tvm.IRModule.from_expr(
+ relay.nn.conv2d(
+ data=data,
+ weight=weight,
+ kernel_size=(3, 3),
+ channels=out_channel,
+ padding=(1, 1),
+ out_dtype="float16",
+ )
+ )
+
+
def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
@@ -289,5 +305,85 @@ def test_batch_matmul():
)
+def convert_conv2d_layout(mod, desired_layouts):
+ with tvm.transform.PassContext(opt_level=3):
+ seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
+ return seq(mod)
+
+
+def verify_conv2d(
+ mod_nchw,
+ mod_ref,
+ d_shape,
+ w_shape,
+ sm=80,
+ atol=1e-5,
+ rtol=1e-5,
+ run_benchmark=False,
+):
+ if not has_cutlass():
+ return
+
+ np_data = np.random.uniform(-1, 1, d_shape).astype("float16")
+ np_weight = np.random.uniform(-1, 1, w_shape).astype("float16")
+
+ params = {"weight": np_weight}
+
+ typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type
+ use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape)
+
+ if use_vm:
+ rt_mod, dev, num_cutlass_partition = profile_and_build_vm(
+ convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm
+ )
+ out = get_output_vm(rt_mod, ["data"], [np_data])
+ else:
+ rt_mod, dev, num_cutlass_partition = profile_and_build(
+ convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}),
+ params,
+ sm,
+ )
+ out = get_output(rt_mod, ["data"], [np_data])
+
+ assert num_cutlass_partition > 0
+
+ rt_mod_ref, _ = get_ref_rt_mod(
+ convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}),
+ params,
+ target="cuda",
+ )
+ ref_out = get_output(rt_mod_ref, ["data"], [np_data])
+
+ np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
+
+ if run_benchmark:
+ print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
+ print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))
+
+
+def test_conv2d():
+ d_shape = (16, 16, 32, 32)
+ w_shape = (32, 16, 3, 3)
+ mod_nchw = get_conv2d_nchw(d_shape, w_shape)
+
+ verify_conv2d(
+ mod_nchw,
+ mod_nchw,
+ d_shape,
+ w_shape,
+ sm=80,
+ atol=1e-5,
+ rtol=1e-5,
+ run_benchmark=False,
+ )
+
+ dyn_batch_shape = (relay.Any(),) + d_shape[1:]
+ mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape)
+
+ verify_conv2d(
+ mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__])