You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/11/25 23:54:24 UTC

[tvm] branch main updated: [CUTLASS] Refactor GEMM generator in preparation for conv2d (#9571)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new adf560e  [CUTLASS] Refactor GEMM generator in preparation for conv2d (#9571)
adf560e is described below

commit adf560ebed8465c22bf58f406d0a8d20663cdd1d
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Nov 26 08:53:55 2021 +0900

    [CUTLASS] Refactor GEMM generator in preparation for conv2d (#9571)
    
    * split non-gemm specific generator code to gen_tensor_op.py
    
    commit 250f915652e72e0012e9aa6ce0b6ef337d3da845
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sun Nov 14 06:44:52 2021 +0900
    
        remove conv2d stuff
    
    commit 1a6b27c438472f13acd4a0f466d78f293415e076
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sun Nov 14 06:41:31 2021 +0900
    
        remove unused import
    
    commit f7c3b5a191b8c73e8b178c32f6d3182fb0f697d6
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sun Nov 14 06:37:07 2021 +0900
    
        add profiler boilarplate for conv2d
    
    commit ca1ae274fb8f96a1dcde688deaf15339fe5604fb
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sun Nov 14 06:22:06 2021 +0900
    
        introduce gen_tensor_op.py
    
    commit 37bb918e0873f04457c29479eb21a530b7052217
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sun Nov 14 05:45:41 2021 +0900
    
        more conv2d code
    
    commit 5c00398892c99cb2a03be51f75878992663432dd
    Author: Masahiro Masuda <ma...@gmail.com>
    Date:   Sun Nov 14 05:13:30 2021 +0900
    
        Begin conv2d support
    
    * fix
    
    * use functools.partial
    
    * remove unused import
---
 python/tvm/contrib/cutlass/gen_gemm.py             | 230 ++-------------------
 .../cutlass/{gen_gemm.py => gen_tensor_op.py}      | 202 +-----------------
 tests/python/contrib/test_cutlass.py               |   2 +-
 3 files changed, 30 insertions(+), 404 deletions(-)

diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py
index 1ed4bfe..4025354 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -15,37 +15,29 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name
-"""Kernel generator and profiler for CUTLASS."""
-import logging
-import os
+"""GEMM kernel generator and profiler for CUTLASS."""
+from functools import partial
 import re
-import tempfile
-import subprocess
-import multiprocessing
 from .gemm_operation import GemmOperation, EmitGemmInstance
 from .gemm_profiler import GemmProfilerEmitter
+from .gen_tensor_op import (
+    ProfilerEngine,
+    generate_sm75_tensor_op_1688,
+    generate_sm80_tensor_op_16816,
+)
 from .library import (
     EpilogueFunctor,
     SwizzlingFunctor,
     TensorDescription,
     DataTypeTag,
     LayoutType,
-    MathInstruction,
-    DataType,
-    OpcodeClass,
-    MathOperation,
-    TileDescription,
 )
 
-logger = logging.getLogger("cutlass")
-
 
 def create_gemm_operator(
-    layouts,
     tile_descriptions,
     data_type,
     alignment_constraints,
-    epilogue_functor=EpilogueFunctor.LinearCombination,
     swizzling_functor=SwizzlingFunctor.Identity8,
     batched=False,
 ):
@@ -59,6 +51,10 @@ def create_gemm_operator(
     if batched:
         swizzling_functor = SwizzlingFunctor.Batched
 
+    layouts = [
+        (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
+    ]
+
     for layout in layouts:
         for tile_description in tile_descriptions:
             for alignment in alignment_constraints:
@@ -76,7 +72,7 @@ def create_gemm_operator(
                     B,
                     C,
                     element_epilogue,
-                    epilogue_functor,
+                    EpilogueFunctor.LinearCombination,
                     swizzling_functor,
                 )
                 op_bias = GemmOperation(
@@ -110,7 +106,6 @@ def create_gemm_operator(
                     swizzling_functor,
                 )
 
-                kernel_emitter = EmitGemmInstance()
                 op_entry["op"] = op
                 op_entry["name"] = op.procedural_name()
                 op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
@@ -134,141 +129,12 @@ def create_gemm_operator(
     return ret
 
 
-def generate_tensor_op_common(
-    math_instructions, alignment_constraints, get_tile_descriptions, batched=False
-):
-    """Common kernel generator to be used by archtecture specific generators."""
-    ops = []
-    layouts = [
-        (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
-    ]
-    for math_inst in math_instructions:
-        tile_descriptions = get_tile_descriptions(math_inst)
-        data_type = [
-            math_inst.element_a,
-            math_inst.element_b,
-            math_inst.element_accumulator,
-            math_inst.element_accumulator,
-        ]
-
-        out = create_gemm_operator(
-            layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
-        )
-
-        ops.extend(out)
-
-    return ops
-
-
-def generate_sm75_tensor_op_1688(out_dtype, batched=False):
-    """Generate GEMM kernels for Turing."""
-    assert out_dtype in ["float32", "float16"]
-    math_instructions = {
-        "float32": [
-            MathInstruction(
-                [16, 8, 8],
-                DataType.f16,
-                DataType.f16,
-                DataType.f32,
-                OpcodeClass.TensorOp,
-                MathOperation.multiply_add,
-            )
-        ],
-        "float16": [
-            MathInstruction(
-                [16, 8, 8],
-                DataType.f16,
-                DataType.f16,
-                DataType.f16,
-                OpcodeClass.TensorOp,
-                MathOperation.multiply_add,
-            )
-        ],
-    }[out_dtype]
-
-    alignment_constraints = [8, 4, 2, 1]
-
-    def get_tile_descriptions(math_inst):
-        min_cc = 75
-        max_cc = 1024
-        return [
-            TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc),
-        ]
-
-    return generate_tensor_op_common(
-        math_instructions, alignment_constraints, get_tile_descriptions, batched
-    )
-
-
-def generate_sm80_tensor_op_16816(out_dtype, batched=False):
-    """Generate GEMM kernels for Ampere."""
-    assert out_dtype in ["float32", "float16"]
-    math_instructions = {
-        "float32": [
-            MathInstruction(
-                [16, 8, 16],
-                DataType.f16,
-                DataType.f16,
-                DataType.f32,
-                OpcodeClass.TensorOp,
-                MathOperation.multiply_add,
-            )
-        ],
-        "float16": [
-            MathInstruction(
-                [16, 8, 16],
-                DataType.f16,
-                DataType.f16,
-                DataType.f16,
-                OpcodeClass.TensorOp,
-                MathOperation.multiply_add,
-            )
-        ],
-    }[out_dtype]
-
-    alignment_constraints = [8, 4, 2]
-
-    def get_tile_descriptions(math_inst):
-        min_cc = 80
-        max_cc = 1024
-        max_cc_smem_limited = 80
-        return [
-            TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
-            TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
-            TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
-            TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
-            TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
-            TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
-            TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
-            TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
-        ]
-
-    return generate_tensor_op_common(
-        math_instructions, alignment_constraints, get_tile_descriptions, batched
-    )
-
-
 GENERATOR_FUNC_TABLE = {
     75: generate_sm75_tensor_op_1688,
     80: generate_sm80_tensor_op_16816,
 }
 
+
 # TODO(masahi): A sensible way to pick reasonable default kernels
 DEFAULT_KERNELS = {
     75: {
@@ -282,67 +148,7 @@ DEFAULT_KERNELS = {
 }
 
 
-class ProfilerEngine:
-    """Compile and run a given profiler executable."""
-
-    def __init__(self, cuda_arch, cutlass_path, binary_prefix):
-        self.cuda_arch = cuda_arch
-        self.binary_prefix = binary_prefix
-        self.cutlass = cutlass_path
-        self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format(
-            cutlass=cutlass_path
-        )
-        self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
-        self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format(
-            arch=cuda_arch
-        )
-        self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing"
-        self.cmd = "nvcc {cflags} {src} -o {output}"
-
-    def _compile(self, op):
-        os.makedirs(self.binary_prefix, exist_ok=True)
-        opath = os.path.join(self.binary_prefix, op["name"])
-        if os.path.exists(opath):
-            return
-        fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu")
-        fi.write(op["src"])
-        fi.close()
-        cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath)
-        os.system(cmd)
-        os.unlink(fi.name)
-
-    def compile_all(self, ops, use_multiprocessing=False):
-        """Compile all profiler executables."""
-        if use_multiprocessing:
-            pool = multiprocessing.Pool(multiprocessing.cpu_count())
-            pool.map(self._compile, ops)
-        else:
-            for op in ops:
-                self._compile(op)
-
-    def evaluate(self, op, args):
-        """Run the profiler executable corresponding to op_name with args."""
-        op_name = op["name"]
-        opath = os.path.join(self.binary_prefix, op_name)
-        if not os.path.exists(opath):
-            self._compile(op)
-        cmd = [opath]
-        if args is not None:
-            cmd.append(str(args[0]))
-            cmd.append(str(args[1]))
-            cmd.append(str(args[2]))
-            if len(args) > 3:
-                cmd.append(str(args[3]))
-        try:
-            sp = subprocess.run(cmd, capture_output=True, check=True)
-            rt = float(sp.stdout)
-            logger.info("%s, %f", op_name, rt)
-        except subprocess.CalledProcessError:
-            rt = -1
-        return rt
-
-
-class CutlassGemmProfiler(object):
+class CutlassGemmProfiler:
     """Profile all candidate kernels and select the best one."""
 
     def __init__(self, sm, cutlass_path, binary_path):
@@ -364,7 +170,9 @@ class CutlassGemmProfiler(object):
         """Return the default kernel for the requested architecture.
         For now, the default kernel was picked arbitrary.
         """
-        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
+        ops = GENERATOR_FUNC_TABLE[self.sm](
+            out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
+        )
         default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
         filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
         assert len(filtered) == 1
@@ -380,7 +188,9 @@ class CutlassGemmProfiler(object):
         if (M, N, K) in self.cache:
             return self.cache[(M, N, K)]
 
-        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
+        ops = GENERATOR_FUNC_TABLE[self.sm](
+            out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
+        )
         ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
 
         for op in ops:
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_tensor_op.py
similarity index 52%
copy from python/tvm/contrib/cutlass/gen_gemm.py
copy to python/tvm/contrib/cutlass/gen_tensor_op.py
index 1ed4bfe..c822151 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -15,21 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 # pylint: disable=invalid-name
-"""Kernel generator and profiler for CUTLASS."""
+"""Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
 import logging
 import os
-import re
 import tempfile
 import subprocess
 import multiprocessing
-from .gemm_operation import GemmOperation, EmitGemmInstance
-from .gemm_profiler import GemmProfilerEmitter
 from .library import (
-    EpilogueFunctor,
-    SwizzlingFunctor,
-    TensorDescription,
-    DataTypeTag,
-    LayoutType,
     MathInstruction,
     DataType,
     OpcodeClass,
@@ -40,108 +32,11 @@ from .library import (
 logger = logging.getLogger("cutlass")
 
 
-def create_gemm_operator(
-    layouts,
-    tile_descriptions,
-    data_type,
-    alignment_constraints,
-    epilogue_functor=EpilogueFunctor.LinearCombination,
-    swizzling_functor=SwizzlingFunctor.Identity8,
-    batched=False,
-):
-    """Exhaustively instantiate all kernels from a given configuration."""
-    ret = []
-    kernel_emitter = EmitGemmInstance()
-    profiler_emitter = GemmProfilerEmitter()
-
-    element_a, element_b, element_c, element_epilogue = data_type
-
-    if batched:
-        swizzling_functor = SwizzlingFunctor.Batched
-
-    for layout in layouts:
-        for tile_description in tile_descriptions:
-            for alignment in alignment_constraints:
-                alignment_c = min(8, alignment)
-
-                A = TensorDescription(element_a, layout[0], alignment)
-                B = TensorDescription(element_b, layout[1], alignment)
-                C = TensorDescription(element_c, layout[2], alignment_c)
-
-                op_entry = {}
-                op = GemmOperation(
-                    tile_description.minimum_compute_capability,
-                    tile_description,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    epilogue_functor,
-                    swizzling_functor,
-                )
-                op_bias = GemmOperation(
-                    tile_description.minimum_compute_capability,
-                    tile_description,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    EpilogueFunctor.LinearCombinationBias,
-                    swizzling_functor,
-                )
-                op_bias_relu = GemmOperation(
-                    tile_description.minimum_compute_capability,
-                    tile_description,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    EpilogueFunctor.LinearCombinationRelu,
-                    swizzling_functor,
-                )
-                op_bias_gelu = GemmOperation(
-                    tile_description.minimum_compute_capability,
-                    tile_description,
-                    A,
-                    B,
-                    C,
-                    element_epilogue,
-                    EpilogueFunctor.LinearCombinationGelu,
-                    swizzling_functor,
-                )
-
-                kernel_emitter = EmitGemmInstance()
-                op_entry["op"] = op
-                op_entry["name"] = op.procedural_name()
-                op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
-                op_entry["opdef_bias"] = kernel_emitter.emit(
-                    op_bias, no_beta_scaling=True, batched=batched
-                )
-                op_entry["opdef_bias_relu"] = kernel_emitter.emit(
-                    op_bias_relu, no_beta_scaling=True, batched=batched
-                )
-                op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu, batched=batched)
-                op_entry["src"] = profiler_emitter.emit(
-                    op.procedural_name(),
-                    kernel_emitter.emit(op, batched=False),
-                    DataTypeTag[element_a],
-                    DataTypeTag[element_b],
-                    DataTypeTag[element_c],
-                    op.leading_dim(),
-                )
-                op_entry["runtime"] = 9999999
-                ret.append(op_entry)
-    return ret
-
-
 def generate_tensor_op_common(
-    math_instructions, alignment_constraints, get_tile_descriptions, batched=False
+    math_instructions, alignment_constraints, get_tile_descriptions, op_creator
 ):
     """Common kernel generator to be used by archtecture specific generators."""
     ops = []
-    layouts = [
-        (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
-    ]
     for math_inst in math_instructions:
         tile_descriptions = get_tile_descriptions(math_inst)
         data_type = [
@@ -151,17 +46,15 @@ def generate_tensor_op_common(
             math_inst.element_accumulator,
         ]
 
-        out = create_gemm_operator(
-            layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
-        )
+        out = op_creator(tile_descriptions, data_type, alignment_constraints)
 
         ops.extend(out)
 
     return ops
 
 
-def generate_sm75_tensor_op_1688(out_dtype, batched=False):
-    """Generate GEMM kernels for Turing."""
+def generate_sm75_tensor_op_1688(out_dtype, op_creator):
+    """Generate GEMM or Conv2D kernels for Turing."""
     assert out_dtype in ["float32", "float16"]
     math_instructions = {
         "float32": [
@@ -202,12 +95,12 @@ def generate_sm75_tensor_op_1688(out_dtype, batched=False):
         ]
 
     return generate_tensor_op_common(
-        math_instructions, alignment_constraints, get_tile_descriptions, batched
+        math_instructions, alignment_constraints, get_tile_descriptions, op_creator
     )
 
 
-def generate_sm80_tensor_op_16816(out_dtype, batched=False):
-    """Generate GEMM kernels for Ampere."""
+def generate_sm80_tensor_op_16816(out_dtype, op_creator):
+    """Generate GEMM or Conv2D kernels for Ampere."""
     assert out_dtype in ["float32", "float16"]
     math_instructions = {
         "float32": [
@@ -260,28 +153,10 @@ def generate_sm80_tensor_op_16816(out_dtype, batched=False):
         ]
 
     return generate_tensor_op_common(
-        math_instructions, alignment_constraints, get_tile_descriptions, batched
+        math_instructions, alignment_constraints, get_tile_descriptions, op_creator
     )
 
 
-GENERATOR_FUNC_TABLE = {
-    75: generate_sm75_tensor_op_1688,
-    80: generate_sm80_tensor_op_16816,
-}
-
-# TODO(masahi): A sensible way to pick reasonable default kernels
-DEFAULT_KERNELS = {
-    75: {
-        "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4",
-        "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4",
-    },
-    80: {
-        "float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4",
-        "float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4",
-    },
-}
-
-
 class ProfilerEngine:
     """Compile and run a given profiler executable."""
 
@@ -340,62 +215,3 @@ class ProfilerEngine:
         except subprocess.CalledProcessError:
             rt = -1
         return rt
-
-
-class CutlassGemmProfiler(object):
-    """Profile all candidate kernels and select the best one."""
-
-    def __init__(self, sm, cutlass_path, binary_path):
-        assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm
-        self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
-        self.sm = sm
-        self.cache = {}
-
-    def check_align(self, op_name, M):
-        """Filter out kernels that cannot be supported."""
-        aligns = re.findall(r"align[1|2|4|8]", op_name)
-        assert len(aligns) == 1
-        align = int(aligns[0][-1])
-        if M % align != 0:
-            return False
-        return True
-
-    def get_default(self, out_dtype, batched=False):
-        """Return the default kernel for the requested architecture.
-        For now, the default kernel was picked arbitrary.
-        """
-        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
-        default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
-        filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
-        assert len(filtered) == 1
-        return filtered[0]
-
-    def profile(
-        self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, batched=False
-    ):
-        """Profile and select the best kernel from candidate kernels.
-        If profile_all is False, return immediately after the first applicable kernel is found.
-        If use_multiprocessing is True, compile all profiler executables in parallel.
-        """
-        if (M, N, K) in self.cache:
-            return self.cache[(M, N, K)]
-
-        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
-        ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
-
-        for op in ops:
-            op["runtime"] = -1
-
-        if profile_all:
-            self.engine.compile_all(ops, use_multiprocessing)
-
-        for op in ops:
-            out = self.engine.evaluate(op, [M, N, K])
-            op["runtime"] = out
-            if out > 0 and profile_all is False:
-                break
-
-        valid_ops = filter(lambda op: op["runtime"] > 0, ops)
-        output = sorted(valid_ops, key=lambda i: i["runtime"])
-        self.cache[(M, N, K)] = output[0]
-        return output[0]
diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py
index 5a1ff8b..6f27d57 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -213,7 +213,7 @@ def verify_batch_matmul(
 
     np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
 
-    if True:
+    if run_benchmark:
         print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
         print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))