You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/07/30 23:35:31 UTC
[tvm] branch main updated: [ROOFLINE] Add CUDA support to roofline analysis (#12205)
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 961a7c70d7 [ROOFLINE] Add CUDA support to roofline analysis (#12205)
961a7c70d7 is described below
commit 961a7c70d75c81503c8c1d7c2e0db66bac4a1859
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Sat Jul 30 16:35:25 2022 -0700
[ROOFLINE] Add CUDA support to roofline analysis (#12205)
* [ROOFLINE] Add CUDA support to roofline analysis
Add functions to estimate peak flops and bandwidth for CUDA. Add a new
registration mechanism to the roofline analysis to support adding any
target. This mechanism uses generic functions with overrides. New
targets only need to add `estimate_peak_bandwidth` and
`estimate_peak_flops` functions.
Also fix cuda codegen and tensorcore_infer_fragment.cc to support
filling matrix_a and matrix_b fragments.
* formatiing
* move statement back inside loops
* print out report for debugging
* default to avx2
* review comments
---
python/tvm/utils/__init__.py | 2 +-
.../utils/{roofline.py => roofline/__init__.py} | 266 +++------------------
python/tvm/utils/roofline/cuda.py | 236 ++++++++++++++++++
python/tvm/utils/roofline/registry.py | 83 +++++++
python/tvm/utils/roofline/x86.py | 254 ++++++++++++++++++++
src/target/source/codegen_cuda.cc | 2 +
src/tir/ir/specialize.cc | 1 +
src/tir/transforms/tensorcore_infer_fragment.cc | 15 +-
tests/python/unittest/test_roofline.py | 121 ++++++++++
tests/python/unittest/test_runtime_profiling.py | 98 --------
10 files changed, 736 insertions(+), 342 deletions(-)
diff --git a/python/tvm/utils/__init__.py b/python/tvm/utils/__init__.py
index 3c1703c244..33abc352b0 100644
--- a/python/tvm/utils/__init__.py
+++ b/python/tvm/utils/__init__.py
@@ -16,4 +16,4 @@
# under the License.
"""Utilities operating at a graph/model or other "high" level"""
-from .roofline import estimate_peak_bandwidth, estimate_peak_fma_flops, roofline_analysis
+from .roofline import roofline_analysis
diff --git a/python/tvm/utils/roofline.py b/python/tvm/utils/roofline/__init__.py
similarity index 51%
rename from python/tvm/utils/roofline.py
rename to python/tvm/utils/roofline/__init__.py
index 7323149193..a54f5ed41d 100644
--- a/python/tvm/utils/roofline.py
+++ b/python/tvm/utils/roofline/__init__.py
@@ -18,15 +18,17 @@
from typing import Dict, Union, Optional
import numpy as np
-from .. import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
-from ..target import Target
-from ..runtime import profiler_vm, profiling, Device, num_threads
-from ..script import tir as T
-from ..ir.instrument import pass_instrument
-from ..ir.expr import GlobalVar
-from ..rpc.base import RPC_SESS_MASK
-from ..rpc.client import RPCSession
-from ..contrib import utils
+from ... import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
+from ...target import Target
+from ...runtime import profiler_vm, profiling, Device, num_threads
+from ...script import tir as T
+from ...ir.instrument import pass_instrument
+from ...ir.expr import GlobalVar
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from ...contrib import utils
+
+from . import registry, cuda, x86
def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=None):
@@ -47,231 +49,6 @@ def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=Non
return args
-def _detect_vec_width_registers(
- target: Target, vec_width: Optional[int], num_vector_registers: Optional[int]
-):
- """Get the vector width and number of vector registers for a target.
-
- Parameters
- ----------
- target : Target
- Target to detect vector width and registers for.
- vec_width : Optional[int]
- If None, try and detect vector width from target. Otherwise provided input is used.
- num_vector_registers : Optional[int]
- If None, try and number of vector registers from target. Otherwise provided input is used.
-
- Returns
- -------
- vec_width: int
- Width of a vector register on `target`.
- num_vector_registers: int
- Number of vector registers on `target`.
- """
- if vec_width is None:
- # Only implemented for x86 so far...
- if (
- str(target.kind) == "llvm"
- and target.device_name == ""
- and len(target.keys) == 1
- and target.keys[0] == "cpu"
- ):
- with target:
- vec_width = topi.x86.utils.get_simd_32bit_lanes() # in number of float32s
- else:
- raise RuntimeError(f"Cannot determine vector width for target {target}")
- if num_vector_registers is None:
- if target.device_name == "": # indicates x86
- num_vector_registers = 16 # Assuming for all platforms, probably wrong on older ones
- else:
- raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
- return vec_width, num_vector_registers
-
-
-@T.prim_func
-def peakflops_fma_tir(
- a: T.handle,
- vec_width: T.int32,
- iters: T.int32,
- num_vector_registers: T.int32,
- threads: T.int32,
-) -> None:
- # pylint: disable=invalid-name, missing-function-docstring
- A = T.match_buffer(a, [threads, num_vector_registers, vec_width], "float32")
- for t in T.parallel(threads):
- for _j in range(iters):
- for l in T.unroll(num_vector_registers):
- # We want to use as few registers as possible, so we perform
- # all operations on the same element
- for k in T.vectorized(vec_width):
- A[t, l, k] = A[t, l, k] * A[t, l, k] + A[t, l, k]
-
-
-def estimate_peak_fma_flops(
- target: Target,
- dev: Device,
- vec_width: Optional[int] = None,
- num_vector_registers: Optional[int] = None,
- remote: Optional[RPCSession] = None,
-) -> float:
- """
- Estimate the maximum number of FLOP/s this target/device combo is capable
- of reaching by running a test program. This assumes vectorized f32 FMA
- (fused-multiply-add) instructions.
-
-
- Parameters
- ----------
- target : Target
- Target to run on. This should be as specific to the actual hardware as
- possible to make sure that LLVM generates the best vector code.
- dev : Device
- Device to run on.
- vec_width : Optional[int]
- Vector width of SIMD units on the underlying hardware. Will try to
- infer if no value is provided.
- num_vector_registers : Optional[int]
- Number of vector registers on the underlying hardware. Will try to
- infer if no value is provided.
- remote : Optional[RPCSession]
- Remote session used to upload artifacts for runtime evaluation. Must be
- the same session used to create `dev`.
-
- Returns
- -------
- float
- Approximate sustained FLOP/s of this target/device combo assuming
- vectorized f32 FMA instructions.
- """
- assert str(target.kind) == "llvm", "Only llvm targets are supported"
- vec_width, num_vector_registers = _detect_vec_width_registers(
- target, vec_width, num_vector_registers
- )
- iters = 1000000
- nthreads = num_threads()
- specialized = peakflops_fma_tir.specialize(
- {
- peakflops_fma_tir.params[1]: vec_width,
- peakflops_fma_tir.params[2]: iters,
- peakflops_fma_tir.params[3]: num_vector_registers,
- peakflops_fma_tir.params[4]: nthreads,
- }
- )
- with transform.PassContext(opt_level=3):
- f = build(specialized, target=target)
-
- # upload to remote if running over rpc
- if dev.device_type >= RPC_SESS_MASK:
- if remote is None:
- raise RuntimeError("A RPCSession must be provided when using a remote device.")
- temp = utils.tempdir()
- path = temp.relpath("peak_fma_flops.tar")
- f.export_library(path)
- remote.upload(path)
- f = remote.load_module("peak_fma_flops.tar")
- random_fill = remote.get_function("tvm.contrib.random.random_fill")
- else:
- random_fill = get_global_func("tvm.contrib.random.random_fill")
- assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
-
- a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", device=dev)
- random_fill(a)
- times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a)
- flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops
- flop_s = flops / times.min
- return flop_s
-
-
-@T.prim_func
-def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.int32) -> None:
- # pylint: disable=invalid-name, missing-function-docstring
- N = T.var("int32")
- A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
- B = T.match_buffer(b, [threads, vec_width, 4], "float32")
- # Parallelism is necessary to hit all cores/nodes
- for i in T.parallel(threads):
- for k in T.serial(N):
- for l in T.unroll(4):
- # vectorized load is necessary to hit peak bandwidth
- for j in T.vectorized(vec_width):
- # += is necessary to introduce a data dependency for all
- # elements of A, preventing the backend from removing the
- # `k` loop and setting `k` to the loop extent.
- B[i, l, j] += A[i, k, l, j]
-
-
-def estimate_peak_bandwidth(
- target: Target,
- dev: Device,
- vec_width: Optional[int] = None,
- remote: Optional[RPCSession] = None,
-) -> float:
- """Estimate peak memory bandwidth of a target/device combo.
-
- Peak bandwidth is estimated by running a small experiment on the underlying
- hardware. The peak bandwidth measurement assumes that vector instructions
- are being used to load the data.
-
- Parameters
- ----------
- target : Target
- Target to use for measurement. This target should be as specific to the
- underlying hardware as possible.
- dev : Device
- Device to measure peak bandwidth on.
- vec_width : Optional[int]
- Vector unit width, determined from target if not supplied.
- remote : Optional[RPCSession]
- Remote session used to upload artifacts for runtime evaluation. Must be
- the same session used to create `dev`.
-
- Returns
- -------
- float
- Peak memory bandwidth in bytes/seconds.
- """
- # Ideally we'd be able to use this code to measure peak bandwidth of the
- # different cache levels. If we could just generate load commands, then we
- # could use those in a tight loop. Instead we need some code that is
- # limited on the cache bandwidth. With the L1 cache we need an operation
- # that has a very low arithmetic intensity and we haven't come up with one
- # yet.
- vec_width, _ = _detect_vec_width_registers(target, vec_width, 1)
- specialized = peak_bandwidth_tir.specialize(
- {
- peak_bandwidth_tir.params[3]: vec_width,
- }
- )
- with transform.PassContext(opt_level=3):
- f = build(specialized, target=target)
-
- # upload to remote if running over rpc
- if dev.device_type >= RPC_SESS_MASK:
- if remote is None:
- raise RuntimeError("A RPCSession must be provided when using a remote device.")
- temp = utils.tempdir()
- path = temp.relpath("peak_bandwidth.tar")
- f.export_library(path)
- remote.upload(path)
- f = remote.load_module("peak_bandwidth.tar")
- random_fill = remote.get_function("tvm.contrib.random.random_fill")
- else:
- random_fill = get_global_func("tvm.contrib.random.random_fill")
- assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
-
- threads = num_threads()
- # Data size needs to be larger than last level of cache. We don't have a
- # way of getting cache sizes, so this number should give us a large enough
- # size.
- size = 10**8 // (4 * threads * vec_width)
- a = nd.empty((threads, size, 4, vec_width), dtype="float32", device=dev)
- random_fill(a)
- b = nd.empty((threads, vec_width, 4), dtype="float32", device=dev)
- random_fill(b)
- times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b, threads)
- return a.numpy().size * 4 / times.min # 4 bytes per float32
-
-
@pass_instrument
class SaveLoweredTIR:
"""Save TIR functions from right before final lowering. Right now this
@@ -357,8 +134,9 @@ def roofline_from_existing(
:py:func:`roofline_analysis` for more information on which metrics
are included.
"""
- peak_bandwidth = estimate_peak_bandwidth(target, dev, remote=remote)
- peak_flops = estimate_peak_fma_flops(target, dev, remote=remote)
+ with target:
+ peak_bandwidth = registry.estimate_peak_bandwidth(target, dev, remote)
+ peak_flops = registry.estimate_peak_flops(target, dev, remote)
ridge_point = peak_flops / peak_bandwidth
@@ -377,7 +155,19 @@ def roofline_from_existing(
loaded_bytes = 0.0
# assume no more than 100 buffers
for i in range(100):
- key = f"B{i}.bytes"
+ if str(target.kind) == "cuda":
+ # autoscheduler features do not take into account that 1.
+ # global and shared memory have very different performance
+ # characteristics -- both are included in the same bytes
+ # touched count 2. multiple threads accessing the same byte
+ # of memory does not use the same amount of bandwidth as
+ # multiple threads accessing different bytes of memory. We
+ # use unique bytes accessed here to avoid these two issues,
+ # but this does bias results towards being more compute
+ # bound.
+ key = f"B{i}.unique_bytes"
+ else:
+ key = f"B{i}.bytes"
if not key in features.keys():
break
loaded_bytes += np.sum(features[key])
@@ -401,7 +191,7 @@ def roofline_from_existing(
else:
new_calls.append(call)
new_configuration = dict(report.configuration.items())
- new_configuration["Estimated Peak FMA FLOP/s"] = profiling.Ratio(peak_flops)
+ new_configuration["Estimated Peak FLOP/s"] = profiling.Ratio(peak_flops)
new_configuration["Estimated Peak Bandwidth (byte/second)"] = profiling.Ratio(peak_bandwidth)
return profiling.Report(new_calls, report.device_metrics, new_configuration)
diff --git a/python/tvm/utils/roofline/cuda.py b/python/tvm/utils/roofline/cuda.py
new file mode 100644
index 0000000000..f5a3f5e1dd
--- /dev/null
+++ b/python/tvm/utils/roofline/cuda.py
@@ -0,0 +1,236 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Estimation of peak flops and memory bandwidth for cuda devices"""
+from typing import Optional
+from ...script import tir as T
+from ... import nd, build, transform
+from ...runtime import Device
+from ...target import Target
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from . import registry
+from ...contrib import utils, nvcc
+
+
+@registry.estimate_peak_flops.register("cuda")
+def estimate_peak_flops_tensorcore(
+ target: Target,
+ dev: Device,
+ remote: Optional[RPCSession],
+ mat_dtype: str = "float16",
+ acc_dtype: str = "float32",
+) -> float:
+ """Estimate the peak FLOP/s of a cuda device with tensorcores.
+
+ This estimate should only be used to compare with operators that can use
+ dense tensorcore mma instructions.
+
+ References
+ ----------
+ Wei Sun, Ang Li, Tong Geng, Sander Stuijk, Henk Corporaal: "Dissecting
+ Tensor Cores via Microbenchmarks: Latency, Throughput and Numerical
+ Behaviors", 2022; http://arxiv.org/abs/2206.02874
+ https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf
+
+ Parameters
+ ----------
+ target : Target
+ Target to run on. This should be as specific to the actual hardware as
+ possible.
+ dev : Device
+ Device to run on.
+ remote : Optional[RPCSession]
+ Remote session used to upload artifacts for runtime evaluation. Must be
+ the same session used to create `dev`.
+ mat_dtype : str
+ Dtype of matrices passed to mma instructions.
+ acc_dtype : str
+ Dtype of accumulator to use with mma instructions. Should be compatible
+ with `mat_dtype`.
+
+ Returns
+ -------
+ float
+ Approximate sustained FLOP/s of this target/device combo assuming
+ mma instructions. Addition and multiplications are each counted as
+ separate FLOPs.
+ """
+ assert str(target.kind) == "cuda", "Only CUDA devices have tensorcores"
+
+ @T.prim_func
+ def peak_flops_tensorcore_tir(
+ inp: T.Buffer((16, 16), mat_dtype),
+ out: T.Buffer((16, 16), acc_dtype),
+ n: T.int32,
+ sms: T.int32,
+ ):
+ # pylint: disable=invalid-name, missing-function-docstring
+ A = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_a")
+ B = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_b")
+ C = T.alloc_buffer((16, 16), dtype=acc_dtype, scope="wmma.accumulator")
+ for _ in T.thread_binding(sms, thread="blockIdx.x"):
+ for _ in T.thread_binding(
+ 8, thread="threadIdx.y"
+ ): # need 8 warps to get enough in-SM parallelism
+ for _ in T.thread_binding(32, thread="threadIdx.x"):
+ T.evaluate(
+ T.tvm_load_matrix_sync(
+ A.data,
+ 16,
+ 16,
+ 16,
+ 0,
+ T.tvm_access_ptr(
+ T.type_annotation(dtype=mat_dtype),
+ inp.data,
+ 0,
+ 16,
+ 1,
+ dtype="handle",
+ ),
+ 16,
+ "row_major",
+ dtype="handle",
+ )
+ )
+ T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, 0, dtype="handle"))
+ T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, 0, 0, dtype="handle"))
+ for _ in range(n):
+ T.evaluate(
+ T.tvm_mma_sync(
+ C.data, 0, A.data, 0, B.data, 0, C.data, 0, dtype="handle"
+ )
+ )
+ T.evaluate(
+ T.tvm_store_matrix_sync(
+ C.data,
+ 16,
+ 16,
+ 16,
+ 0,
+ T.tvm_access_ptr(
+ T.type_annotation(dtype=acc_dtype),
+ out.data,
+ 0,
+ 16,
+ 2,
+ dtype="handle",
+ ),
+ 16,
+ "row_major",
+ dtype="handle",
+ )
+ )
+
+ n = 100000
+ sms = dev.multi_processor_count
+ specialized = peak_flops_tensorcore_tir.specialize(
+ {peak_flops_tensorcore_tir.params[2]: n, peak_flops_tensorcore_tir.params[3]: sms}
+ )
+ with transform.PassContext(opt_level=3):
+ f = build(specialized, target=target)
+
+ # upload to remote if running over rpc
+ if dev.device_type >= RPC_SESS_MASK:
+ if remote is None:
+ raise RuntimeError("A RPCSession must be provided when using a remote device.")
+ temp = utils.tempdir()
+ path = temp.relpath("peak_fma_flops.tar")
+ f.export_library(path)
+ remote.upload(path)
+ f = remote.load_module("peak_fma_flops.tar")
+
+ x = nd.empty((16, 16), dtype=mat_dtype, device=dev)
+ y = nd.empty((16, 16), dtype=acc_dtype, device=dev)
+ times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(x, y)
+ # each mma operation computes 16 x 16 x 16 FLOPs
+ return n * 16 * 16 * 16 * 2 * sms * 8 / times.min
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.int32) -> None:
+ # pylint: disable=invalid-name, missing-function-docstring
+ N = T.var("int32")
+ A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
+ B = T.match_buffer(b, [blocks, 4, warp_size], "float32")
+ for i in T.thread_binding(blocks, "blockIdx.x"):
+ for k in T.serial(N):
+ for l in T.unroll(4):
+ # vectorized load is necessary to hit peak bandwidth
+ for j in T.thread_binding(warp_size, "threadIdx.x"):
+ # += is necessary to introduce a data dependency for all
+ # elements of A, preventing the backend from removing the
+ # `k` loop and setting `k` to the loop extent.
+ B[i, l, j] += A[i, k, l, j]
+
+
+@registry.estimate_peak_bandwidth.register("cuda")
+def estimate_peak_bandwidth(
+ target: Target,
+ dev: Device,
+ remote: Optional[RPCSession] = None,
+) -> float:
+ """Estimate peak memory bandwidth of a target/device combo.
+
+ Peak bandwidth is estimated by running a small experiment on the underlying
+ hardware. The peak bandwidth measurement assumes that vector instructions
+ are being used to load the data.
+
+ Parameters
+ ----------
+ target : Target
+ Target to use for measurement. This target should be as specific to the
+ underlying hardware as possible.
+ dev : Device
+ Device to measure peak bandwidth on.
+ remote : Optional[RPCSession]
+ Remote session used to upload artifacts for runtime evaluation. Must be
+ the same session used to create `dev`.
+
+ Returns
+ -------
+ float
+ Peak memory bandwidth in bytes/seconds.
+ """
+ assert nvcc.have_tensorcore(
+ dev.compute_version
+ ), "CUDA roofline only works with devices that have tensorcores"
+ warp_size = dev.warp_size
+ # These sizes seem large enough to give the card time to hit a fixpoint on memory bandwidth
+ blocks = 1024
+ size = 1024
+
+ specialized = peak_bandwidth_tir.specialize(
+ {peak_bandwidth_tir.params[2]: blocks, peak_bandwidth_tir.params[3]: warp_size}
+ )
+ with transform.PassContext(opt_level=3):
+ f = build(specialized, target=target)
+
+ # upload to remote if running over rpc
+ if dev.device_type >= RPC_SESS_MASK:
+ if remote is None:
+ raise RuntimeError("A RPCSession must be provided when using a remote device.")
+ temp = utils.tempdir()
+ path = temp.relpath("peak_bandwidth.tar")
+ f.export_library(path)
+ remote.upload(path)
+ f = remote.load_module("peak_bandwidth.tar")
+
+ a = nd.empty((blocks, size, 4, warp_size), dtype="float32", device=dev)
+ b = nd.empty((blocks, 4, warp_size), dtype="float32", device=dev)
+ times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b)
+ return a.numpy().size * 4 / times.min # 4 bytes per float32
diff --git a/python/tvm/utils/roofline/registry.py b/python/tvm/utils/roofline/registry.py
new file mode 100644
index 0000000000..b3ea522be8
--- /dev/null
+++ b/python/tvm/utils/roofline/registry.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of generic functions for estimating peak flops and bandwidth"""
+from typing import Optional
+from ...target import Target, generic_func
+from ...runtime import Device
+from ...rpc.client import RPCSession
+
+
+@generic_func
+def estimate_peak_bandwidth(
+ target: Target,
+ dev: Device,
+ remote: Optional[RPCSession] = None,
+) -> float:
+ """Estimate peak memory bandwidth of a target/device combo.
+
+ Peak bandwidth is estimated by running a small experiment on the underlying
+ hardware. The peak bandwidth measurement assumes that vector instructions
+ are being used to load the data.
+
+ Parameters
+ ----------
+ target : Target
+ Target to use for measurement. This target should be as specific to the
+ underlying hardware as possible.
+ dev : Device
+ Device to measure peak bandwidth on.
+ remote : Optional[RPCSession]
+ Remote session used to upload artifacts for runtime evaluation. Must be
+ the same session used to create `dev`.
+
+ Returns
+ -------
+ float
+ Peak memory bandwidth in bytes/seconds.
+ """
+ raise NotImplementedError()
+
+
+@generic_func
+def estimate_peak_flops(
+ target: Target,
+ dev: Device,
+ remote: Optional[RPCSession],
+) -> float:
+ """
+ Estimate the maximum number of FLOP/s this target/device combo is capable
+ of reaching by running a test program. This is a generic function that
+ should be overridden for each target.
+
+ Parameters
+ ----------
+ target : Target
+ Target to run on. This should be as specific to the actual hardware as
+ possible to make sure that LLVM generates the best vector code.
+ dev : Device
+ Device to run on.
+ remote : Optional[RPCSession]
+ Remote session used to upload artifacts for runtime evaluation. Must be
+ the same session used to create `dev`.
+
+ Returns
+ -------
+ float
+ Approximate sustained FLOP/s of this target/device combo. Each FMA
+ operation counts as two FLOPs.
+ """
+ raise NotImplementedError()
diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py
new file mode 100644
index 0000000000..d4a0e51184
--- /dev/null
+++ b/python/tvm/utils/roofline/x86.py
@@ -0,0 +1,254 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Estimate peak flops and bandwidth for x86 devices"""
+from typing import Optional
+
+from ... import nd, build, topi, transform, get_global_func
+from ...target import Target
+from ...runtime import Device, num_threads
+from ...script import tir as T
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from ...contrib import utils
+from . import registry
+
+
+def _detect_vec_width_registers(
+ target: Target, vec_width: Optional[int], num_vector_registers: Optional[int]
+):
+ """Get the vector width and number of vector registers for a target.
+
+ Parameters
+ ----------
+ target : Target
+ Target to detect vector width and registers for.
+ vec_width : Optional[int]
+ If None, try and detect vector width from target. Otherwise provided input is used.
+ num_vector_registers : Optional[int]
+ If None, try and number of vector registers from target. Otherwise provided input is used.
+
+ Returns
+ -------
+ vec_width: int
+ Width of a vector register on `target`.
+ num_vector_registers: int
+ Number of vector registers on `target`.
+ """
+ if vec_width is None:
+ # Only implemented for x86 so far...
+ if (
+ str(target.kind) == "llvm"
+ and target.device_name == ""
+ and len(target.keys) == 1
+ and target.keys[0] == "cpu"
+ ):
+ with target:
+ vec_width = topi.x86.utils.get_simd_32bit_lanes() # in number of float32s
+ else:
+ raise RuntimeError(f"Cannot determine vector width for target {target}")
+ if num_vector_registers is None:
+ if target.device_name == "": # indicates x86
+ num_vector_registers = 16 # Assuming for all platforms, probably wrong on older ones
+ else:
+ raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+ return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+ a: T.handle,
+ vec_width: T.int32,
+ iters: T.int32,
+ num_vector_registers: T.int32,
+ threads: T.int32,
+) -> None:
+ # pylint: disable=invalid-name, missing-function-docstring
+ A = T.match_buffer(a, [threads, num_vector_registers, vec_width], "float32")
+ for t in T.parallel(threads):
+ for _j in range(iters):
+ for l in T.unroll(num_vector_registers):
+ # We want to use as few registers as possible, so we perform
+ # all operations on the same element
+ for k in T.vectorized(vec_width):
+ A[t, l, k] = A[t, l, k] * A[t, l, k] + A[t, l, k]
+
+
+@registry.estimate_peak_flops.register("cpu")
+def estimate_peak_fma_flops(
+ target: Target,
+ dev: Device,
+ remote: Optional[RPCSession],
+ vec_width: Optional[int] = None,
+ num_vector_registers: Optional[int] = None,
+) -> float:
+ """
+ Estimate the maximum number of FLOP/s this target/device combo is capable
+ of reaching by running a test program. This assumes vectorized f32 FMA
+ (fused-multiply-add) instructions.
+
+
+ Parameters
+ ----------
+ target : Target
+ Target to run on. This should be as specific to the actual hardware as
+ possible to make sure that LLVM generates the best vector code.
+ dev : Device
+ Device to run on.
+ remote : Optional[RPCSession]
+ Remote session used to upload artifacts for runtime evaluation. Must be
+ the same session used to create `dev`.
+ vec_width : Optional[int]
+ Vector width of SIMD units on the underlying hardware. Will try to
+ infer if no value is provided.
+ num_vector_registers : Optional[int]
+ Number of vector registers on the underlying hardware. Will try to
+ infer if no value is provided.
+
+ Returns
+ -------
+ float
+ Approximate sustained FLOP/s of this target/device combo assuming
+ vectorized f32 FMA instructions. Each FMA operation counts as two FLOPs.
+ """
+ assert str(target.kind) == "llvm", "Only llvm targets are supported"
+ vec_width, num_vector_registers = _detect_vec_width_registers(
+ target, vec_width, num_vector_registers
+ )
+ iters = 1000000
+ nthreads = num_threads()
+ specialized = peakflops_fma_tir.specialize(
+ {
+ peakflops_fma_tir.params[1]: vec_width,
+ peakflops_fma_tir.params[2]: iters,
+ peakflops_fma_tir.params[3]: num_vector_registers,
+ peakflops_fma_tir.params[4]: nthreads,
+ }
+ )
+ with transform.PassContext(opt_level=3):
+ f = build(specialized, target=target)
+
+ # upload to remote if running over rpc
+ if dev.device_type >= RPC_SESS_MASK:
+ if remote is None:
+ raise RuntimeError("A RPCSession must be provided when using a remote device.")
+ temp = utils.tempdir()
+ path = temp.relpath("peak_fma_flops.tar")
+ f.export_library(path)
+ remote.upload(path)
+ f = remote.load_module("peak_fma_flops.tar")
+ random_fill = remote.get_function("tvm.contrib.random.random_fill")
+ else:
+ random_fill = get_global_func("tvm.contrib.random.random_fill")
+ assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
+
+ a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", device=dev)
+ random_fill(a)
+ times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a)
+ flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops
+ flop_s = flops / times.min
+ return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.int32) -> None:
+ # pylint: disable=invalid-name, missing-function-docstring
+ N = T.var("int32")
+ A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
+ B = T.match_buffer(b, [threads, 4, vec_width], "float32")
+ # Parallelism is necessary to hit all cores/nodes
+ for i in T.parallel(threads):
+ for k in T.serial(N):
+ for l in T.unroll(4):
+ # vectorized load is necessary to hit peak bandwidth
+ for j in T.vectorized(vec_width):
+ # += is necessary to introduce a data dependency for all
+ # elements of A, preventing the backend from removing the
+ # `k` loop and setting `k` to the loop extent.
+ B[i, l, j] += A[i, k, l, j]
+
+
+@registry.estimate_peak_bandwidth.register("cpu")
+def estimate_peak_bandwidth(
+ target: Target,
+ dev: Device,
+ remote: Optional[RPCSession],
+ vec_width: Optional[int] = None,
+) -> float:
+ """Estimate peak memory bandwidth of a target/device combo.
+
+ Peak bandwidth is estimated by running a small experiment on the underlying
+ hardware. The peak bandwidth measurement assumes that vector instructions
+ are being used to load the data.
+
+ Parameters
+ ----------
+ target : Target
+ Target to use for measurement. This target should be as specific to the
+ underlying hardware as possible.
+ dev : Device
+ Device to measure peak bandwidth on.
+ remote : Optional[RPCSession]
+ Remote session used to upload artifacts for runtime evaluation. Must be
+ the same session used to create `dev`.
+ vec_width : Optional[int]
+ Vector unit width, determined from target if not supplied.
+
+ Returns
+ -------
+ float
+ Peak memory bandwidth in bytes/seconds.
+ """
+ # Ideally we'd be able to use this code to measure peak bandwidth of the
+ # different cache levels. If we could just generate load commands, then we
+ # could use those in a tight loop. Instead we need some code that is
+ # limited on the cache bandwidth. With the L1 cache we need an operation
+ # that has a very low arithmetic intensity and we haven't come up with one
+ # yet.
+ vec_width, _ = _detect_vec_width_registers(target, vec_width, 1)
+ specialized = peak_bandwidth_tir.specialize(
+ {
+ peak_bandwidth_tir.params[3]: vec_width,
+ }
+ )
+ with transform.PassContext(opt_level=3):
+ f = build(specialized, target=target)
+
+ # upload to remote if running over rpc
+ if dev.device_type >= RPC_SESS_MASK:
+ if remote is None:
+ raise RuntimeError("A RPCSession must be provided when using a remote device.")
+ temp = utils.tempdir()
+ path = temp.relpath("peak_bandwidth.tar")
+ f.export_library(path)
+ remote.upload(path)
+ f = remote.load_module("peak_bandwidth.tar")
+ random_fill = remote.get_function("tvm.contrib.random.random_fill")
+ else:
+ random_fill = get_global_func("tvm.contrib.random.random_fill")
+ assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
+
+ threads = num_threads()
+ # Data size needs to be larger than last level of cache. We don't have a
+ # way of getting cache sizes, so this number should give us a large enough
+ # size.
+ size = 10**8 // (4 * threads * vec_width)
+ a = nd.empty((threads, size, 4, vec_width), dtype="float32", device=dev)
+ random_fill(a)
+ b = nd.empty((threads, 4, vec_width), dtype="float32", device=dev)
+ random_fill(b)
+ times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b, threads)
+ return a.numpy().size * 4 / times.min # 4 bytes per float32
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index 3ea6f8d9ed..dde1d112ed 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -1235,11 +1235,13 @@ void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const Var
if (scope == "wmma.matrix_a") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
+ ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
<< ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.matrix_b") {
need_mma_h_ = true;
std::string layout_str = fragment_layouts[variable];
+ ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
<< ", nvcuda::wmma::" << layout_str << ">";
} else if (scope == "wmma.accumulator") {
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
index 1e5b2f28b2..520e3ee03c 100644
--- a/src/tir/ir/specialize.cc
+++ b/src/tir/ir/specialize.cc
@@ -363,6 +363,7 @@ PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
} else if (instance->IsInstance<PrimExprNode>()) {
UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), &var_map);
} else {
+ CHECK(instance.defined()) << "Specialize instance is not defined for param " << param;
LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got "
<< instance->GetTypeKey();
}
diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc
index b050193076..e0ae7172ad 100644
--- a/src/tir/transforms/tensorcore_infer_fragment.cc
+++ b/src/tir/transforms/tensorcore_infer_fragment.cc
@@ -92,15 +92,14 @@ class FragmentGetter : public StmtExprVisitor {
ICHECK(k);
std::string scope = GetPtrStorageScope(GetRef<Var>(buffer_var));
- // Only wmma.accumulator can use tvm_fill_fragment
- ICHECK_EQ(scope, "wmma.accumulator");
if (fragments.count(buffer_var)) {
FragmentInfo info = fragments[buffer_var];
ICHECK_EQ(m->value, info.m);
ICHECK_EQ(n->value, info.n);
ICHECK_EQ(k->value, info.k);
} else {
- FragmentInfo info(m->value, n->value, k->value, "", scope);
+ // default to row major ordering
+ FragmentInfo info(m->value, n->value, k->value, "row_major", scope);
fragments[buffer_var] = info;
}
}
@@ -148,8 +147,14 @@ class FragmentChecker : public StmtExprVisitor {
private:
// A tool for checking shapes of two fragments
bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) {
- ICHECK(fragment_getter.fragments.count(buffer1));
- ICHECK(fragment_getter.fragments.count(buffer2));
+ CHECK(fragment_getter.fragments.count(buffer1))
+ << "Tensorecore fragment " << buffer1->name_hint
+ << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before "
+ "use.";
+ CHECK(fragment_getter.fragments.count(buffer2))
+ << "Tensorecore fragment " << buffer2->name_hint
+ << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before "
+ "use.";
FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
diff --git a/tests/python/unittest/test_roofline.py b/tests/python/unittest/test_roofline.py
new file mode 100644
index 0000000000..e37f6e085b
--- /dev/null
+++ b/tests/python/unittest/test_roofline.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+from io import StringIO
+import csv
+import os
+import json
+import platform
+
+import tvm.testing
+import tvm.utils
+from tvm.runtime import profiler_vm
+from tvm import relay
+from tvm.relay.testing import mlp
+from tvm.contrib.debugger import debug_executor
+from tvm import rpc
+from tvm.contrib import utils
+from tvm.runtime.profiling import Report
+from tvm.script import tir as T
+
+
+@tvm.testing.parametrize_targets("llvm", "cuda")
+def test_estimate_peak_flops(target, dev):
+ server = rpc.Server(key="roofline_flops")
+ remote = rpc.connect("127.0.0.1", server.port, key="roofline_flops")
+ dev = remote.device(target)
+ # This test uses vectorized instructions so we need a target that supports them
+ if target == "llvm":
+ target = "llvm -mattr=+fma,+avx2"
+ target = tvm.target.Target(target)
+ with target:
+ flops = tvm.utils.roofline.registry.estimate_peak_flops(target, dev, remote)
+ if str(target.kind) == "llvm":
+ # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
+ assert (
+ flops > 10**9 and flops < 10**14
+ ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
+ elif str(target.kind) == "cuda":
+ # should be able to hit a TFLOP/s with tensor cores
+ assert (
+ flops > 10**12 and flops < 10**14
+ ), f"FLOP/s should be between 10^12 and 10^14, but it is {flops}"
+ else:
+ raise RuntimeError("Unsupported target " + str(target))
+
+
+@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
+@tvm.testing.parametrize_targets("llvm", "cuda")
+def test_estimate_peak_bandwidth(target, dev):
+ server = rpc.Server(key="roofline_bandwidth")
+ remote = rpc.connect("127.0.0.1", server.port, key="roofline_bandwidth")
+ dev = remote.device(target)
+ # This test uses vectorized instructions so we need a target that supports them
+ if target == "llvm":
+ target = "llvm -mattr=+fma,+avx2"
+ target = tvm.target.Target(target)
+ with target:
+ bandwidth = tvm.utils.roofline.registry.estimate_peak_bandwidth(target, dev, remote)
+ if str(target.kind) == "llvm":
+ # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
+ # GB/s, so this should leave enough wiggle room.
+ assert (
+ bandwidth > 10**9 and bandwidth < 10**12
+ ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+ elif str(target.kind) == "cuda":
+ # should be able to hit a 100 GB/s on a GPU. GTX 280 hits 140 GB/s and
+ # it is really old.
+ assert (
+ bandwidth > 10**11 and bandwidth < 10**13
+ ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+ else:
+ raise RuntimeError("Unsupported target " + str(target))
+
+
+@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
+@tvm.testing.parametrize_targets("llvm -mattr=+fma+avx2", "cuda")
+def test_roofline_analysis(target, dev):
+ a = relay.var("a", relay.TensorType((512, 512), "float32"))
+ b = relay.var("b", relay.TensorType((512, 512), "float32"))
+ c = relay.nn.dense(a, b)
+ mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
+ params = {}
+
+ server = rpc.Server(key="roofline")
+ remote = rpc.connect("127.0.0.1", server.port, key="roofline")
+ dev = remote.device(target)
+
+ report = tvm.utils.roofline_analysis(mod, params, target, dev, remote=remote)
+ print(report)
+
+ assert "Bound" in report.table()
+ assert "Percent of Theoretical Optimal" in report.table()
+ for call in report.calls:
+ if "Percent of Theoretical Optimal" in call:
+ if target.startswith("llvm"):
+ # Ideally we'd like a little tighter bound here, but it is hard to
+ # know how well this dense will perform without tuning. And we
+ # don't have an operator that uses a specific number of flops.
+ assert call["Percent of Theoretical Optimal"].ratio >= 5.0
+ elif target == "cuda":
+ # The cuda gpu kernel is really poorly optimized
+ assert 90 >= call["Percent of Theoretical Optimal"].ratio >= 0.01
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py
index adb5dee174..ed1841b216 100644
--- a/tests/python/unittest/test_runtime_profiling.py
+++ b/tests/python/unittest/test_runtime_profiling.py
@@ -263,103 +263,5 @@ def test_profile_function(target, dev):
assert report[metric].value > 0
-@tvm.testing.parametrize_targets("llvm")
-def test_estimate_peak_fma_flops(target, dev):
- # This test uses vectorized instructions so we need a target that supports them
- if target == "llvm":
- target = "llvm -mattr=+fma,+avx2"
- flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
- # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
- assert (
- flops > 10**9 and flops < 10**14
- ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
-
-
-def test_estimate_peak_fma_flops_rpc():
- target = "llvm -mattr=+fma,+avx2"
- server = rpc.Server(key="profiling")
- remote = rpc.connect("127.0.0.1", server.port, key="profiling")
- dev = remote.device(target)
- flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev, remote=remote)
- # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
- assert (
- flops > 10**9 and flops < 10**14
- ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
-
-
-@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
-@tvm.testing.parametrize_targets("llvm")
-def test_estimate_peak_bandwidth(target, dev):
- # This test uses vectorized instructions so we need a target that supports them
- if target == "llvm":
- target = "llvm -mattr=+fma,+avx2"
- bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev)
- # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
- # GB/s, so this should leave enough wiggle room.
- assert (
- bandwidth > 10**9 and bandwidth < 10**12
- ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
-
-
-@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
-def test_estimate_peak_bandwidth_rpc():
- target = "llvm -mattr=+fma,+avx2"
- server = rpc.Server(key="profiling")
- remote = rpc.connect("127.0.0.1", server.port, key="profiling")
- dev = remote.device(target)
- bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev, remote=remote)
- # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
- # GB/s, so this should leave enough wiggle room.
- assert (
- bandwidth > 10**9 and bandwidth < 10**12
- ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
-
-
-@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
-@tvm.testing.parametrize_targets("llvm")
-def test_roofline_analysis(target, dev):
- a = relay.var("a", relay.TensorType((512, 512), "float32"))
- b = relay.var("b", relay.TensorType((512, 512), "float32"))
- c = relay.nn.dense(a, b)
- mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
- params = {}
- report = tvm.utils.roofline_analysis(mod, params, target, dev)
-
- assert "Bound" in report.table()
- assert "Percent of Theoretical Optimal" in report.table()
- for call in report.calls:
- if "Percent of Theoretical Optimal" in call:
- # Ideally we'd like a little tighter bound here, but it is hard to
- # know how well this dense will perform without tuning. And we
- # don't have an operator that uses a specific number of flops.
- assert call["Percent of Theoretical Optimal"].ratio >= 5.0
-
-
-@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
-def test_roofline_analysis_rpc():
- target = "llvm"
-
- a = relay.var("a", relay.TensorType((512, 512), "float32"))
- b = relay.var("b", relay.TensorType((512, 512), "float32"))
- c = relay.nn.dense(a, b)
- mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
- params = {}
-
- server = rpc.Server(key="profiling")
- remote = rpc.connect("127.0.0.1", server.port, key="profiling")
- dev = remote.device(target)
-
- report = tvm.utils.roofline_analysis(mod, params, target, dev, remote=remote)
-
- assert "Bound" in report.table()
- assert "Percent of Theoretical Optimal" in report.table()
- for call in report.calls:
- if "Percent of Theoretical Optimal" in call:
- # Ideally we'd like a little tighter bound here, but it is hard to
- # know how well this dense will perform without tuning. And we
- # don't have an operator that uses a specific number of flops.
- assert call["Percent of Theoretical Optimal"].ratio >= 5.0
-
-
if __name__ == "__main__":
tvm.testing.main()