You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/19 20:45:32 UTC

[GitHub] [tvm] tkonolige opened a new pull request, #11066: [PROFILER] Theoretical roofline models

tkonolige opened a new pull request, #11066:
URL: https://github.com/apache/tvm/pull/11066

   `tvm.analysis.roofline_analysis` add estimated roofline performance to a profiling report. The roofline model measures how close a operator gets to best possible memory bandwidth or FLOP/s depending on whether it is memory or compute bound. This computation uses the runtime of the operator along with two numbers extracted from the TIR code: bytes of memory touched and number of floating point operations. Because these numbers are extracted from TIR, they may not be 100% accurate. The best possible memory bandwidth and FLOP/s are measured by running small programs that are memory and compute bound respectively.
   
   For now, this function only works with llvm cpu targets, but it should be possible to extend to GPU targets.
   
   @AndrewZhaoLuo @mbrookhart @masahi 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859096015


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[

Review Comment:
   `+=` is required to make a data dependence between all the loads, otherwise llvm could rewrite this to just loading the last element in the loop (`k=N // nt // 4 // vec_width-1`). This compute is much less than the maximum arithmetic intensity of processors, so we will be bandwidth limited (which is what we want). If you look at this comment https://github.com/apache/tvm/pull/11066/files/0e60f8d2a3fd2c234886ea4cb2df57e646fc17ef#diff-a127e78ae9dc951f53d45d15e2a905cb559a0bee347d7fc7232334f9e69a862fR189-R194, I explain in what cases this compute does matter.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r858097347


##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"

Review Comment:
   Turning off vectorization does have an effect. It drops flops down to ~1MFLOP/s.
   
   Not sure how to check if the processor supports an instruction, I'll looked through the codebase and couldn't find any examples. avx2 and fma are at least 10 years old, so I think we can assume x86 targets support them (totally could be wrong here though).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860226117


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(
+    mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
+) -> profiling.Report:
+    """
+    Create a profiling report that contains roofline and other estimated
+    statistics from running a module on the VM.
+
+    These statistics are calculated by analyzing the lowered TIR of each
+    operator, so they are estimates of the true values. The statistics are:
+      - Bound: Is the operator memory or compute bound. This is computed by
+        assuming that the operator could perfectly cache all loads -- each byte
+        of memory is only loaded once.
+      - Percent of Theoretical Optimal: What percent of theoretical optimal for
+        the bound. i.e. percent of peak memory bandwidth if memory bound,
+        percent of peak FLOP/s if compute bound.
+      - Unique Loaded Bytes: estimation of the number of byte loaded not
+        counting multiple accesses to the same byte.
+      - Estimated Flops: estimated number of floating point operations.
+      - Arithmetic Intensity: ratio of FLOPs per byte of data.
+      - FLOP/s: floating point operations per second.
+      - Bandwidth: Number of bytes loaded per second.
+
+    Parameters
+    ----------
+    mod : IRModule
+      Uncompiled input module>
+
+    params : Dict[str, nd.NDArray]
+
+    target : Union[str, Target]
+      Target to run on.
+
+    dev : Device
+      Device to run on.
+
+    Returns
+    -------
+
+    report : profiling.Report
+      Profiling report which includes the estimated statistics.
+    """
+    if isinstance(target, str):
+        target = Target(target)
+    peak_bandwidth = estimate_peak_bandwidth(target, dev)
+    peak_flops = estimate_peak_fma_flops(target, dev)
+
+    ridge_point = peak_flops / peak_bandwidth
+
+    all_features = _estimated_features(mod, params, target)
+
+    lib = relay.vm.compile(mod, params=params, target=target)
+    vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
+
+    args = _create_args(mod, dev)
+    report = vmexec.profile(*args)
+    new_calls = []
+    for call in report.calls:
+        if "Hash" in call.keys():
+            _, features = all_features[call["Hash"]]
+
+            flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"])
+            unique_loaded_bytes = 0.0
+            # assume no more than 100 buffers
+            for i in range(100):
+                # We could uses loaded bytes, but that accounts for for L1 cache.
+                # If we use unique_bytes, then we are looking at how close we come
+                # to the performance assuming all data is cached perfectly.

Review Comment:
   >"given the amount of data we must load, how close are we to loading it efficiently."
   
   I see okay, so I think you're thinking about this like an upperbound on arithmetic intensity wrt to memory access, ie how far can we push the arithmetic intensity just by efficiently accessing the data. 
   
   I am mostly worried about the scenario in which we don't account for multiple loads of the same bytes in an algorithm. My feeling is that it would be desirable to see the arithmetic intensity increase when the choice of algorithm is changed to reduce the number of loads/stores for the same number of flops. That alone can change the problem from memory bound to compute bound, but it would be deceiving if our algorithmic choice appeared to be compute bound only because the arithmetic intensity wasn't counting repeated loads on the same bytes. 
   
   Said differently, when using unique_bytes the same arithmetic intensity could result for two algorithm choices with the same total flops but different numbers of loads/stores.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860085447


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   `-device` is used by TOPI to provide a layer of specialization when the target is generic. So with mali for example, we can have separate topi impls for opencl. And I agree that the way TOPI does hardware specific attribute lookup is how you are doing it here, I just don't like it :). 
   
   It's essentially a side channel for information to flow into the compiler from outside of the target that assumes that the hardware we are compiling for is locally accessible which only works in a limited set of cases. In my view, any device specific attribute that the compiler needs should be retrievable from the target directly, and we now have a path to do this with the attribute preprocessor I mentioned above that directly queries the DeviceAPI for this information. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859232848


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   I don't think the hexagon target string you have provided there is correct. If I look at every "builtin" target provided by tvm, they all set the `-device` flag to something (with the exception of hexagon).
   
   ```
   In [2]: tvm.target.arm_cpu()
   Out[2]: llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0 -model=unknown
   
   In [3]: tvm.target.hexagon()
   Out[3]: hexagon -keys=hexagon -link-params=0 -mattr=+hvxv66,+hvx-length128b -mcpu=hexagonv66 -mtriple=hexagon
   
   In [4]: tvm.target.cuda()
   Out[4]: cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -model=unknown -thread_warp_size=32
   
   In [5]: tvm.target.bifrost()
   Out[5]: opencl -keys=bifrost,opencl,gpu -device=bifrost -max_num_threads=256 -model=unknown -thread_warp_size=1
   
   In [6]: tvm.target.intel_graphics()
   Out[6]: opencl -keys=intel_graphics,opencl,gpu -device=intel_graphics -max_num_threads=256 -model=unknown -thread_warp_size=16
   
   In [7]: tvm.target.mali()
   Out[7]: opencl -keys=mali,opencl,gpu -device=mali -max_num_threads=256 -model=unknown -thread_warp_size=1
   
   In [8]: tvm.target.rasp()
   Out[8]: llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0 -mattr=+neon -model=bcm2837 -mtriple=armv7l-linux-gnueabihf
   
   In [9]: tvm.target.riscv_cpu()
   Out[9]: llvm -keys=arm_cpu,cpu -device=arm_cpu -link-params=0 -mabi=lp64d -mcpu=sifive-u54 -model=sifive-u54 -mtriple=riscv64-unknown-linux-gnu
   
   In [10]: tvm.target.rocm()
   Out[10]: rocm -keys=rocm,gpu -max_num_threads=256 -mcpu=gfx900 -model=unknown -mtriple=amdgcn-amd-amdhsa-hcc -thread_warp_size=64
   ```
   (I have been unable to find any documentation on what `device` actually means/does).
   
   If you have a better way to detect x86, I'd be happy to use it. The only way I've seen this detection done in the codebase is by looking at `mcpu`.
   
   I'd love to have this information available from the targets, but all the x86 code in topi appears to do it the way I do it here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859228219


##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"

Review Comment:
   I don't think `llvm -mcpu=apple-latest -mtriple=arm64-apple-macos` is the correct target for macOS. It is missing `-device=arm_cpu`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853519434


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   Would be good to think about the namespacing choices. 
   
   This is a mixture of profiling that is also related to TIR features. In before analysis under `tir/analysis` or `relay/analysis` were about IR choices.
   
   How about `contrib/roofline` if we aimed at roofline analysis?
   
   Some of the features fit into tir/profile?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853638940


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   The main reason was mentioned above: `analysis` was already used by IR based analysis (in the case of boh tir and relay). So there might be confusions in here 
   
   In this case `profile_analysis` might be a good distinction(I am not sticking to this particular name though and there might be better one), happy to discuss other name choices as well 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853638940


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   The main reason was mentioned above: `analysis` was already used by IR based analysis (in the case of boh tir and relay). So there might be confusions in here 
   
   I understand that it should not be part of runtime/profile because we need compiler support. In this case `profile_analysis` might be a good distinction(I am not sticking to this particular name though and there might be better one), happy to discuss other name choices as well.
   
   One thing to note though is that we might increasing would need compiler counter for profiling. So another possibility could be `profile`(compiler support if needed for jitting) and `runtime/profile_rt`(the runtime counter part)
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859131336


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)

Review Comment:
   Empirically, repeat_ms doesn't seem necessary. There is already one warmup iteration.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on PR #11066:
URL: https://github.com/apache/tvm/pull/11066#issuecomment-1110164874

   @csullivan ?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #11066:
URL: https://github.com/apache/tvm/pull/11066#issuecomment-1113606358

   Thanks @tkonolige I have not yet approved the change explicitly because of the requests (as in the comment section).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r854070352


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   A new namespace is a main architectural choice as well as API choice itself that would benefit from broader [deliberation and discussions](https://tvm.apache.org/docs/contribute/code_review.html#deliberate-on-api-and-data-structures). Especially when we are not too sure about the choice of the name. So it might be helpful to get more inputs when introducing a new top-level namespace. I spend more time deliberating  about the particular choices over night.
   
   The particular current module started with LLVM cpu target and is more like a standalone tooling that to be continuously improved and used. Our previous convention would start with such self-contained tooling in contrib, (e.g.`contrib/popen_pool`), indeed in those cases the contrib provides less info other than "collection of contributed tools", but they use less deliberations architectural-wise and can unblock the PR. And the current tooling matches the characteristics of other collections as well (depends on a few things and aimed to enable certain goals).
   
   In the meantime, if we want to think about possible ways to group things as the tooling matures, it would be great to open a discuss thread to also get some collective wisdoms from the community members.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r854070352


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   Coming back after more deliberations. A new namespace is a main architectural choice as well as API choice itself that would benefit from broader [deliberation and discussions](https://tvm.apache.org/docs/contribute/code_review.html#deliberate-on-api-and-data-structures). Especially when we are not too sure about the choice of the name.
   
   In the meantime, the particular module started with LLVM cpu target and is more like a standalone tooling that to be continuously improved and used. Our previous convention would start with such self-contained tooling in contrib, e.g.
   `contrib/popen_pool`, which can use less deliberations architectural-wise and unblock the PR. In the meantime, we can open a discuss thread to see if others's input about possible ways to group the things as they matures.
   
   Please let me know if that makes sense.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r858242458


##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"

Review Comment:
   Yeah so this runs fine on m1 mac. I get around 2e11 FLOPS. 
   
   However this is with just 'llvm' target reaching in which is interesting. Changing the target string to any of the following does not change results, at least not noticeably. It does output some things
   
   "llvm"
   "llvm -mcpu=apple-latest -mtriple=arm64-apple-macos"
   
   Additionally 
   "llvm -mattr=+fma,+avx2" generates the following printout
   ```
   tests/python/unittest/test_runtime_profiling.py '+fma' is not a recognized feature for this target (ignoring feature)
   '+avx2' is not a recognized feature for this target (ignoring feature)
   '+fma' is not a recognized feature for this target (ignoring feature)
   '+avx2' is not a recognized feature for this target (ignoring feature)
   '+fma' is not a recognized feature for this target (ignoring feature)
   '+avx2' is not a recognized feature for this target (ignoring feature)
   '+fma' is not a recognized feature for this target (ignoring feature)
   '+avx2' is not a recognized feature for this target (ignoring feature)
   '+fma' is not a recognized feature for this target (ignoring feature)
   '+avx2' is not a recognized feature for this target (ignoring feature)
   '+fma' is not a recognized feature for this target (ignoring feature)
   '+avx2' is not a recognized feature for this target (ignoring feature)
   ```
   
   Can't grep for the string above so Imma assume this is just llvm carrying TVM somehow



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859233374


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")

Review Comment:
   good point, done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859029029


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):

Review Comment:
   nit: suggest name like `_resolve_vec_width_registers` to reflect None handling behavior



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"

Review Comment:
   is this meant to be uncommented?



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,

Review Comment:
   imo you should just call this "input_buffer" or something



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):

Review Comment:
   nit: add type annotations here to be consistent with the rest of file



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)

Review Comment:
   do we need to do something like warmup or repeat_ms to properly get accurate performance numbers?



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:

Review Comment:
   nt --> threads for consistency with other func



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[

Review Comment:
   Why += instead of =. ? You have some compute which seems to be not good?



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   hmm so correct me if im wrong, but size would have to be greater than 2x the last level of cache to get no hits on consecutive runs.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860317231


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   I've changed this check to make sure the target is llvm and that there aren't any keys besides "cpu". I think that should cover our bases for now.
   
   `get_simd_32bit_lanes` defaults to 4 if it can't figure out what the target is. Not sure this is a reasonable default, but I think we should leave changes to `get_simd_32bit_lanes` to a separate PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859097753


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,

Review Comment:
   IMO `input_buffer` conveys about as much information as `a`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859138609


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")

Review Comment:
   nit: use nd-buffer to simplify the TIR, 
   >A = T.match_buffer(a, [threads, num_vector_registers, vec_width], "float32")



##########
include/tvm/runtime/profiling.h:
##########
@@ -459,6 +459,21 @@ class CountNode : public Object {
   TVM_DECLARE_FINAL_OBJECT_INFO(CountNode, Object);
 };
 
+/* \brief A ratio of two things. */
+class RatioNode : public Object {
+ public:
+  /* The ratio as a floating point number. */
+  double ratio;

Review Comment:
   nit: Double precision floating point number, change type or comment to match.



##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GFLOP/s per thread
+    assert (
+        flops > 10**9 * tvm.runtime.num_threads() and flops < 10**14
+    ), f"FLOP/s should be between 10^9 * num_threads and 10^14, but it is {flops}"
+
+
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_bandwidth(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GB/s
+    assert (
+        bandwidth > 10**9 and bandwidth < 10**12
+    ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+
+
+@pytest.mark.skipif(platform.machine() == "i386", reason="Cannot allocate enough memory on i386")
+@tvm.testing.parametrize_targets("llvm")
+def test_roofline_analysis(target, dev):
+    a = relay.var("a", relay.TensorType((512, 512), "float32"))
+    b = relay.var("b", relay.TensorType((512, 512), "float32"))
+    c = relay.nn.dense(a, b)
+    mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
+    params = {}
+    report = tvm.utils.roofline_analysis(mod, params, target, dev)
+
+    assert "Bound" in report.table()
+    assert "Percent of Theoretical Optimal" in report.table()
+    for call in report.calls:
+        if "Percent of Theoretical Optimal" in call:
+            # Ideally we'd like a little tighter bound here, but it is hard to
+            # know how well this dense will perform without tuning. And we
+            # don't have an operator that uses a specific number of flops.
+            assert call["Percent of Theoretical Optimal"].ratio >= 0
+
+

Review Comment:
   A couple classification tests would be rather impressive. e.g. asserting that dense (as above) of sufficient size is compute bound, and some injective op is memory bound



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   not necessarily, could be arm, hexagon, etc., right?



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   It would be neat to do all the STREAM benches for the memory bandwidth measurement and combine them. 



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   I'm just ramping into the discussion here but given cache size variability I would argue for the iterative approach for increasing the array size until performance plateaus. It's easy to measure and less likely to lead to false assumptions being made. 
   
   



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   Would also be interesting to apply the specific stream benchmark that applies to a given workload pattern based on the read and writes defined in a block. Probably diminishing returns given that they are usually comparable measurements of the bandwidth. 



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(
+    mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
+) -> profiling.Report:
+    """
+    Create a profiling report that contains roofline and other estimated
+    statistics from running a module on the VM.
+
+    These statistics are calculated by analyzing the lowered TIR of each
+    operator, so they are estimates of the true values. The statistics are:
+      - Bound: Is the operator memory or compute bound. This is computed by
+        assuming that the operator could perfectly cache all loads -- each byte
+        of memory is only loaded once.
+      - Percent of Theoretical Optimal: What percent of theoretical optimal for
+        the bound. i.e. percent of peak memory bandwidth if memory bound,
+        percent of peak FLOP/s if compute bound.
+      - Unique Loaded Bytes: estimation of the number of byte loaded not
+        counting multiple accesses to the same byte.
+      - Estimated Flops: estimated number of floating point operations.
+      - Arithmetic Intensity: ratio of FLOPs per byte of data.
+      - FLOP/s: floating point operations per second.
+      - Bandwidth: Number of bytes loaded per second.
+
+    Parameters
+    ----------
+    mod : IRModule
+      Uncompiled input module>
+
+    params : Dict[str, nd.NDArray]
+
+    target : Union[str, Target]
+      Target to run on.
+
+    dev : Device
+      Device to run on.
+
+    Returns
+    -------
+
+    report : profiling.Report
+      Profiling report which includes the estimated statistics.
+    """
+    if isinstance(target, str):
+        target = Target(target)
+    peak_bandwidth = estimate_peak_bandwidth(target, dev)
+    peak_flops = estimate_peak_fma_flops(target, dev)
+
+    ridge_point = peak_flops / peak_bandwidth
+
+    all_features = _estimated_features(mod, params, target)
+
+    lib = relay.vm.compile(mod, params=params, target=target)
+    vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
+
+    args = _create_args(mod, dev)
+    report = vmexec.profile(*args)
+    new_calls = []
+    for call in report.calls:
+        if "Hash" in call.keys():
+            _, features = all_features[call["Hash"]]
+
+            flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"])
+            unique_loaded_bytes = 0.0
+            # assume no more than 100 buffers
+            for i in range(100):
+                # We could uses loaded bytes, but that accounts for for L1 cache.
+                # If we use unique_bytes, then we are looking at how close we come
+                # to the performance assuming all data is cached perfectly.

Review Comment:
   I don't understand why you prefer to make this assumption on data being accessed from the cache perfectly? This will bias you towards fewer loads overall, or a greater arith intensity. Likely I am not understanding this comment. 



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(

Review Comment:
   Equivalent functionality for profiling a TIR primfunc would be very nice, and I could imagine as part of a standard TIR compiler toolchain (debugger, profiler, etc), cc @supersat



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859201299


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   My understanding is that the `llvm` target (which is where `device_name==""`) implies x86. If we look at all `llvm` strategies in topi (called `cpu` there), they assume x86.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r858102457


##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"

Review Comment:
   I agree for x86 it's probably safe, but in general as ARM gets more popular we should try to come up with a way to deal with this. I dunno, probs best ask others



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860315129


##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GFLOP/s per thread
+    assert (
+        flops > 10**9 * tvm.runtime.num_threads() and flops < 10**14
+    ), f"FLOP/s should be between 10^9 * num_threads and 10^14, but it is {flops}"
+
+
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_bandwidth(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GB/s
+    assert (
+        bandwidth > 10**9 and bandwidth < 10**12
+    ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}"
+
+
+@pytest.mark.skipif(platform.machine() == "i386", reason="Cannot allocate enough memory on i386")
+@tvm.testing.parametrize_targets("llvm")
+def test_roofline_analysis(target, dev):
+    a = relay.var("a", relay.TensorType((512, 512), "float32"))
+    b = relay.var("b", relay.TensorType((512, 512), "float32"))
+    c = relay.nn.dense(a, b)
+    mod = tvm.IRModule.from_expr(relay.Function([a, b], c))
+    params = {}
+    report = tvm.utils.roofline_analysis(mod, params, target, dev)
+
+    assert "Bound" in report.table()
+    assert "Percent of Theoretical Optimal" in report.table()
+    for call in report.calls:
+        if "Percent of Theoretical Optimal" in call:
+            # Ideally we'd like a little tighter bound here, but it is hard to
+            # know how well this dense will perform without tuning. And we
+            # don't have an operator that uses a specific number of flops.
+            assert call["Percent of Theoretical Optimal"].ratio >= 0
+
+

Review Comment:
   The default (untuned) schedules for dense seems to always be memory bound instead of compute bound, so I can't really test for that without tuning. I'll try some other operators and see if it is true for them too.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860098777


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   > In my view, any device specific attribute that the compiler needs should be retrievable from the target directly, and we now have a path to do this with the attribute preprocessor I mentioned above that directly queries the DeviceAPI for this information.
   
   I agree on this point, but I think doing this change is out of scope of this PR. I'd have to change a bunch of stuff within topi.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r858087970


##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GFLOP/s per thread

Review Comment:
   How is this achieved, is it 1 flop per cycle in a 1 GHZ processor? 
   
   Might be nice to clarify here



##########
src/runtime/profiling.cc:
##########
@@ -288,6 +290,8 @@ void print_metric(std::ostream& os, ObjectRef o) {
     os << "{\"microseconds\":" << std::setprecision(17) << std::fixed << n->microseconds << "}";
   } else if (const PercentNode* n = o.as<PercentNode>()) {
     os << "{\"percent\":" << std::setprecision(17) << std::fixed << n->percent << "}";
+  } else if (const RatioNode* n = o.as<RatioNode>()) {
+    os << "{\"ratio\":" << std::setprecision(17) << std::fixed << n->ratio << "}";

Review Comment:
   nit: suggest using `std::numeric_limits<double>::max_digits10` instead of `17` for this and the above usages too



##########
src/runtime/profiling.cc:
##########
@@ -343,6 +347,46 @@ String ReportNode::AsJSON() const {
   return s.str();
 }
 
+// Aggregate a set of values for a metric. Computes sum for Duration, Count,
+// and Percent; average for Ratio; and assumes all Strings are the same. All
+// ObjectRefs in metrics must have the same type.
+ObjectRef AggregateMetric(const std::vector<ObjectRef>& metrics) {
+  ICHECK_GT(metrics.size(), 0) << "Must pass a non-zero number of metrics";
+  if (metrics[0].as<DurationNode>()) {
+    double sum = 0;
+    for (auto& metric : metrics) {
+      sum += metric.as<DurationNode>()->microseconds;
+    }
+    return ObjectRef(make_object<DurationNode>(sum));
+  } else if (metrics[0].as<CountNode>()) {
+    int64_t sum = 0;
+    for (auto& metric : metrics) {
+      sum += metric.as<CountNode>()->value;
+    }
+    return ObjectRef(make_object<CountNode>(sum));
+  } else if (metrics[0].as<PercentNode>()) {
+    double sum = 0;
+    for (auto& metric : metrics) {
+      sum += metric.as<PercentNode>()->percent;
+    }
+    return ObjectRef(make_object<PercentNode>(sum));
+  } else if (metrics[0].as<RatioNode>()) {
+    double sum = 0;
+    for (auto& metric : metrics) {
+      sum += metric.as<RatioNode>()->ratio;
+    }
+    return ObjectRef(make_object<RatioNode>(sum / metrics.size()));
+  } else if (metrics[0].as<StringObj>()) {
+    // Assume all strings in metrics are the same.

Review Comment:
   Know the old code didn't check but might be a good idea to check here assumption is correct



##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    flops = tvm.utils.estimate_peak_fma_flops(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GFLOP/s per thread
+    assert (
+        flops > 10**9 * tvm.runtime.num_threads() and flops < 10**14
+    ), f"FLOP/s should be between 10^9 * num_threads and 10^14, but it is {flops}"
+
+
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_bandwidth(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"
+    bandwidth = tvm.utils.estimate_peak_bandwidth(tvm.target.Target(target), dev)
+    # assume we can achieve 1 GB/s

Review Comment:
   also be nice to see where bandwidth assumption comes from



##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"

Review Comment:
   Is it possible to turn off vectorization and see if that affects flops as a test?



##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"

Review Comment:
   This test also will require processors which support these extensions right? So I can't run this on my m1 mac. can we skip this if i do not support avx2?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r858244561


##########
tests/python/unittest/test_runtime_profiling.py:
##########
@@ -257,6 +259,50 @@ def test_profile_function(target, dev):
     assert report[metric].value > 0
 
 
+@tvm.testing.parametrize_targets("llvm")
+def test_estimate_peak_fma_flops(target, dev):
+    # This test uses vectorized instructions so we need a target that supports them
+    if target == "llvm":
+        target = "llvm -mattr=+fma,+avx2"

Review Comment:
   `test_vm()` does seg fault on my machine but it's probably unrelated im guessing.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860314700


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(
+    mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
+) -> profiling.Report:
+    """
+    Create a profiling report that contains roofline and other estimated
+    statistics from running a module on the VM.
+
+    These statistics are calculated by analyzing the lowered TIR of each
+    operator, so they are estimates of the true values. The statistics are:
+      - Bound: Is the operator memory or compute bound. This is computed by
+        assuming that the operator could perfectly cache all loads -- each byte
+        of memory is only loaded once.
+      - Percent of Theoretical Optimal: What percent of theoretical optimal for
+        the bound. i.e. percent of peak memory bandwidth if memory bound,
+        percent of peak FLOP/s if compute bound.
+      - Unique Loaded Bytes: estimation of the number of byte loaded not
+        counting multiple accesses to the same byte.
+      - Estimated Flops: estimated number of floating point operations.
+      - Arithmetic Intensity: ratio of FLOPs per byte of data.
+      - FLOP/s: floating point operations per second.
+      - Bandwidth: Number of bytes loaded per second.
+
+    Parameters
+    ----------
+    mod : IRModule
+      Uncompiled input module>
+
+    params : Dict[str, nd.NDArray]
+
+    target : Union[str, Target]
+      Target to run on.
+
+    dev : Device
+      Device to run on.
+
+    Returns
+    -------
+
+    report : profiling.Report
+      Profiling report which includes the estimated statistics.
+    """
+    if isinstance(target, str):
+        target = Target(target)
+    peak_bandwidth = estimate_peak_bandwidth(target, dev)
+    peak_flops = estimate_peak_fma_flops(target, dev)
+
+    ridge_point = peak_flops / peak_bandwidth
+
+    all_features = _estimated_features(mod, params, target)
+
+    lib = relay.vm.compile(mod, params=params, target=target)
+    vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
+
+    args = _create_args(mod, dev)
+    report = vmexec.profile(*args)
+    new_calls = []
+    for call in report.calls:
+        if "Hash" in call.keys():
+            _, features = all_features[call["Hash"]]
+
+            flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"])
+            unique_loaded_bytes = 0.0
+            # assume no more than 100 buffers
+            for i in range(100):
+                # We could uses loaded bytes, but that accounts for for L1 cache.
+                # If we use unique_bytes, then we are looking at how close we come
+                # to the performance assuming all data is cached perfectly.

Review Comment:
   You've convinced me, I'll switch it over.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859206088


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   Yeah there is some oddness going on  https://github.com/apache/tvm/pull/11066#discussion_r858242458 <-- perhaps this assumption is not entirely accurate or maybe llvm acts as a nice compatibility layer



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #11066:
URL: https://github.com/apache/tvm/pull/11066#issuecomment-1113238429

   Thank you @tkonolige for the contribution and @csullivan @AndrewZhaoLuo for reviewing. 
   
   As the current proposal now we introduced a new top-level namespace (`utils`). While I personally also would tend to agree with the namespace choice, as stated in an earlier comment, please open a discuss thread for posterity since this would be a case to get community involvement and awareness as new top-level is a main architectural choice. 
   
   After a new namespace is introduced, it usually less common to have implementations directly sit in `__init__.py`, as a new namespace would entail further grows of tools of similar kind in the same namespace. As a result, it would be helpful to move the implementation to a separate file(or files, e.g. roofline.py or other better names) under the same namespace then import.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853638940


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   The main reason was mentioned above: `analysis` as reserved for IR based analysis itself(in the case of tir/relay analysis). So there might be confusions in here 
   
   In this case `profile_analysis` might be a good distinction part(it was not the best name that i can come up with), happy to discuss other name choices as well 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r854070352


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   Following up after more time deliberating  about the particular choices over night as new namespace constitute a main architectural choice as well as API choice itself that would benefit from broader [deliberation and discussions](https://tvm.apache.org/docs/contribute/code_review.html#deliberate-on-api-and-data-structures). Especially when we are not too sure about the choice of the name. So it might be helpful to get more inputs when introducing a new top-level namespace. 
   
   The particular current module started with LLVM cpu target and is more like a standalone tooling that to be continuously improved and used. Our previous convention would start with such self-contained tooling in contrib, (e.g.`contrib/popen_pool`), indeed in those cases the contrib provides less info other than "collection of contributed tools", but they use less deliberations architectural-wise and can unblock the PR. And the current tooling matches the characteristics of other collections as well (depends on a few things and aimed to enable certain goals).
   
   In the meantime, if we want to think about possible ways to group things as the tooling matures, it would be great to open a discuss thread to also get some collective wisdoms from the community members.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r855336812


##########
python/tvm/relay/backend/vm.py:
##########
@@ -150,9 +150,7 @@ def lower(self, mod, target=None, target_host=None):
             target, target_host, target_is_dict_key=False
         )
 
-        tophub_context = self._tophub_context(target)

Review Comment:
   Whoops, accidentally left this in. See https://github.com/apache/tvm-rfcs/pull/59 for some context.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859127453


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"

Review Comment:
   whoops, removed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859213141


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   Yeah, having an empty device name is more just a special case for llvm.
   
   
   > target = tvm.target.Target("
   > llvm -keys=hexagon,cpu \
   >       -link-params=0 -mattr=+hvxv68, \
   >       **+hvx-length128b**, +hvx-qfloat, -hvx-ieee-fp \
   >       -mcpu=hexagonv68 -mtriple=hexagon
   > ") 
   
   
   will also have an empty device name similar to Andrew's comment. I think we need a different way to dispatch here. 
   
   In general I'd like to see this information queried as attrs from the target which, when not provided, can be queried from the device via an attr [preprocessor](https://github.com/apache/tvm/blob/main/src/target/target_kind.cc#L295) used to update the attributes for a target kind -- which can [query](https://github.com/apache/tvm/blob/main/src/target/target_kind.cc#L114) the local or remote device api for the necessary information.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r861277767


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   Agreed there isn't a great way to reliably determine plateauing. I suppose could consider letting the user define a convergence function and then use something naive as a default (e.g. walk up the size in continuous steps and after a threshold of comparable runs, do a few additional steps of much larger interval?). 
   
   But I think I'm okay with this as is for now, more desirably would be for this also to be a queryable attribute via the target attribute preprocessor flow. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r854803368


##########
python/tvm/relay/backend/vm.py:
##########
@@ -150,9 +150,7 @@ def lower(self, mod, target=None, target_host=None):
             target, target_host, target_is_dict_key=False
         )
 
-        tophub_context = self._tophub_context(target)

Review Comment:
   Well TIL this exists...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853638940


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   The main reason was mentioned above: `analysis` was already used by IR based analysis (in the case of boh tir and relay). So there might be confusions in here 
   
   In this case `profile_analysis` might be a good distinction(it was not the best name that i can come up with), happy to discuss other name choices as well 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859198956


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   Personally I think it would be nice to expose an interface to change this (you can set it default 10 ** 8) in the worst case for what csullivan says above



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860041432


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(

Review Comment:
   I think we could add that functionality without too much work. Maybe in a separate PR though?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859067984


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   Or actually maybe not, since there are no guarantees on execution order. Do we need to clear cache between runs?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859110634


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[

Review Comment:
   Ah I see.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853527343


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   I wasn't sure about the correct namespace either. It doesn't really fit in `tir` or `relay` because it is about high level features, not anything specific to tir or relay. It doesn't fit in `runtime` because it requires auto scheduler, so I can't put it in `runtime/profiling`. Also, it is not really profiling, it is an analysis.
   
   I don't think `contrib/roofline` is a good match because contrib is poorly defined and adds no extra information. `roofline` doesn't make sense because there are more functions than just the roofline model (estimated flops and bandwidth).
   
   The reason I choose `analysis` is because it seemed better than all the choices listed above. Do you have a reason not to use `analysis`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853519434


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   It be good to think about the namespacing choices here.
   
   This is a mixture of profiling that is also related to TIR features. In before analysis under `tir/analysis` or `relay/analysis` were about IR choices. So analysis might create some confusion here due to ambiguity
   
   How about `contrib/roofline` if we aimed at roofline analysis?
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo merged pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo merged PR #11066:
URL: https://github.com/apache/tvm/pull/11066


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853519434


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   It be good to think about the namespacing choices here.
   
   This is a mixture of profiling that is also related to TIR features. In before analysis under `tir/analysis` or `relay/analysis` were about IR choices. 
   
   So analysis might create some confusion here due to ambiguity (whether it is analysis of tir or in this case a mix of runtime profiling and IR anlysis)
   
   How about:
   -  `contrib/roofline` if we aimed at roofline analysis?
   -  Alternatively, `tir/profile_analysis` given it is profile driven analysis of TIR
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859198607


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)

Review Comment:
   Yeah for something on the order of n=1. If it's easy to add why not add it? From discussions below on this + cache size it looks like there will need to be some guess and checking in the worst case so why no expose the interface to do that (you can set default arguments in the function anyway).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859129624


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   I'm not sure it is worth exposing the size to the user, then they have to go to more work to get the correct size. We could keep increasing the size until performance flatlines, but it would be computationally expensive and we would need a different parameter to control when to stop.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859093053


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   Do you mean execution order between threads? This assumes that threads are executing at the same time.
   
   Ideally we would have a way to make sure that each load is directly from main memory, but there is no cross-platform/cross-architecture was of doing that (that I know of). If we make the data size large enough, then we can assume that the vast majority of loads don't hit cache. 2x LLC size would probably be a safe default, but it does depend a bit on the caching strategy.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on PR #11066:
URL: https://github.com/apache/tvm/pull/11066#issuecomment-1116700164

   I'm going to merge, looks like in discussion thread there is little controversy


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] AndrewZhaoLuo commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
AndrewZhaoLuo commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859110139


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   Hmmm, I see. Would it be worth it to expose size to the user? 
   
   Also easy way to graph size vs estimated bandwidth and we can see when we hit large enough to avoid cache hit



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860018635


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)

Review Comment:
   I am using the minimum amongst all runs here, so adding the warmup_ms actually hurts performance because it forces the time evaluator to average across a couple iterations.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853638940


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   The main reason was mentioned above: `analysis` was already used by IR based analysis (in the case of boh tir and relay). So there might be confusions in here 
   
   I understand that it should not be part of runtime/profile because we need compiler support. In this case `profile_analysis` might be a good distinction(I am not sticking to this particular name though and there might be better one), happy to discuss other name choices as well.
   
   One thing to note though is that we might increasing would need compiler counter for profiling. So another possible could be `profile`(compiler support if needed for jitting) and `runtime/profile_rt`(the runtime counter part)
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on PR #11066:
URL: https://github.com/apache/tvm/pull/11066#issuecomment-1104464737

   `python/tvm/utils` might indeed be a better choice, in the spirit of that, it might be useful to introduce python.tvm.utils.roofline, assuming the collection of utils growing in the future.
   
   Also would be good to open a discuss thread for posterity as this is indeed a top-level namespace introduction. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r854070352


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   Coming back after more deliberations. A new namespace is a main architectural choice as well as API choice itself that would benefit from broader [deliberation and discussions](https://tvm.apache.org/docs/contribute/code_review.html#deliberate-on-api-and-data-structures). Especially when we are not too sure about the choice of the name.
   
   In the meantime, the particular module started with LLVM cpu target and is more like a standalone tooling that to be continuously improved and used. Our previous convention would start with such self-contained tooling in contrib, (e.g.`contrib/popen_pool`), indeed in those cases the contrib provides less info other than "collection of contributed tools", but they use less deliberations architectural-wise and can unblock the PR.
   
   In the meantime, I would recommend open a discuss thread to see to include thoughts from other community members (usually when it comes to naming collective wisdom helps) as well about possible ways to group the things as the tooling matures.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] jwfromm commented on pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
jwfromm commented on PR #11066:
URL: https://github.com/apache/tvm/pull/11066#issuecomment-1104204050

   I agree that the analysis namespace is confusing due to overloading. I personally would prefer something like either `python/tvm/utils` or `python/tvm/tools` for this.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853519434


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   It be good to think about the namespacing choices here.
   
   This is a mixture of profiling that is also related to TIR features. In before analysis under `tir/analysis` or `relay/analysis` were about IR choices. 
   
   So analysis might create some confusion here due to ambiguity (whether it is analysis of tir or in this case a mix of runtime profiling and IR anlysis)
   
   How about:
   -  `contrib/roofline` if we aimed at a focused tooling around roofline analysis?
   -  Alternatively, `tir/profile_analysis` given it is profile driven analysis of TIR
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tqchen commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r853519434


##########
python/tvm/analysis/__init__.py:
##########
@@ -0,0 +1,296 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""High level analysis functions"""

Review Comment:
   Would be good to think about the namespacing choices. 
   
   This is a mixture of profiling that is also related to TIR features. In before analysis under `tir/analysis` or `relay/analysis` were about IR choices. So analysis might create some confusion here due to ambiguity
   
   How about `contrib/roofline` if we aimed at roofline analysis?
   
   Some of the features fit into tir/profile?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860034447


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)

Review Comment:
   > I'm just ramping into the discussion here but given cache size variability I would argue for the iterative approach for increasing the array size until performance plateaus. It's easy to measure and less likely to lead to false assumptions being made.
   
   I would like to have a more robust approach, but the iterative approach also suffers from needing some sort of configuration. How do we decide that we have plateaued at memory bandwidth? What if we plateaued at LLC bandwidth? We'd have to run past a certain size to make sure it wasn't LLC, so we just end up running something around the size I've set here. The best solution is to have some way of getting LLC size and then doing a multiple of it, but we don't have a way to do that right now. Do you have a proposal for how I might determine if the bandwidth has plateaued or not?
   
   > It would be neat to do all the STREAM benches for the memory bandwidth measurement and combine them.
   
   I'm not sure the STREAM benchmarks are a good fit here because they involve writing back to memory, which may not be a reflection on what is happening in a regular ML kernel. And how would we combine them? Just average? I've looked at what other tools do to estimate bandwidth/flops and most have handwritten assembly or optimized kernels (like here) that just do load or just do flops.
   
   > Would also be interesting to apply the specific stream benchmark that applies to a given workload pattern based on the read and writes defined in a block. Probably diminishing returns given that they are usually comparable measurements of the bandwidth.
   
   I've thought about this idea, especially with regards to cache size/type of flops, but I haven't come up with a good way to determine cache sizes (for access patterns) or type of compute (FMA or not). It is definitely something I want to do, but I was trying to get a small working PR in first and then we can make the analysis more specific.
   
   > Personally I think it would be nice to expose an interface to change this (you can set it default 10 ** 8) in the worst case for what csullivan says above
   
   I guess we can expose it, but how would the user know what to set it to?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860240098


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   I agree there are lots of places where this is assumed in TOPI, and I don't want you to be burdened with that in this PR, but for the cases in which target.device_name is not set and the machine is not x86 the profiler will silently do the wrong thing. This applies for hexagon, cuda, opencl, rocm, and others. 
   
   At minimum I guess we need to check that the target is llvm, and the device name is empty. 
   
   Can we also make it more strict such that `get_simd_32bit_lanes` does not return a default if the mcpu name is not found in its knowledgebase, or is that default behavior being relied on? If it's being relied on maybe we simple make `get_simd_32bit_lanes` take a parameter that defaults to True which indicates whether we should guess the number of lanes, and when false an error is thrown. Then here we can set that param to False so that for the case of profiling we don't make target specific assumptions. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] csullivan commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
csullivan commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r859213141


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86

Review Comment:
   Yeah, having an empty device name is more just a special case for llvm.
   
   
   > target = tvm.target.Target("
   > llvm -keys=hexagon,cpu \
   >       -link-params=0 -mattr=+hvxv68, \
   >       **+hvx-length128b**, +hvx-qfloat, -hvx-ieee-fp \
   >       -mcpu=hexagonv68 -mtriple=hexagon
   > ") 
   
   
   will also have an empty device name similar to Andrew's comment. I think we need a different way to dispatch here. 
   
   In general I'd like to see this information queried as attrs from the target which, when not provided, can be queried from the device via an attr [preprocessor](https://github.com/apache/tvm/blob/main/src/target/target_kind.cc#L295) used to update the attributes for a target kind -- which [queries](https://github.com/apache/tvm/blob/main/src/target/target_kind.cc#L114) the local or remote device api for the necessary information.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on a diff in pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on code in PR #11066:
URL: https://github.com/apache/tvm/pull/11066#discussion_r860040269


##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(
+    mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
+) -> profiling.Report:
+    """
+    Create a profiling report that contains roofline and other estimated
+    statistics from running a module on the VM.
+
+    These statistics are calculated by analyzing the lowered TIR of each
+    operator, so they are estimates of the true values. The statistics are:
+      - Bound: Is the operator memory or compute bound. This is computed by
+        assuming that the operator could perfectly cache all loads -- each byte
+        of memory is only loaded once.
+      - Percent of Theoretical Optimal: What percent of theoretical optimal for
+        the bound. i.e. percent of peak memory bandwidth if memory bound,
+        percent of peak FLOP/s if compute bound.
+      - Unique Loaded Bytes: estimation of the number of byte loaded not
+        counting multiple accesses to the same byte.
+      - Estimated Flops: estimated number of floating point operations.
+      - Arithmetic Intensity: ratio of FLOPs per byte of data.
+      - FLOP/s: floating point operations per second.
+      - Bandwidth: Number of bytes loaded per second.
+
+    Parameters
+    ----------
+    mod : IRModule
+      Uncompiled input module>
+
+    params : Dict[str, nd.NDArray]
+
+    target : Union[str, Target]
+      Target to run on.
+
+    dev : Device
+      Device to run on.
+
+    Returns
+    -------
+
+    report : profiling.Report
+      Profiling report which includes the estimated statistics.
+    """
+    if isinstance(target, str):
+        target = Target(target)
+    peak_bandwidth = estimate_peak_bandwidth(target, dev)
+    peak_flops = estimate_peak_fma_flops(target, dev)
+
+    ridge_point = peak_flops / peak_bandwidth
+
+    all_features = _estimated_features(mod, params, target)
+
+    lib = relay.vm.compile(mod, params=params, target=target)
+    vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
+
+    args = _create_args(mod, dev)
+    report = vmexec.profile(*args)
+    new_calls = []
+    for call in report.calls:
+        if "Hash" in call.keys():
+            _, features = all_features[call["Hash"]]
+
+            flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"])
+            unique_loaded_bytes = 0.0
+            # assume no more than 100 buffers
+            for i in range(100):
+                # We could uses loaded bytes, but that accounts for for L1 cache.
+                # If we use unique_bytes, then we are looking at how close we come
+                # to the performance assuming all data is cached perfectly.

Review Comment:
   I was thinking of this as a measure of "given the amount of data we must load, how close are we to loading it efficiently." For example, we could be doing many unnecessary passes over the data just in order to get good bandwidth, but a single pass over the data in a suboptimal way may result in faster runtime. It does bias the results a little towards a higher arithmetic intensity, but it is a more accurate measure of how close the program is to optimal bandwidth usage. However, it doesn't match how flops are counted, no maybe it is worth switching. Thought?



##########
python/tvm/utils/__init__.py:
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities operating at a graph/model or other "high" level"""
+import csv
+import subprocess
+from typing import Dict, Union, Optional
+import numpy as np
+
+from .. import auto_scheduler, relay, tir, device, nd, IRModule, build, topi, transform
+from ..target import Target
+from ..runtime import profiler_vm, profiling, Device, num_threads
+from ..script import tir as T
+
+
+def _create_args(mod, dev, func_name="main"):
+    args = []
+    for arg in mod[func_name].params:
+        args.append(
+            nd.array(
+                np.zeros([x.value for x in arg.type_annotation.shape], arg.type_annotation.dtype),
+                device=dev,
+            )
+        )
+    return args
+
+
+def _estimated_features(mod, params, target):
+    comp = relay.vm.VMCompiler()
+    mod, params = comp.optimize(mod, params=params, target=target)
+    return {
+        prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
+        for name, prim in mod.functions.items()
+        if isinstance(prim, tir.PrimFunc)
+    }
+
+
+def _vec_width_registers(target, vec_width, num_vector_registers):
+    if vec_width is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                vec_width = topi.x86.utils.get_simd_32bit_lanes()  # in number of float32s
+        else:
+            raise RuntimeError(f"Cannot determine vector width for target {target}")
+    if num_vector_registers is None:
+        if target.device_name == "":  # indicates x86
+            with target:
+                num_vector_registers = (
+                    16  # Assuming for all platforms, probably wrong on older ones
+                )
+        else:
+            raise RuntimeError(f"Cannot determine number of vector registers for target {target}")
+    return vec_width, num_vector_registers
+
+
+@T.prim_func
+def peakflops_fma_tir(
+    a: T.handle,
+    vec_width: T.int32,
+    iters: T.int32,
+    num_vector_registers: T.int32,
+    threads: T.int32,
+) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    assert (
+        N >= threads * num_vector_registers * vec_width
+    ), "Input vectors must be >= num_vector_registers*vec_width"
+    for t in T.parallel(threads):
+        for _j in range(iters):
+            for l in T.unroll(num_vector_registers):
+                # We want to use as few registers as possible, so we perform
+                # all operations on the same element
+                for k in T.vectorized(vec_width):
+                    A[t * vec_width * num_vector_registers + vec_width * l + k] = (
+                        A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        * A[t * vec_width * num_vector_registers + vec_width * l + k]
+                        + A[t * vec_width * num_vector_registers + vec_width * l + k]
+                    )
+
+
+def estimate_peak_fma_flops(
+    target: Target,
+    dev: Device,
+    vec_width: Optional[int] = None,
+    num_vector_registers: Optional[int] = None,
+) -> float:
+    """
+    Estimate the maximum number of FLOP/s this target/device combo is capable
+    of reaching by running a test program. This assumes vectorized f32 FMA
+    (fused-multiply-add) instructions.
+
+
+    Parameters
+    ----------
+    target : Target
+        Target to run on. This should be as specific to the actual hardware as
+        possible to make sure that LLVM generates the best vector code.
+    dev : Device
+        Device to run on.
+    vec_width : Optional[int]
+        Vector width of SIMD units on the underlying hardware. Will try to
+        infer if no value is provided.
+    num_vector_registers : Optional[int]
+        Number of vector registers on the underlying hardware. Will try to
+        infer if no value is provided.
+
+    Returns
+    -------
+    float
+        Approximate sustained FLOP/s of this target/device combo assuming
+        vectorized f32 FMA instructions.
+    """
+    assert str(target.kind) == "llvm", "Only llvm targets are supported"
+    vec_width, num_vector_registers = _vec_width_registers(target, vec_width, num_vector_registers)
+    iters = 100000
+    nthreads = num_threads()
+    specialized = peakflops_fma_tir.specialize(
+        {
+            peakflops_fma_tir.params[1]: vec_width,
+            peakflops_fma_tir.params[2]: iters,
+            peakflops_fma_tir.params[3]: num_vector_registers,
+            peakflops_fma_tir.params[4]: nthreads,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    a = nd.array(np.ones(vec_width * num_vector_registers * nthreads).astype("float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=100)(a)
+    flops = 2 * vec_width * num_vector_registers * nthreads * iters  # fma is two flops
+    flop_s = flops / times.min
+    return flop_s
+
+
+@T.prim_func
+def peak_bandwidth_tir(a: T.handle, b: T.handle, nt: T.int32, vec_width: T.int32) -> None:
+    # pylint: disable=invalid-name, missing-function-docstring
+    N = T.var("int32")
+    A = T.match_buffer(a, [N], "float32")
+    B = T.match_buffer(b, [nt * vec_width * 4], "float32")
+    # assert N % (nt * 4 * vec_width) == 0, "div"
+    # Parallelism is necessary to hit all cores/nodes
+    for i in T.parallel(nt):
+        for k in T.serial(N // nt // 4 // vec_width):
+            for l in T.unroll(4):
+                # vectorized load is necessary to hit peak bandwidth
+                for j in T.vectorized(vec_width):
+                    B[i * vec_width * 4 + l * vec_width + j] += A[
+                        i * (N // nt) + k * vec_width * 4 + l * vec_width + j
+                    ]
+
+
+def estimate_peak_bandwidth(target: Target, dev: Device, vec_width: Optional[int] = None) -> float:
+    """Estimate peak memory bandwidth of a target/device combo.
+
+    Peak bandwidth is estimated by running a small experiment on the underlying
+    hardware. The peak bandwidth measurement assumes that vector instructions
+    are being used to load the data.
+
+    Parameters
+    ----------
+    target : Target
+        Target to use for measurement. This target should be as specific to the
+        underlying hardware as possible.
+    dev : Device
+        Device to measure peak bandwidth on.
+    vec_width : Optional[int]
+        Vector unit width, determined from target if not supplied.
+
+    Returns
+    -------
+    float
+        Peak memory bandwidth in bytes/seconds.
+    """
+    # Ideally we'd be able to use this code to measure peak bandwidth of the
+    # different cache levels. If we could just generate load commands, then we
+    # could use those in a tight loop. Instead we need some code that is
+    # limited on the cache bandwidth. With the L1 cache we need an operation
+    # that has a very low arithmetic intensity and we haven't come up with one
+    # yet.
+    vec_width, _ = _vec_width_registers(target, vec_width, 1)
+    specialized = peak_bandwidth_tir.specialize(
+        {
+            peak_bandwidth_tir.params[3]: vec_width,
+        }
+    )
+    with transform.PassContext(opt_level=3):
+        f = build(specialized, target=target)
+    # Data size needs to be larger than last level of cache. We don't have a
+    # way of getting cache sizes, so this number should give us a large enough
+    # size.
+    size = 10**8 // (4 * num_threads() * vec_width) * (4 * num_threads() * vec_width)
+    a = nd.array(np.ones(size, dtype="float32"), device=dev)
+    b = nd.array(np.ones(vec_width * 4 * num_threads(), dtype="float32"), device=dev)
+    times = f.time_evaluator(f.entry_name, dev, repeat=5, number=1)(a, b, num_threads())
+    return size * 4 / times.min  # 4 bytes per float32
+
+
+def roofline_analysis(
+    mod: IRModule, params: Dict[str, nd.NDArray], target: Union[str, Target], dev: Device
+) -> profiling.Report:
+    """
+    Create a profiling report that contains roofline and other estimated
+    statistics from running a module on the VM.
+
+    These statistics are calculated by analyzing the lowered TIR of each
+    operator, so they are estimates of the true values. The statistics are:
+      - Bound: Is the operator memory or compute bound. This is computed by
+        assuming that the operator could perfectly cache all loads -- each byte
+        of memory is only loaded once.
+      - Percent of Theoretical Optimal: What percent of theoretical optimal for
+        the bound. i.e. percent of peak memory bandwidth if memory bound,
+        percent of peak FLOP/s if compute bound.
+      - Unique Loaded Bytes: estimation of the number of byte loaded not
+        counting multiple accesses to the same byte.
+      - Estimated Flops: estimated number of floating point operations.
+      - Arithmetic Intensity: ratio of FLOPs per byte of data.
+      - FLOP/s: floating point operations per second.
+      - Bandwidth: Number of bytes loaded per second.
+
+    Parameters
+    ----------
+    mod : IRModule
+      Uncompiled input module>
+
+    params : Dict[str, nd.NDArray]
+
+    target : Union[str, Target]
+      Target to run on.
+
+    dev : Device
+      Device to run on.
+
+    Returns
+    -------
+
+    report : profiling.Report
+      Profiling report which includes the estimated statistics.
+    """
+    if isinstance(target, str):
+        target = Target(target)
+    peak_bandwidth = estimate_peak_bandwidth(target, dev)
+    peak_flops = estimate_peak_fma_flops(target, dev)
+
+    ridge_point = peak_flops / peak_bandwidth
+
+    all_features = _estimated_features(mod, params, target)
+
+    lib = relay.vm.compile(mod, params=params, target=target)
+    vmexec = profiler_vm.VirtualMachineProfiler(lib, dev)
+
+    args = _create_args(mod, dev)
+    report = vmexec.profile(*args)
+    new_calls = []
+    for call in report.calls:
+        if "Hash" in call.keys():
+            _, features = all_features[call["Hash"]]
+
+            flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"])
+            unique_loaded_bytes = 0.0
+            # assume no more than 100 buffers
+            for i in range(100):
+                # We could uses loaded bytes, but that accounts for for L1 cache.
+                # If we use unique_bytes, then we are looking at how close we come
+                # to the performance assuming all data is cached perfectly.

Review Comment:
   I was thinking of this as a measure of "given the amount of data we must load, how close are we to loading it efficiently." For example, we could be doing many unnecessary passes over the data just in order to get good bandwidth, but a single pass over the data in a suboptimal way may result in faster runtime. It does bias the results a little towards a higher arithmetic intensity, but it is a more accurate measure of how close the program is to optimal bandwidth usage. However, it doesn't match how flops are counted, no maybe it is worth switching. 
   
   Thoughts?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tkonolige commented on pull request #11066: [PROFILER] Theoretical roofline models

Posted by GitBox <gi...@apache.org>.
tkonolige commented on PR #11066:
URL: https://github.com/apache/tvm/pull/11066#issuecomment-1113539291

   @tqchen I've move the implementation into `tvm/utils/roofline.py` as I think you've requested. In the future, could you please explicitly approve PRs or request changes. See https://tvm.apache.org/docs/contribute/code_review.html#approve-and-request-changes-explicitly.
   
   I've also posted to discuss as requested: https://discuss.tvm.apache.org/t/new-top-level-python-namespace-tvm-utils/12683.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org