You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/26 22:26:06 UTC

[GitHub] [tvm] csullivan commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

csullivan commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859138609


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")

Review Comment:
   nit: use nd-buffer to simplify the TIR, 
   >A = T.match_buffer(a, [threads, num_vector_registers, vec_width], "float32")



##########
include/tvm/runtime/profiling.h:
##########
@@ -459,6 +459,21 @@ class CountNode : public Object {
   TVM_DECLARE_FINAL_OBJECT_INFO(CountNode, Object);
 };
 
+/* \brief A ratio of two things. */
+class RatioNode : public Object {
+ public:
+  /* The ratio as a floating point number. */
+  double ratio;

Review Comment:
   nit: Double precision floating point number, change type or comment to match.



##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GFLOP/s per thread
+    assert (
+        flops > 10**9 * tvm.runtime.num_threads() and flops < 10**14
+    ), f"FLOP/s should be between 10^9 * num_threads and 10^14, but it is {flops}"
+
+
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_bandwidth(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GB/s
+    assert (
+        bandwidth > 10**9 and bandwidth < 10**12
+    ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+
+
+@pytest.mark.skipif(platform.machine() == "i386", reason="Cannot allocate enough memory on i386")
+@tvm.testing.parametrize_targets("llvm")
+def test_roofline_analysis(target, dev):
+    a = relay.var("a", relay.TensorType((512, 512), "float32"))
+    b = relay.var("b", relay.TensorType((512, 512), "float32"))
+    c = relay.nn.dense(a, b)
+    mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
+    params = {}
+    report = tvm.utils.roofline_analysis(mod, params, target, dev)
+
+    assert "Bound" in report.table()
+    assert "Percent of Theoretical Optimal" in report.table()
+    for call in report.calls:
+        if "Percent of Theoretical Optimal" in call:
+            # Ideally we'd like a little tighter bound here, but it is hard to
+            # know how well this dense will perform without tuning. And we
+            # don't have an operator that uses a specific number of flops.
+            assert call["Percent of Theoretical Optimal"].ratio >= 0
+
+

Review Comment:
   A couple classification tests would be rather impressive. e.g. asserting that dense (as above) of sufficient size is compute bound, and some injective op is memory bound



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   not necessarily, could be arm, hexagon, etc., right?



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   It would be neat to do all the STREAM benches for the memory bandwidth measurement and combine them. 



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   I'm just ramping into the discussion here but given cache size variability I would argue for the iterative approach for increasing the array size until performance plateaus. It's easy to measure and less likely to lead to false assumptions being made. 
   
   



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   Would also be interesting to apply the specific stream benchmark that applies to a given workload pattern based on the read and writes defined in a block. Probably diminishing returns given that they are usually comparable measurements of the bandwidth. 



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(
+    mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
+) -> profiling.Report:
+    """
+    Create a profiling report that contains roofline and other estimated
+    statistics from running a module on the VM.
+
+    These statistics are calculated by analyzing the lowered TIR of each
+    operator, so they are estimates of the true values. The statistics are:
+      - Bound: Is the operator memory or compute bound. This is computed by
+        assuming that the operator could perfectly cache all loads -- each byte
+        of memory is only loaded once.
+      - Percent of Theoretical Optimal: What percent of theoretical optimal for
+        the bound. i.e. percent of peak memory bandwidth if memory bound,
+        percent of peak FLOP/s if compute bound.
+      - Unique Loaded Bytes: estimation of the number of byte loaded not
+        counting multiple accesses to the same byte.
+      - Estimated Flops: estimated number of floating point operations.
+      - Arithmetic Intensity: ratio of FLOPs per byte of data.
+      - FLOP/s: floating point operations per second.
+      - Bandwidth: Number of bytes loaded per second.
+
+    Parameters
+    ----------
+    mod : IRModule
+      Uncompiled input module>
+
+    params : Dict[str, nd.NDArray]
+
+    target : Union[str, Target]
+      Target to run on.
+
+    dev : Device
+      Device to run on.
+
+    Returns
+    -------
+
+    report : profiling.Report
+      Profiling report which includes the estimated statistics.
+    """
+    if isinstance(target, str):
+        target = Target(target)
+    peak_bandwidth = estimate_peak_bandwidth(target, dev)
+    peak_flops = estimate_peak_fma_flops(target, dev)
+
+    ridge_point = peak_flops / peak_bandwidth
+
+    all_features = _estimated_features(mod, params, target)
+
+    lib = relay.vm.compile(mod, params=params, target=target)
+    vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
+
+    args = _create_args(mod, dev)
+    report = vmexec.profile(*args)
+    new_calls = []
+    for call in report.calls:
+        if "Hash" in call.keys():
+            _, features = all_features[call["Hash"]]
+
+            flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"])
+            unique_loaded_bytes = 0.0
+            # assume no more than 100 buffers
+            for i in range(100):
+                # We could uses loaded bytes, but that accounts for for L1 cache.
+                # If we use unique_bytes, then we are looking at how close we come
+                # to the performance assuming all data is cached perfectly.

Review Comment:
   I don't understand why you prefer to make this assumption on data being accessed from the cache perfectly? This will bias you towards fewer loads overall, or a greater arith intensity. Likely I am not understanding this comment. 



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(

Review Comment:
   Equivalent functionality for profiling a TIR primfunc would be very nice, and I could imagine as part of a standard TIR compiler toolchain (debugger, profiler, etc), cc @supersat



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org