You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/11/25 23:54:24 UTC
[tvm] branch main updated: [CUTLASS] Refactor GEMM generator in preparation for conv2d (#9571)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new adf560e [CUTLASS] Refactor GEMM generator in preparation for conv2d (#9571)
adf560e is described below
commit adf560ebed8465c22bf58f406d0a8d20663cdd1d
Author: masahi <ma...@gmail.com>
AuthorDate: Fri Nov 26 08:53:55 2021 +0900
[CUTLASS] Refactor GEMM generator in preparation for conv2d (#9571)
* split non-gemm specific generator code to gen_tensor_op.py
commit 250f915652e72e0012e9aa6ce0b6ef337d3da845
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sun Nov 14 06:44:52 2021 +0900
remove conv2d stuff
commit 1a6b27c438472f13acd4a0f466d78f293415e076
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sun Nov 14 06:41:31 2021 +0900
remove unused import
commit f7c3b5a191b8c73e8b178c32f6d3182fb0f697d6
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sun Nov 14 06:37:07 2021 +0900
add profiler boilarplate for conv2d
commit ca1ae274fb8f96a1dcde688deaf15339fe5604fb
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sun Nov 14 06:22:06 2021 +0900
introduce gen_tensor_op.py
commit 37bb918e0873f04457c29479eb21a530b7052217
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sun Nov 14 05:45:41 2021 +0900
more conv2d code
commit 5c00398892c99cb2a03be51f75878992663432dd
Author: Masahiro Masuda <ma...@gmail.com>
Date: Sun Nov 14 05:13:30 2021 +0900
Begin conv2d support
* fix
* use functools.partial
* remove unused import
---
python/tvm/contrib/cutlass/gen_gemm.py | 230 ++-------------------
.../cutlass/{gen_gemm.py => gen_tensor_op.py} | 202 +-----------------
tests/python/contrib/test_cutlass.py | 2 +-
3 files changed, 30 insertions(+), 404 deletions(-)
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py
index 1ed4bfe..4025354 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_gemm.py
@@ -15,37 +15,29 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
-"""Kernel generator and profiler for CUTLASS."""
-import logging
-import os
+"""GEMM kernel generator and profiler for CUTLASS."""
+from functools import partial
import re
-import tempfile
-import subprocess
-import multiprocessing
from .gemm_operation import GemmOperation, EmitGemmInstance
from .gemm_profiler import GemmProfilerEmitter
+from .gen_tensor_op import (
+ ProfilerEngine,
+ generate_sm75_tensor_op_1688,
+ generate_sm80_tensor_op_16816,
+)
from .library import (
EpilogueFunctor,
SwizzlingFunctor,
TensorDescription,
DataTypeTag,
LayoutType,
- MathInstruction,
- DataType,
- OpcodeClass,
- MathOperation,
- TileDescription,
)
-logger = logging.getLogger("cutlass")
-
def create_gemm_operator(
- layouts,
tile_descriptions,
data_type,
alignment_constraints,
- epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
batched=False,
):
@@ -59,6 +51,10 @@ def create_gemm_operator(
if batched:
swizzling_functor = SwizzlingFunctor.Batched
+ layouts = [
+ (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
+ ]
+
for layout in layouts:
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
@@ -76,7 +72,7 @@ def create_gemm_operator(
B,
C,
element_epilogue,
- epilogue_functor,
+ EpilogueFunctor.LinearCombination,
swizzling_functor,
)
op_bias = GemmOperation(
@@ -110,7 +106,6 @@ def create_gemm_operator(
swizzling_functor,
)
- kernel_emitter = EmitGemmInstance()
op_entry["op"] = op
op_entry["name"] = op.procedural_name()
op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
@@ -134,141 +129,12 @@ def create_gemm_operator(
return ret
-def generate_tensor_op_common(
- math_instructions, alignment_constraints, get_tile_descriptions, batched=False
-):
- """Common kernel generator to be used by archtecture specific generators."""
- ops = []
- layouts = [
- (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
- ]
- for math_inst in math_instructions:
- tile_descriptions = get_tile_descriptions(math_inst)
- data_type = [
- math_inst.element_a,
- math_inst.element_b,
- math_inst.element_accumulator,
- math_inst.element_accumulator,
- ]
-
- out = create_gemm_operator(
- layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
- )
-
- ops.extend(out)
-
- return ops
-
-
-def generate_sm75_tensor_op_1688(out_dtype, batched=False):
- """Generate GEMM kernels for Turing."""
- assert out_dtype in ["float32", "float16"]
- math_instructions = {
- "float32": [
- MathInstruction(
- [16, 8, 8],
- DataType.f16,
- DataType.f16,
- DataType.f32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- )
- ],
- "float16": [
- MathInstruction(
- [16, 8, 8],
- DataType.f16,
- DataType.f16,
- DataType.f16,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- )
- ],
- }[out_dtype]
-
- alignment_constraints = [8, 4, 2, 1]
-
- def get_tile_descriptions(math_inst):
- min_cc = 75
- max_cc = 1024
- return [
- TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc),
- ]
-
- return generate_tensor_op_common(
- math_instructions, alignment_constraints, get_tile_descriptions, batched
- )
-
-
-def generate_sm80_tensor_op_16816(out_dtype, batched=False):
- """Generate GEMM kernels for Ampere."""
- assert out_dtype in ["float32", "float16"]
- math_instructions = {
- "float32": [
- MathInstruction(
- [16, 8, 16],
- DataType.f16,
- DataType.f16,
- DataType.f32,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- )
- ],
- "float16": [
- MathInstruction(
- [16, 8, 16],
- DataType.f16,
- DataType.f16,
- DataType.f16,
- OpcodeClass.TensorOp,
- MathOperation.multiply_add,
- )
- ],
- }[out_dtype]
-
- alignment_constraints = [8, 4, 2]
-
- def get_tile_descriptions(math_inst):
- min_cc = 80
- max_cc = 1024
- max_cc_smem_limited = 80
- return [
- TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc),
- TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc_smem_limited),
- TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc_smem_limited),
- TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc_smem_limited),
- TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc_smem_limited),
- TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
- TileDescription([64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
- ]
-
- return generate_tensor_op_common(
- math_instructions, alignment_constraints, get_tile_descriptions, batched
- )
-
-
GENERATOR_FUNC_TABLE = {
75: generate_sm75_tensor_op_1688,
80: generate_sm80_tensor_op_16816,
}
+
# TODO(masahi): A sensible way to pick reasonable default kernels
DEFAULT_KERNELS = {
75: {
@@ -282,67 +148,7 @@ DEFAULT_KERNELS = {
}
-class ProfilerEngine:
- """Compile and run a given profiler executable."""
-
- def __init__(self, cuda_arch, cutlass_path, binary_prefix):
- self.cuda_arch = cuda_arch
- self.binary_prefix = binary_prefix
- self.cutlass = cutlass_path
- self.cflags = "-I{cutlass}/include -I{cutlass}/tools/util/include -O3 -std=c++11".format(
- cutlass=cutlass_path
- )
- self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1"
- self.cflags += " -gencode=arch=compute_{arch},code=[sm_{arch},compute_{arch}]".format(
- arch=cuda_arch
- )
- self.cflags += " -Xcompiler=-Wconversion -Xcompiler=-fno-strict-aliasing"
- self.cmd = "nvcc {cflags} {src} -o {output}"
-
- def _compile(self, op):
- os.makedirs(self.binary_prefix, exist_ok=True)
- opath = os.path.join(self.binary_prefix, op["name"])
- if os.path.exists(opath):
- return
- fi = tempfile.NamedTemporaryFile("w", delete=False, suffix=".cu")
- fi.write(op["src"])
- fi.close()
- cmd = self.cmd.format(cflags=self.cflags, src=fi.name, output=opath)
- os.system(cmd)
- os.unlink(fi.name)
-
- def compile_all(self, ops, use_multiprocessing=False):
- """Compile all profiler executables."""
- if use_multiprocessing:
- pool = multiprocessing.Pool(multiprocessing.cpu_count())
- pool.map(self._compile, ops)
- else:
- for op in ops:
- self._compile(op)
-
- def evaluate(self, op, args):
- """Run the profiler executable corresponding to op_name with args."""
- op_name = op["name"]
- opath = os.path.join(self.binary_prefix, op_name)
- if not os.path.exists(opath):
- self._compile(op)
- cmd = [opath]
- if args is not None:
- cmd.append(str(args[0]))
- cmd.append(str(args[1]))
- cmd.append(str(args[2]))
- if len(args) > 3:
- cmd.append(str(args[3]))
- try:
- sp = subprocess.run(cmd, capture_output=True, check=True)
- rt = float(sp.stdout)
- logger.info("%s, %f", op_name, rt)
- except subprocess.CalledProcessError:
- rt = -1
- return rt
-
-
-class CutlassGemmProfiler(object):
+class CutlassGemmProfiler:
"""Profile all candidate kernels and select the best one."""
def __init__(self, sm, cutlass_path, binary_path):
@@ -364,7 +170,9 @@ class CutlassGemmProfiler(object):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
- ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
+ ops = GENERATOR_FUNC_TABLE[self.sm](
+ out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
+ )
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
assert len(filtered) == 1
@@ -380,7 +188,9 @@ class CutlassGemmProfiler(object):
if (M, N, K) in self.cache:
return self.cache[(M, N, K)]
- ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
+ ops = GENERATOR_FUNC_TABLE[self.sm](
+ out_dtype, op_creator=partial(create_gemm_operator, batched=batched)
+ )
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
for op in ops:
diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_tensor_op.py
similarity index 52%
copy from python/tvm/contrib/cutlass/gen_gemm.py
copy to python/tvm/contrib/cutlass/gen_tensor_op.py
index 1ed4bfe..c822151 100644
--- a/python/tvm/contrib/cutlass/gen_gemm.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -15,21 +15,13 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
-"""Kernel generator and profiler for CUTLASS."""
+"""Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
import logging
import os
-import re
import tempfile
import subprocess
import multiprocessing
-from .gemm_operation import GemmOperation, EmitGemmInstance
-from .gemm_profiler import GemmProfilerEmitter
from .library import (
- EpilogueFunctor,
- SwizzlingFunctor,
- TensorDescription,
- DataTypeTag,
- LayoutType,
MathInstruction,
DataType,
OpcodeClass,
@@ -40,108 +32,11 @@ from .library import (
logger = logging.getLogger("cutlass")
-def create_gemm_operator(
- layouts,
- tile_descriptions,
- data_type,
- alignment_constraints,
- epilogue_functor=EpilogueFunctor.LinearCombination,
- swizzling_functor=SwizzlingFunctor.Identity8,
- batched=False,
-):
- """Exhaustively instantiate all kernels from a given configuration."""
- ret = []
- kernel_emitter = EmitGemmInstance()
- profiler_emitter = GemmProfilerEmitter()
-
- element_a, element_b, element_c, element_epilogue = data_type
-
- if batched:
- swizzling_functor = SwizzlingFunctor.Batched
-
- for layout in layouts:
- for tile_description in tile_descriptions:
- for alignment in alignment_constraints:
- alignment_c = min(8, alignment)
-
- A = TensorDescription(element_a, layout[0], alignment)
- B = TensorDescription(element_b, layout[1], alignment)
- C = TensorDescription(element_c, layout[2], alignment_c)
-
- op_entry = {}
- op = GemmOperation(
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- epilogue_functor,
- swizzling_functor,
- )
- op_bias = GemmOperation(
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- EpilogueFunctor.LinearCombinationBias,
- swizzling_functor,
- )
- op_bias_relu = GemmOperation(
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- EpilogueFunctor.LinearCombinationRelu,
- swizzling_functor,
- )
- op_bias_gelu = GemmOperation(
- tile_description.minimum_compute_capability,
- tile_description,
- A,
- B,
- C,
- element_epilogue,
- EpilogueFunctor.LinearCombinationGelu,
- swizzling_functor,
- )
-
- kernel_emitter = EmitGemmInstance()
- op_entry["op"] = op
- op_entry["name"] = op.procedural_name()
- op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
- op_entry["opdef_bias"] = kernel_emitter.emit(
- op_bias, no_beta_scaling=True, batched=batched
- )
- op_entry["opdef_bias_relu"] = kernel_emitter.emit(
- op_bias_relu, no_beta_scaling=True, batched=batched
- )
- op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu, batched=batched)
- op_entry["src"] = profiler_emitter.emit(
- op.procedural_name(),
- kernel_emitter.emit(op, batched=False),
- DataTypeTag[element_a],
- DataTypeTag[element_b],
- DataTypeTag[element_c],
- op.leading_dim(),
- )
- op_entry["runtime"] = 9999999
- ret.append(op_entry)
- return ret
-
-
def generate_tensor_op_common(
- math_instructions, alignment_constraints, get_tile_descriptions, batched=False
+ math_instructions, alignment_constraints, get_tile_descriptions, op_creator
):
"""Common kernel generator to be used by archtecture specific generators."""
ops = []
- layouts = [
- (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor),
- ]
for math_inst in math_instructions:
tile_descriptions = get_tile_descriptions(math_inst)
data_type = [
@@ -151,17 +46,15 @@ def generate_tensor_op_common(
math_inst.element_accumulator,
]
- out = create_gemm_operator(
- layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
- )
+ out = op_creator(tile_descriptions, data_type, alignment_constraints)
ops.extend(out)
return ops
-def generate_sm75_tensor_op_1688(out_dtype, batched=False):
- """Generate GEMM kernels for Turing."""
+def generate_sm75_tensor_op_1688(out_dtype, op_creator):
+ """Generate GEMM or Conv2D kernels for Turing."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
"float32": [
@@ -202,12 +95,12 @@ def generate_sm75_tensor_op_1688(out_dtype, batched=False):
]
return generate_tensor_op_common(
- math_instructions, alignment_constraints, get_tile_descriptions, batched
+ math_instructions, alignment_constraints, get_tile_descriptions, op_creator
)
-def generate_sm80_tensor_op_16816(out_dtype, batched=False):
- """Generate GEMM kernels for Ampere."""
+def generate_sm80_tensor_op_16816(out_dtype, op_creator):
+ """Generate GEMM or Conv2D kernels for Ampere."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
"float32": [
@@ -260,28 +153,10 @@ def generate_sm80_tensor_op_16816(out_dtype, batched=False):
]
return generate_tensor_op_common(
- math_instructions, alignment_constraints, get_tile_descriptions, batched
+ math_instructions, alignment_constraints, get_tile_descriptions, op_creator
)
-GENERATOR_FUNC_TABLE = {
- 75: generate_sm75_tensor_op_1688,
- 80: generate_sm80_tensor_op_16816,
-}
-
-# TODO(masahi): A sensible way to pick reasonable default kernels
-DEFAULT_KERNELS = {
- 75: {
- "float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align4",
- "float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align4",
- },
- 80: {
- "float16": "cutlass_tensorop_h16816gemm_128x256_32x3_tn_align4",
- "float32": "cutlass_tensorop_s16816gemm_f16_128x128_32x3_tn_align4",
- },
-}
-
-
class ProfilerEngine:
"""Compile and run a given profiler executable."""
@@ -340,62 +215,3 @@ class ProfilerEngine:
except subprocess.CalledProcessError:
rt = -1
return rt
-
-
-class CutlassGemmProfiler(object):
- """Profile all candidate kernels and select the best one."""
-
- def __init__(self, sm, cutlass_path, binary_path):
- assert sm in GENERATOR_FUNC_TABLE and sm in DEFAULT_KERNELS, "sm%d not supported yet." % sm
- self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
- self.sm = sm
- self.cache = {}
-
- def check_align(self, op_name, M):
- """Filter out kernels that cannot be supported."""
- aligns = re.findall(r"align[1|2|4|8]", op_name)
- assert len(aligns) == 1
- align = int(aligns[0][-1])
- if M % align != 0:
- return False
- return True
-
- def get_default(self, out_dtype, batched=False):
- """Return the default kernel for the requested architecture.
- For now, the default kernel was picked arbitrary.
- """
- ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
- default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
- filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
- assert len(filtered) == 1
- return filtered[0]
-
- def profile(
- self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, batched=False
- ):
- """Profile and select the best kernel from candidate kernels.
- If profile_all is False, return immediately after the first applicable kernel is found.
- If use_multiprocessing is True, compile all profiler executables in parallel.
- """
- if (M, N, K) in self.cache:
- return self.cache[(M, N, K)]
-
- ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
- ops = list(filter(lambda op: self.check_align(op["name"], M), ops))
-
- for op in ops:
- op["runtime"] = -1
-
- if profile_all:
- self.engine.compile_all(ops, use_multiprocessing)
-
- for op in ops:
- out = self.engine.evaluate(op, [M, N, K])
- op["runtime"] = out
- if out > 0 and profile_all is False:
- break
-
- valid_ops = filter(lambda op: op["runtime"] > 0, ops)
- output = sorted(valid_ops, key=lambda i: i["runtime"])
- self.cache[(M, N, K)] = output[0]
- return output[0]
diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py
index 5a1ff8b..6f27d57 100644
--- a/tests/python/contrib/test_cutlass.py
+++ b/tests/python/contrib/test_cutlass.py
@@ -213,7 +213,7 @@ def verify_batch_matmul(
np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol)
- if True:
+ if run_benchmark:
print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600))
print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600))