You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/07/30 23:35:31 UTC

[tvm] branch main updated: [ROOFLINE] Add CUDA support to roofline analysis (#12205)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 961a7c70d7 [ROOFLINE] Add CUDA support to roofline analysis (#12205)
961a7c70d7 is described below

commit 961a7c70d75c81503c8c1d7c2e0db66bac4a1859
Author: Tristan Konolige <tk...@octoml.ai>
AuthorDate: Sat Jul 30 16:35:25 2022 -0700

    [ROOFLINE] Add CUDA support to roofline analysis (#12205)
    
    * [ROOFLINE] Add CUDA support to roofline analysis
    
    Add functions to estimate peak flops and bandwidth for CUDA. Add a new
    registration mechanism to the roofline analysis to support adding any
    target. This mechanism uses generic functions with overrides. New
    targets only need to add `estimate_peak_bandwidth` and
    `estimate_peak_flops` functions.
    
    Also fix cuda codegen and tensorcore_infer_fragment.cc to support
    filling matrix_a and matrix_b fragments.
    
    * formatiing
    
    * move statement back inside loops
    
    * print out report for debugging
    
    * default to avx2
    
    * review comments
---
 python/tvm/utils/__init__.py                       |   2 +-
 .../utils/{roofline.py => roofline/__init__.py}    | 266 +++------------------
 python/tvm/utils/roofline/cuda.py                  | 236 ++++++++++++++++++
 python/tvm/utils/roofline/registry.py              |  83 +++++++
 python/tvm/utils/roofline/x86.py                   | 254 ++++++++++++++++++++
 src/target/source/codegen_cuda.cc                  |   2 +
 src/tir/ir/specialize.cc                           |   1 +
 src/tir/transforms/tensorcore_infer_fragment.cc    |  15 +-
 tests/python/unittest/test_roofline.py             | 121 ++++++++++
 tests/python/unittest/test_runtime_profiling.py    |  98 --------
 10 files changed, 736 insertions(+), 342 deletions(-)

diff --git a/python/tvm/utils/__init__.py b/python/tvm/utils/__init__.py
index 3c1703c244..33abc352b0 100644
--- a/python/tvm/utils/__init__.py
+++ b/python/tvm/utils/__init__.py
@@ -16,4 +16,4 @@
 # under the License.
 """Utilities operating at a graph/model or other "high" level"""
 
-from .roofline import estimate_peak_bandwidth, estimate_peak_fma_flops, roofline_analysis
+from .roofline import roofline_analysis
diff --git a/python/tvm/utils/roofline.py b/python/tvm/utils/roofline/__init__.py
similarity index 51%
rename from python/tvm/utils/roofline.py
rename to python/tvm/utils/roofline/__init__.py
index 7323149193..a54f5ed41d 100644
--- a/python/tvm/utils/roofline.py
+++ b/python/tvm/utils/roofline/__init__.py
@@ -18,15 +18,17 @@
 from typing import Dict, Union, Optional
 import numpy as np
 
-from .. import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
-from ..target import Target
-from ..runtime import profiler_vm, profiling, Device, num_threads
-from ..script import tir as T
-from ..ir.instrument import pass_instrument
-from ..ir.expr import GlobalVar
-from ..rpc.base import RPC_SESS_MASK
-from ..rpc.client import RPCSession
-from ..contrib import utils
+from ... import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
+from ...target import Target
+from ...runtime import profiler_vm, profiling, Device, num_threads
+from ...script import tir as T
+from ...ir.instrument import pass_instrument
+from ...ir.expr import GlobalVar
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from ...contrib import utils
+
+from . import registry, cuda, x86
 
 
 def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=None):
@@ -47,231 +49,6 @@ def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=Non
     return args
 
 
-def _detect_vec_width_registers(
-    target: Target, vec_width: Optional[int], num_vector_registers: Optional[int]
-):
-    """Get the vector width and number of vector registers for a target.
-
-    Parameters
-    ----------
-    target : Target
-        Target to detect vector width and registers for.
-    vec_width : Optional[int]
-        If None, try and detect vector width from target. Otherwise provided input is used.
-    num_vector_registers : Optional[int]
-        If None, try and number of vector registers from target. Otherwise provided input is used.
-
-    Returns
-    -------
-    vec_width: int
-        Width of a vector register on `target`.
-    num_vector_registers: int
-        Number of vector registers on `target`.
-    """
-    if vec_width is None:
-        # Only implemented for x86 so far...
-        if (
-            str(target.kind) == "llvm"
-            and target.device_name == ""
-            and len(target.keys) == 1
-            and target.keys[0] == "cpu"
-        ):
-            with target:
-                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
-        else:
-            raise RuntimeError(f"Cannot determine vector width for target {target}")
-    if num_vector_registers is None:
-        if target.device_name == "":  # indicates x86
-            num_vector_registers = 16  # Assuming for all platforms, probably wrong on older ones
-        else:
-            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
-    return vec_width, num_vector_registers
-
-
-@T.prim_func
-def peakflops_fma_tir(
-    a: T.handle,
-    vec_width: T.int32,
-    iters: T.int32,
-    num_vector_registers: T.int32,
-    threads: T.int32,
-) -> None:
-    # pylint: disable=invalid-name, missing-function-docstring
-    A = T.match_buffer(a, [threads, num_vector_registers, vec_width], "float32")
-    for t in T.parallel(threads):
-        for _j in range(iters):
-            for l in T.unroll(num_vector_registers):
-                # We want to use as few registers as possible, so we perform
-                # all operations on the same element
-                for k in T.vectorized(vec_width):
-                    A[t, l, k] = A[t, l, k] * A[t, l, k] + A[t, l, k]
-
-
-def estimate_peak_fma_flops(
-    target: Target,
-    dev: Device,
-    vec_width: Optional[int] = None,
-    num_vector_registers: Optional[int] = None,
-    remote: Optional[RPCSession] = None,
-) -> float:
-    """
-    Estimate the maximum number of FLOP/s this target/device combo is capable
-    of reaching by running a test program. This assumes vectorized f32 FMA
-    (fused-multiply-add) instructions.
-
-
-    Parameters
-    ----------
-    target : Target
-        Target to run on. This should be as specific to the actual hardware as
-        possible to make sure that LLVM generates the best vector code.
-    dev : Device
-        Device to run on.
-    vec_width : Optional[int]
-        Vector width of SIMD units on the underlying hardware. Will try to
-        infer if no value is provided.
-    num_vector_registers : Optional[int]
-        Number of vector registers on the underlying hardware. Will try to
-        infer if no value is provided.
-    remote : Optional[RPCSession]
-      Remote session used to upload artifacts for runtime evaluation. Must be
-      the same session used to create `dev`.
-
-    Returns
-    -------
-    float
-        Approximate sustained FLOP/s of this target/device combo assuming
-        vectorized f32 FMA instructions.
-    """
-    assert str(target.kind) == "llvm", "Only llvm targets are supported"
-    vec_width, num_vector_registers = _detect_vec_width_registers(
-        target, vec_width, num_vector_registers
-    )
-    iters = 1000000
-    nthreads = num_threads()
-    specialized = peakflops_fma_tir.specialize(
-        {
-            peakflops_fma_tir.params[1]: vec_width,
-            peakflops_fma_tir.params[2]: iters,
-            peakflops_fma_tir.params[3]: num_vector_registers,
-            peakflops_fma_tir.params[4]: nthreads,
-        }
-    )
-    with transform.PassContext(opt_level=3):
-        f = build(specialized, target=target)
-
-    # upload to remote if running over rpc
-    if dev.device_type >= RPC_SESS_MASK:
-        if remote is None:
-            raise RuntimeError("A RPCSession must be provided when using a remote device.")
-        temp = utils.tempdir()
-        path = temp.relpath("peak_fma_flops.tar")
-        f.export_library(path)
-        remote.upload(path)
-        f = remote.load_module("peak_fma_flops.tar")
-        random_fill = remote.get_function("tvm.contrib.random.random_fill")
-    else:
-        random_fill = get_global_func("tvm.contrib.random.random_fill")
-    assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
-
-    a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", device=dev)
-    random_fill(a)
-    times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a)
-    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
-    flop_s = flops / times.min
-    return flop_s
-
-
-@T.prim_func
-def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.int32) -> None:
-    # pylint: disable=invalid-name, missing-function-docstring
-    N = T.var("int32")
-    A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
-    B = T.match_buffer(b, [threads, vec_width, 4], "float32")
-    # Parallelism is necessary to hit all cores/nodes
-    for i in T.parallel(threads):
-        for k in T.serial(N):
-            for l in T.unroll(4):
-                # vectorized load is necessary to hit peak bandwidth
-                for j in T.vectorized(vec_width):
-                    # += is necessary to introduce a data dependency for all
-                    # elements of A, preventing the backend from removing the
-                    # `k` loop and setting `k` to the loop extent.
-                    B[i, l, j] += A[i, k, l, j]
-
-
-def estimate_peak_bandwidth(
-    target: Target,
-    dev: Device,
-    vec_width: Optional[int] = None,
-    remote: Optional[RPCSession] = None,
-) -> float:
-    """Estimate peak memory bandwidth of a target/device combo.
-
-    Peak bandwidth is estimated by running a small experiment on the underlying
-    hardware. The peak bandwidth measurement assumes that vector instructions
-    are being used to load the data.
-
-    Parameters
-    ----------
-    target : Target
-        Target to use for measurement. This target should be as specific to the
-        underlying hardware as possible.
-    dev : Device
-        Device to measure peak bandwidth on.
-    vec_width : Optional[int]
-        Vector unit width, determined from target if not supplied.
-    remote : Optional[RPCSession]
-      Remote session used to upload artifacts for runtime evaluation. Must be
-      the same session used to create `dev`.
-
-    Returns
-    -------
-    float
-        Peak memory bandwidth in bytes/seconds.
-    """
-    # Ideally we'd be able to use this code to measure peak bandwidth of the
-    # different cache levels. If we could just generate load commands, then we
-    # could use those in a tight loop. Instead we need some code that is
-    # limited on the cache bandwidth. With the L1 cache we need an operation
-    # that has a very low arithmetic intensity and we haven't come up with one
-    # yet.
-    vec_width, _ = _detect_vec_width_registers(target, vec_width, 1)
-    specialized = peak_bandwidth_tir.specialize(
-        {
-            peak_bandwidth_tir.params[3]: vec_width,
-        }
-    )
-    with transform.PassContext(opt_level=3):
-        f = build(specialized, target=target)
-
-    # upload to remote if running over rpc
-    if dev.device_type >= RPC_SESS_MASK:
-        if remote is None:
-            raise RuntimeError("A RPCSession must be provided when using a remote device.")
-        temp = utils.tempdir()
-        path = temp.relpath("peak_bandwidth.tar")
-        f.export_library(path)
-        remote.upload(path)
-        f = remote.load_module("peak_bandwidth.tar")
-        random_fill = remote.get_function("tvm.contrib.random.random_fill")
-    else:
-        random_fill = get_global_func("tvm.contrib.random.random_fill")
-    assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
-
-    threads = num_threads()
-    # Data size needs to be larger than last level of cache. We don't have a
-    # way of getting cache sizes, so this number should give us a large enough
-    # size.
-    size = 10**8 // (4 * threads * vec_width)
-    a = nd.empty((threads, size, 4, vec_width), dtype="float32", device=dev)
-    random_fill(a)
-    b = nd.empty((threads, vec_width, 4), dtype="float32", device=dev)
-    random_fill(b)
-    times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b, threads)
-    return a.numpy().size * 4 / times.min  # 4 bytes per float32
-
-
 @pass_instrument
 class SaveLoweredTIR:
     """Save TIR functions from right before final lowering. Right now this
@@ -357,8 +134,9 @@ def roofline_from_existing(
         :py:func:`roofline_analysis` for more information on which metrics
         are included.
     """
-    peak_bandwidth = estimate_peak_bandwidth(target, dev, remote=remote)
-    peak_flops = estimate_peak_fma_flops(target, dev, remote=remote)
+    with target:
+        peak_bandwidth = registry.estimate_peak_bandwidth(target, dev, remote)
+        peak_flops = registry.estimate_peak_flops(target, dev, remote)
 
     ridge_point = peak_flops / peak_bandwidth
 
@@ -377,7 +155,19 @@ def roofline_from_existing(
             loaded_bytes = 0.0
             # assume no more than 100 buffers
             for i in range(100):
-                key = f"B{i}.bytes"
+                if str(target.kind) == "cuda":
+                    # autoscheduler features do not take into account that 1.
+                    # global and shared memory have very different performance
+                    # characteristics -- both are included in the same bytes
+                    # touched count 2. multiple threads accessing the same byte
+                    # of memory does not use the same amount of bandwidth as
+                    # multiple threads accessing different bytes of memory. We
+                    # use unique bytes accessed here to avoid these two issues,
+                    # but this does bias results towards being more compute
+                    # bound.
+                    key = f"B{i}.unique_bytes"
+                else:
+                    key = f"B{i}.bytes"
                 if not key in features.keys():
                     break
                 loaded_bytes += np.sum(features[key])
@@ -401,7 +191,7 @@ def roofline_from_existing(
         else:
             new_calls.append(call)
     new_configuration = dict(report.configuration.items())
-    new_configuration["Estimated Peak FMA FLOP/s"] = profiling.Ratio(peak_flops)
+    new_configuration["Estimated Peak FLOP/s"] = profiling.Ratio(peak_flops)
     new_configuration["Estimated Peak Bandwidth (byte/second)"] = profiling.Ratio(peak_bandwidth)
     return profiling.Report(new_calls, report.device_metrics, new_configuration)
 
diff --git a/python/tvm/utils/roofline/cuda.py b/python/tvm/utils/roofline/cuda.py
new file mode 100644
index 0000000000..f5a3f5e1dd
--- /dev/null
+++ b/python/tvm/utils/roofline/cuda.py
@@ -0,0 +1,236 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Estimation of peak flops and memory bandwidth for cuda devices"""
+from typing import Optional
+from ...script import tir as T
+from ... import nd, build, transform
+from ...runtime import Device
+from ...target import Target
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from . import registry
+from ...contrib import utils, nvcc
+
+
+@registry.estimate_peak_flops.register("cuda")
+def estimate_peak_flops_tensorcore(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+    mat_dtype: str = "float16",
+    acc_dtype: str = "float32",
+) -> float:
+    """Estimate the peak FLOP/s of a cuda device with tensorcores.
+
+    This estimate should only be used to compare with operators that can use
+    dense tensorcore mma instructions.
+
+    References
+    ----------
+    Wei Sun, Ang Li, Tong Geng, Sander Stuijk, Henk Corporaal: "Dissecting
+    Tensor Cores via Microbenchmarks: Latency, Throughput and Numerical
+    Behaviors", 2022; http://arxiv.org/abs/2206.02874
+    https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+    mat_dtype : str
+        Dtype of matrices passed to mma instructions.
+    acc_dtype : str
+        Dtype of accumulator to use with mma instructions. Should be compatible
+        with `mat_dtype`.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        mma instructions. Addition and multiplications are each counted as
+        separate FLOPs.
+    """
+    assert str(target.kind) == "cuda", "Only CUDA devices have tensorcores"
+
+    @T.prim_func
+    def peak_flops_tensorcore_tir(
+        inp: T.Buffer((16, 16), mat_dtype),
+        out: T.Buffer((16, 16), acc_dtype),
+        n: T.int32,
+        sms: T.int32,
+    ):
+        # pylint: disable=invalid-name, missing-function-docstring
+        A = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_a")
+        B = T.alloc_buffer((16, 16), dtype=mat_dtype, scope="wmma.matrix_b")
+        C = T.alloc_buffer((16, 16), dtype=acc_dtype, scope="wmma.accumulator")
+        for _ in T.thread_binding(sms, thread="blockIdx.x"):
+            for _ in T.thread_binding(
+                8, thread="threadIdx.y"
+            ):  # need 8 warps to get enough in-SM parallelism
+                for _ in T.thread_binding(32, thread="threadIdx.x"):
+                    T.evaluate(
+                        T.tvm_load_matrix_sync(
+                            A.data,
+                            16,
+                            16,
+                            16,
+                            0,
+                            T.tvm_access_ptr(
+                                T.type_annotation(dtype=mat_dtype),
+                                inp.data,
+                                0,
+                                16,
+                                1,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+                    T.evaluate(T.tvm_fill_fragment(B.data, 16, 16, 16, 0, 0, dtype="handle"))
+                    T.evaluate(T.tvm_fill_fragment(C.data, 16, 16, 16, 0, 0, dtype="handle"))
+                    for _ in range(n):
+                        T.evaluate(
+                            T.tvm_mma_sync(
+                                C.data, 0, A.data, 0, B.data, 0, C.data, 0, dtype="handle"
+                            )
+                        )
+                    T.evaluate(
+                        T.tvm_store_matrix_sync(
+                            C.data,
+                            16,
+                            16,
+                            16,
+                            0,
+                            T.tvm_access_ptr(
+                                T.type_annotation(dtype=acc_dtype),
+                                out.data,
+                                0,
+                                16,
+                                2,
+                                dtype="handle",
+                            ),
+                            16,
+                            "row_major",
+                            dtype="handle",
+                        )
+                    )
+
+    n = 100000
+    sms = dev.multi_processor_count
+    specialized = peak_flops_tensorcore_tir.specialize(
+        {peak_flops_tensorcore_tir.params[2]: n, peak_flops_tensorcore_tir.params[3]: sms}
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+
+    # upload to remote if running over rpc
+    if dev.device_type >= RPC_SESS_MASK:
+        if remote is None:
+            raise RuntimeError("A RPCSession must be provided when using a remote device.")
+        temp = utils.tempdir()
+        path = temp.relpath("peak_fma_flops.tar")
+        f.export_library(path)
+        remote.upload(path)
+        f = remote.load_module("peak_fma_flops.tar")
+
+    x = nd.empty((16, 16), dtype=mat_dtype, device=dev)
+    y = nd.empty((16, 16), dtype=acc_dtype, device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(x, y)
+    # each mma operation computes 16 x 16 x 16 FLOPs
+    return n * 16 * 16 * 16 * 2 * sms * 8 / times.min
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [blocks, N, 4, warp_size], "float32")
+    B = T.match_buffer(b, [blocks, 4, warp_size], "float32")
+    for i in T.thread_binding(blocks, "blockIdx.x"):
+        for k in T.serial(N):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.thread_binding(warp_size, "threadIdx.x"):
+                    # += is necessary to introduce a data dependency for all
+                    # elements of A, preventing the backend from removing the
+                    # `k` loop and setting `k` to the loop extent.
+                    B[i, l, j] += A[i, k, l, j]
+
+
+@registry.estimate_peak_bandwidth.register("cuda")
+def estimate_peak_bandwidth(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession] = None,
+) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    assert nvcc.have_tensorcore(
+        dev.compute_version
+    ), "CUDA roofline only works with devices that have tensorcores"
+    warp_size = dev.warp_size
+    # These sizes seem large enough to give the card time to hit a fixpoint on memory bandwidth
+    blocks = 1024
+    size = 1024
+
+    specialized = peak_bandwidth_tir.specialize(
+        {peak_bandwidth_tir.params[2]: blocks, peak_bandwidth_tir.params[3]: warp_size}
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+
+    # upload to remote if running over rpc
+    if dev.device_type >= RPC_SESS_MASK:
+        if remote is None:
+            raise RuntimeError("A RPCSession must be provided when using a remote device.")
+        temp = utils.tempdir()
+        path = temp.relpath("peak_bandwidth.tar")
+        f.export_library(path)
+        remote.upload(path)
+        f = remote.load_module("peak_bandwidth.tar")
+
+    a = nd.empty((blocks, size, 4, warp_size), dtype="float32", device=dev)
+    b = nd.empty((blocks, 4, warp_size), dtype="float32", device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b)
+    return a.numpy().size * 4 / times.min  # 4 bytes per float32
diff --git a/python/tvm/utils/roofline/registry.py b/python/tvm/utils/roofline/registry.py
new file mode 100644
index 0000000000..b3ea522be8
--- /dev/null
+++ b/python/tvm/utils/roofline/registry.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of generic functions for estimating peak flops and bandwidth"""
+from typing import Optional
+from ...target import Target, generic_func
+from ...runtime import Device
+from ...rpc.client import RPCSession
+
+
+@generic_func
+def estimate_peak_bandwidth(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession] = None,
+) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    raise NotImplementedError()
+
+
+@generic_func
+def estimate_peak_flops(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This is a generic function that
+    should be overridden for each target.
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo. Each FMA
+        operation counts as two FLOPs.
+    """
+    raise NotImplementedError()
diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py
new file mode 100644
index 0000000000..d4a0e51184
--- /dev/null
+++ b/python/tvm/utils/roofline/x86.py
@@ -0,0 +1,254 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Estimate peak flops and bandwidth for x86 devices"""
+from typing import Optional
+
+from ... import nd, build, topi, transform, get_global_func
+from ...target import Target
+from ...runtime import Device, num_threads
+from ...script import tir as T
+from ...rpc.base import RPC_SESS_MASK
+from ...rpc.client import RPCSession
+from ...contrib import utils
+from . import registry
+
+
+def _detect_vec_width_registers(
+    target: Target, vec_width: Optional[int], num_vector_registers: Optional[int]
+):
+    """Get the vector width and number of vector registers for a target.
+
+    Parameters
+    ----------
+    target : Target
+        Target to detect vector width and registers for.
+    vec_width : Optional[int]
+        If None, try and detect vector width from target. Otherwise provided input is used.
+    num_vector_registers : Optional[int]
+        If None, try and number of vector registers from target. Otherwise provided input is used.
+
+    Returns
+    -------
+    vec_width: int
+        Width of a vector register on `target`.
+    num_vector_registers: int
+        Number of vector registers on `target`.
+    """
+    if vec_width is None:
+        # Only implemented for x86 so far...
+        if (
+            str(target.kind) == "llvm"
+            and target.device_name == ""
+            and len(target.keys) == 1
+            and target.keys[0] == "cpu"
+        ):
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            num_vector_registers = 16  # Assuming for all platforms, probably wrong on older ones
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    A = T.match_buffer(a, [threads, num_vector_registers, vec_width], "float32")
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t, l, k] = A[t, l, k] * A[t, l, k] + A[t, l, k]
+
+
+@registry.estimate_peak_flops.register("cpu")
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions. Each FMA operation counts as two FLOPs.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _detect_vec_width_registers(
+        target, vec_width, num_vector_registers
+    )
+    iters = 1000000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+
+    # upload to remote if running over rpc
+    if dev.device_type >= RPC_SESS_MASK:
+        if remote is None:
+            raise RuntimeError("A RPCSession must be provided when using a remote device.")
+        temp = utils.tempdir()
+        path = temp.relpath("peak_fma_flops.tar")
+        f.export_library(path)
+        remote.upload(path)
+        f = remote.load_module("peak_fma_flops.tar")
+        random_fill = remote.get_function("tvm.contrib.random.random_fill")
+    else:
+        random_fill = get_global_func("tvm.contrib.random.random_fill")
+    assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
+
+    a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", device=dev)
+    random_fill(a)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [threads, N, 4, vec_width], "float32")
+    B = T.match_buffer(b, [threads, 4, vec_width], "float32")
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(threads):
+        for k in T.serial(N):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    # += is necessary to introduce a data dependency for all
+                    # elements of A, preventing the backend from removing the
+                    # `k` loop and setting `k` to the loop extent.
+                    B[i, l, j] += A[i, k, l, j]
+
+
+@registry.estimate_peak_bandwidth.register("cpu")
+def estimate_peak_bandwidth(
+    target: Target,
+    dev: Device,
+    remote: Optional[RPCSession],
+    vec_width: Optional[int] = None,
+) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    remote : Optional[RPCSession]
+      Remote session used to upload artifacts for runtime evaluation. Must be
+      the same session used to create `dev`.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _detect_vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+
+    # upload to remote if running over rpc
+    if dev.device_type >= RPC_SESS_MASK:
+        if remote is None:
+            raise RuntimeError("A RPCSession must be provided when using a remote device.")
+        temp = utils.tempdir()
+        path = temp.relpath("peak_bandwidth.tar")
+        f.export_library(path)
+        remote.upload(path)
+        f = remote.load_module("peak_bandwidth.tar")
+        random_fill = remote.get_function("tvm.contrib.random.random_fill")
+    else:
+        random_fill = get_global_func("tvm.contrib.random.random_fill")
+    assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake"
+
+    threads = num_threads()
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * threads * vec_width)
+    a = nd.empty((threads, size, 4, vec_width), dtype="float32", device=dev)
+    random_fill(a)
+    b = nd.empty((threads, 4, vec_width), dtype="float32", device=dev)
+    random_fill(b)
+    times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b, threads)
+    return a.numpy().size * 4 / times.min  # 4 bytes per float32
diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc
index 3ea6f8d9ed..dde1d112ed 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -1235,11 +1235,13 @@ void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const Var
   if (scope == "wmma.matrix_a") {
     need_mma_h_ = true;
     std::string layout_str = fragment_layouts[variable];
+    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_a";
     os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
        << ", nvcuda::wmma::" << layout_str << ">";
   } else if (scope == "wmma.matrix_b") {
     need_mma_h_ = true;
     std::string layout_str = fragment_layouts[variable];
+    ICHECK_NE(layout_str, "") << "Layout must be defined for matrix_b";
     os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
        << ", nvcuda::wmma::" << layout_str << ">";
   } else if (scope == "wmma.accumulator") {
diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc
index 1e5b2f28b2..520e3ee03c 100644
--- a/src/tir/ir/specialize.cc
+++ b/src/tir/ir/specialize.cc
@@ -363,6 +363,7 @@ PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
     } else if (instance->IsInstance<PrimExprNode>()) {
       UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), &var_map);
     } else {
+      CHECK(instance.defined()) << "Specialize instance is not defined for param " << param;
       LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got "
                  << instance->GetTypeKey();
     }
diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc
index b050193076..e0ae7172ad 100644
--- a/src/tir/transforms/tensorcore_infer_fragment.cc
+++ b/src/tir/transforms/tensorcore_infer_fragment.cc
@@ -92,15 +92,14 @@ class FragmentGetter : public StmtExprVisitor {
       ICHECK(k);
 
       std::string scope = GetPtrStorageScope(GetRef<Var>(buffer_var));
-      // Only wmma.accumulator can use tvm_fill_fragment
-      ICHECK_EQ(scope, "wmma.accumulator");
       if (fragments.count(buffer_var)) {
         FragmentInfo info = fragments[buffer_var];
         ICHECK_EQ(m->value, info.m);
         ICHECK_EQ(n->value, info.n);
         ICHECK_EQ(k->value, info.k);
       } else {
-        FragmentInfo info(m->value, n->value, k->value, "", scope);
+        // default to row major ordering
+        FragmentInfo info(m->value, n->value, k->value, "row_major", scope);
         fragments[buffer_var] = info;
       }
     }
@@ -148,8 +147,14 @@ class FragmentChecker : public StmtExprVisitor {
  private:
   // A tool for checking shapes of two fragments
   bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) {
-    ICHECK(fragment_getter.fragments.count(buffer1));
-    ICHECK(fragment_getter.fragments.count(buffer2));
+    CHECK(fragment_getter.fragments.count(buffer1))
+        << "Tensorecore fragment " << buffer1->name_hint
+        << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before "
+           "use.";
+    CHECK(fragment_getter.fragments.count(buffer2))
+        << "Tensorecore fragment " << buffer2->name_hint
+        << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before "
+           "use.";
     FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
     FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
     return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
diff --git a/tests/python/unittest/test_roofline.py b/tests/python/unittest/test_roofline.py
new file mode 100644
index 0000000000..e37f6e085b
--- /dev/null
+++ b/tests/python/unittest/test_roofline.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import pytest
+from io import StringIO
+import csv
+import os
+import json
+import platform
+
+import tvm.testing
+import tvm.utils
+from tvm.runtime import profiler_vm
+from tvm import relay
+from tvm.relay.testing import mlp
+from tvm.contrib.debugger import debug_executor
+from tvm import rpc
+from tvm.contrib import utils
+from tvm.runtime.profiling import Report
+from tvm.script import tir as T
+
+
+@tvm.testing.parametrize_targets("llvm", "cuda")
+def test_estimate_peak_flops(target, dev):
+    server = rpc.Server(key="roofline_flops")
+    remote = rpc.connect("127.0.0.1", server.port, key="roofline_flops")
+    dev = remote.device(target)
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    target = tvm.target.Target(target)
+    with target:
+        flops = tvm.utils.roofline.registry.estimate_peak_flops(target, dev, remote)
+    if str(target.kind) == "llvm":
+        # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
+        assert (
+            flops > 10**9 and flops < 10**14
+        ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
+    elif str(target.kind) == "cuda":
+        # should be able to hit a TFLOP/s with tensor cores
+        assert (
+            flops > 10**12 and flops < 10**14
+        ), f"FLOP/s should be between 10^12 and 10^14, but it is {flops}"
+    else:
+        raise RuntimeError("Unsupported target " + str(target))
+
+
+@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
+@tvm.testing.parametrize_targets("llvm", "cuda")
+def test_estimate_peak_bandwidth(target, dev):
+    server = rpc.Server(key="roofline_bandwidth")
+    remote = rpc.connect("127.0.0.1", server.port, key="roofline_bandwidth")
+    dev = remote.device(target)
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    target = tvm.target.Target(target)
+    with target:
+        bandwidth = tvm.utils.roofline.registry.estimate_peak_bandwidth(target, dev, remote)
+    if str(target.kind) == "llvm":
+        # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
+        # GB/s, so this should leave enough wiggle room.
+        assert (
+            bandwidth > 10**9 and bandwidth < 10**12
+        ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+    elif str(target.kind) == "cuda":
+        # should be able to hit a 100 GB/s on a GPU. GTX 280 hits 140 GB/s and
+        # it is really old.
+        assert (
+            bandwidth > 10**11 and bandwidth < 10**13
+        ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+    else:
+        raise RuntimeError("Unsupported target " + str(target))
+
+
+@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
+@tvm.testing.parametrize_targets("llvm -mattr=+fma+avx2", "cuda")
+def test_roofline_analysis(target, dev):
+    a = relay.var("a", relay.TensorType((512, 512), "float32"))
+    b = relay.var("b", relay.TensorType((512, 512), "float32"))
+    c = relay.nn.dense(a, b)
+    mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
+    params = {}
+
+    server = rpc.Server(key="roofline")
+    remote = rpc.connect("127.0.0.1", server.port, key="roofline")
+    dev = remote.device(target)
+
+    report = tvm.utils.roofline_analysis(mod, params, target, dev, remote=remote)
+    print(report)
+
+    assert "Bound" in report.table()
+    assert "Percent of Theoretical Optimal" in report.table()
+    for call in report.calls:
+        if "Percent of Theoretical Optimal" in call:
+            if target.startswith("llvm"):
+                # Ideally we'd like a little tighter bound here, but it is hard to
+                # know how well this dense will perform without tuning. And we
+                # don't have an operator that uses a specific number of flops.
+                assert call["Percent of Theoretical Optimal"].ratio >= 5.0
+            elif target == "cuda":
+                # The cuda gpu kernel is really poorly optimized
+                assert 90 >= call["Percent of Theoretical Optimal"].ratio >= 0.01
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py
index adb5dee174..ed1841b216 100644
--- a/tests/python/unittest/test_runtime_profiling.py
+++ b/tests/python/unittest/test_runtime_profiling.py
@@ -263,103 +263,5 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
-@tvm.testing.parametrize_targets("llvm")
-def test_estimate_peak_fma_flops(target, dev):
-    # This test uses vectorized instructions so we need a target that supports them
-    if target == "llvm":
-        target = "llvm -mattr=+fma,+avx2"
-    flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
-    # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
-    assert (
-        flops > 10**9 and flops < 10**14
-    ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
-
-
-def test_estimate_peak_fma_flops_rpc():
-    target = "llvm -mattr=+fma,+avx2"
-    server = rpc.Server(key="profiling")
-    remote = rpc.connect("127.0.0.1", server.port, key="profiling")
-    dev = remote.device(target)
-    flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev, remote=remote)
-    # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu.
-    assert (
-        flops > 10**9 and flops < 10**14
-    ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}"
-
-
-@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
-@tvm.testing.parametrize_targets("llvm")
-def test_estimate_peak_bandwidth(target, dev):
-    # This test uses vectorized instructions so we need a target that supports them
-    if target == "llvm":
-        target = "llvm -mattr=+fma,+avx2"
-    bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev)
-    # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
-    # GB/s, so this should leave enough wiggle room.
-    assert (
-        bandwidth > 10**9 and bandwidth < 10**12
-    ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
-
-
-@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
-def test_estimate_peak_bandwidth_rpc():
-    target = "llvm -mattr=+fma,+avx2"
-    server = rpc.Server(key="profiling")
-    remote = rpc.connect("127.0.0.1", server.port, key="profiling")
-    dev = remote.device(target)
-    bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev, remote=remote)
-    # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6
-    # GB/s, so this should leave enough wiggle room.
-    assert (
-        bandwidth > 10**9 and bandwidth < 10**12
-    ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
-
-
-@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
-@tvm.testing.parametrize_targets("llvm")
-def test_roofline_analysis(target, dev):
-    a = relay.var("a", relay.TensorType((512, 512), "float32"))
-    b = relay.var("b", relay.TensorType((512, 512), "float32"))
-    c = relay.nn.dense(a, b)
-    mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
-    params = {}
-    report = tvm.utils.roofline_analysis(mod, params, target, dev)
-
-    assert "Bound" in report.table()
-    assert "Percent of Theoretical Optimal" in report.table()
-    for call in report.calls:
-        if "Percent of Theoretical Optimal" in call:
-            # Ideally we'd like a little tighter bound here, but it is hard to
-            # know how well this dense will perform without tuning. And we
-            # don't have an operator that uses a specific number of flops.
-            assert call["Percent of Theoretical Optimal"].ratio >= 5.0
-
-
-@tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386")
-def test_roofline_analysis_rpc():
-    target = "llvm"
-
-    a = relay.var("a", relay.TensorType((512, 512), "float32"))
-    b = relay.var("b", relay.TensorType((512, 512), "float32"))
-    c = relay.nn.dense(a, b)
-    mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
-    params = {}
-
-    server = rpc.Server(key="profiling")
-    remote = rpc.connect("127.0.0.1", server.port, key="profiling")
-    dev = remote.device(target)
-
-    report = tvm.utils.roofline_analysis(mod, params, target, dev, remote=remote)
-
-    assert "Bound" in report.table()
-    assert "Percent of Theoretical Optimal" in report.table()
-    for call in report.calls:
-        if "Percent of Theoretical Optimal" in call:
-            # Ideally we'd like a little tighter bound here, but it is hard to
-            # know how well this dense will perform without tuning. And we
-            # don't have an operator that uses a specific number of flops.
-            assert call["Percent of Theoretical Optimal"].ratio >= 5.0
-
-
 if __name__ == "__main__":
     tvm.testing.main()