You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2023/01/10 21:43:06 UTC

[GitHub] [tvm] guberti opened a new pull request, #13752: [DRAFT] [microTVM] Use QNN schedules to improve performance

guberti opened a new pull request, #13752:
URL: https://github.com/apache/tvm/pull/13752

   **This pull request is not ready for review.**
   
   In #13242, I rewrote microTVM's convolution schedules to give a major improvement in performance. While I demonstrated in tests that my changes worked, they could not be used with `relay.build`.
   
   This pull request expands the functionality of #13242 and adds new `legalize` and `alter_op` passes to take advantage of the quantized schedules. This dramatically improves performance on some models, dramatically cuts RAM usage, and removes the _need_ for autotuning on microTVM. More specifically, for the `vww` model from MLPerf Tiny, this pull request:
   
   - Improves **untuned** performance from `1741 ms` to `225 ms` - a **6.8x** improvement!
   - Improves **tuned** performance from `337 ms` to `225 ms`.
     - This closes **80%** of the performance gap between us and the current state-of-the-art (which is `205 ms`).
   - Reduces memory consumption by **73 KB** (a large amount on microcontrollers!) by eliminating intermediate buffers.
   
   [TODO work with Mehrdad so he can sign off on these numbers]
   
   To enable the schedules that grant these performance improvements, this pull request:
   - Adds `out_layout` support to the regular and depthwise conv2d schedules from #13242.
   - Generalizes the schedules from #13242 to be more widely applicable.
   - Adds a layout alternation pass to ensure regular and depthwise conv2d schedules always get their desired input formats.
   - Adds a `conv2d -> depthwise conv2d -> unpadded conv2d` rewrite step to remove empty channels from `conv2d` operators.
   - Adds a `conv2d -> average pool -> dense` rewrite step to remove empty channels from `conv2d` operators.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1101903538


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,139 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(

Review Comment:
   Done, for both `_make_unrolled_conv2d_primfunc` and `_make_conv2d_primfunc`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch merged pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "areusch (via GitHub)" <gi...@apache.org>.
areusch merged PR #13752:
URL: https://github.com/apache/tvm/pull/13752


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087304976


##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,11 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+unsigned int random_seed = 0;
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  // Add a random suffix to eliminate duplicate global variables.
+  int suffix = rand_r(&random_seed) % (2 << 24);

Review Comment:
   any reason to not just use a counter and add that as suffix to each global var name?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on PR #13752:
URL: https://github.com/apache/tvm/pull/13752#issuecomment-1416544043

   Would love a look from @areusch and @mehrdadh!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1101021878


##########
src/relay/transforms/simplify_expr.cc:
##########
@@ -979,7 +992,16 @@ Pass SimplifyExpr() {
   return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"});
 }
 
+Pass SimplifyExprPostAlterOp() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(SimplifyExprPostAlterOp(f, m));
+      };
+  return CreateFunctionPass(pass_func, 0, "SimplifyExprPostAlterOp", {"InferType"});
+}
+
 TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr);
+// Don't globally register SimplifyExprPostAlterOp

Review Comment:
   You're right, we should register this as well. It doesn't seem useful elsewhere, but it is weird to not register just this one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1113411550


##########
python/tvm/relay/qnn/op/_qnn.py:
##########
@@ -85,12 +91,72 @@ def simulated_dequantize_compute(attrs, inputs, output_type):
 register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy)
 
 
+def _get_clip_dtype_bounds(dtype):
+    """Returns the minimum and maximum values of a C integer data type."""
+    assert "int" in dtype
+    bits = int(dtype[dtype.find("int") + 3 :])
+
+    if dtype.startswith("int"):
+        return (-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)

Review Comment:
   That's much cleaner - fixed!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1101023183


##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,10 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+NameSupply global_name_supply = NameSupply("");
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  std::string symbol_name = global_name_supply->FreshName(op->buffer_var->name_hint);

Review Comment:
   Unfortunately, just calling `name_supply_->FreshName(op->buffer_var->name_hint);` does not work - we need to have a global name generator.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mehrdadh commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "mehrdadh (via GitHub)" <gi...@apache.org>.
mehrdadh commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1096409344


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,139 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(
+    output_dimensions: Tuple[int, int, int],
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    function_names: Dict[Tuple, str],
+    function_code: str,
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout="NHWC",
+):
+    out_height, out_width, out_channels = output_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    data_ptr, kernel_ptr = ptr_gens
+
+    def output_ptr(output, y, c):
+        if output_layout == "NHWC":
+            return _make_tscript_ptr(output, y * const(out_width * out_channels) + c, 1)
+        elif output_layout == "NCHW":
+            return _make_tscript_ptr(
+                output, c * const(out_height * out_width) + y * const(out_width), 1
+            )
+        else:
+            raise TVMError(f"Unsupported out_layout '{output_layout}'!")
+
+    def make_row_call(buffers, c_var, y, c):
+        output, data, kernel, bias, scale = buffers
+        return _make_tscript_call(
+            function_names[(y + c) % 2, c % 2, 0],
+            output_ptr(output, y, c_var + c),
+            data_ptr(data, y, c_var + c, offset=(y + c) % 2),
+            kernel_ptr(kernel, c_var + c, offset=c),
+            _bias_ptr(bias, c_var + c),
+            _scale_ptr(scale, c_var + c),
+        )
+
+    @T.prim_func
+    def biased_quantized_conv2d(
+        data_handle: T.handle,
+        kernel_handle: T.handle,
+        bias_handle: T.handle,
+        scale_handle: T.handle,
+        output_handle: T.handle,
+    ) -> None:
+        # Same setup is used as in _make_conv2d_primfunc
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        data = T.match_buffer(data_handle, data_shape, dtype="int16")
+        kernel = T.match_buffer(kernel_handle, kernel_shape, dtype="int16")
+        bias = T.match_buffer(bias_handle, bias_shape, dtype="int32")
+        scale = T.match_buffer(scale_handle, scale_shape)
+        output = T.match_buffer(output_handle, output_shape, dtype="int16")
+
+        # pylint: disable=unused-variable
+        output[0, 0, 0, 0] = 0
+        __1 = data[0, 0, 0, 0]
+        __2 = kernel[0, 0, 0, 0]
+        __3 = bias[0, 0, 0, 0]
+        __4 = scale[0]
+        # pylint: enable=unused-variable
+
+        for c_ax in T.grid(out_channels // 2):
+            with T.block("conv2ds"):
+                T.block_attr({"pragma_import_c": function_code})
+                c = T.axis.remap("S", [c_ax]) * 2
+
+                # TODO how can I programatically make the right number of

Review Comment:
   where you planing to change this part?



##########
python/tvm/topi/arm_cpu/qnn_legalize.py:
##########
@@ -0,0 +1,349 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""QNN legalization transforms that help eliminate sparse channels.
+
+Some models (like MobileNetV1 when fine-tuned) have output channels in their kernels which are
+completely full of zeros. Sometimes these can be optimized away by the C compiler, but this does not
+happen when complex schedules (like the ACLE tensordot convolutions) are used.
+
+Instead, we will remove these channels by replacing blocks of operators with equivalent "denser"
+ones during legalization. This is harder than it looks - while the outputs of channels with all-zero
+kernels do not depend on the input data, they are usually not zero. We work around this by computing
+how these constant values affect subsequent operators, and "folding" these effects into a bias_add.
+
+It would eventually be nice to have a generalized, cross-target solution for removing zero channels,
+as there is no downside. This may be possible with Relax, but I'm unsure.
+"""
+
+import numpy as np
+from scipy.signal import convolve2d
+from tvm.topi.utils import get_const_tuple
+from tvm import nd, relay
+from .qnn_alter_op import prev_ops_match, edit_attrs
+from ..nn import qnn_bias_add_legalize
+
+
+def _compute_fixed_conv2d_outputs(requantize_op):
+    """Compute all conv2d output values that do not depend on the layer input."""
+    bias_add_op = requantize_op.args[0]
+    conv2d_op = bias_add_op.args[0]
+
+    assert conv2d_op.attrs.kernel_layout.isalpha()
+    assert conv2d_op.attrs.groups == 1
+    kernel = conv2d_op.args[1].data.numpy()
+    oc_axis = conv2d_op.attrs.kernel_layout.index("O")
+
+    num_channels = kernel.shape[oc_axis]
+    rq_input_scale = requantize_op.args[1].data.numpy()
+    rq_output_scale = requantize_op.args[3].data.numpy().item()
+    rq_output_zero_point = requantize_op.args[4].data.numpy().item()
+    bias_data = bias_add_op.args[1].data.numpy()
+
+    fixed_outputs = {}
+
+    for i in range(num_channels):
+        if np.any(np.take(kernel, i, axis=oc_axis)):
+            continue
+        scale = rq_input_scale[i] / rq_output_scale
+        channel_constant = round(bias_data[i] * scale + rq_output_zero_point)
+        clipped = min(127, max(-128, channel_constant))
+        fixed_outputs[i] = clipped
+
+    return fixed_outputs
+
+
+def _compute_fixed_depthwise_outputs(requantize_op, fixed_channel_inputs):
+    """Compute all depthwise conv2d output values that do not depend on the PREVIOUS layer input.
+
+    We take as input a requantize operator, and a dictionary of which inputs to our depthwise
+    operator are fixed and what values they are fixed to. However, a fixed input to one channel
+    of our depthwise operator does NOT guarantee we can remove the output, because of padding.
+    This function checks if the padding makes a difference in the outputs, and if not, removes
+    the channels from the depthwise_conv2d.
+    """
+    bias_add_op = requantize_op.args[0]
+    depthwise_op = bias_add_op.args[0]
+
+    assert depthwise_op.attrs.kernel_layout.isalpha()
+    assert depthwise_op.attrs.groups > 1
+    kernel = depthwise_op.args[1].data.numpy()
+    oc_axis = depthwise_op.attrs.kernel_layout.index("O")
+
+    conv_input_zero_point = depthwise_op.args[2].data.numpy().item()
+    rq_input_scale = requantize_op.args[1].data.numpy()
+    rq_output_scale = requantize_op.args[3].data.numpy().item()
+    rq_output_zero_point = requantize_op.args[4].data.numpy().item()
+    bias_data = bias_add_op.args[1].data.numpy()
+
+    kernel_size = get_const_tuple(depthwise_op.attrs.kernel_size)
+
+    # Make a kernel_size x kernel_size array of fixed_input
+    # Pad it with zeros usint padding
+    # Do a convolution and make sure
+
+    fixed_outputs = {}
+
+    for i, fixed_input in fixed_channel_inputs.items():
+        input_array = np.full(kernel_size, fixed_input, dtype="int32") - conv_input_zero_point
+        kernel_channel = np.take(kernel, i, axis=oc_axis).reshape(kernel_size)
+        scale = rq_input_scale[i] / rq_output_scale
+
+        convolved = convolve2d(input_array, kernel_channel, mode="same")
+        rounded = np.around((convolved + bias_data[i]) * scale).astype("int32")
+        clipped = np.clip(rounded + rq_output_zero_point, -128, 127)
+
+        # We require the ENTIRE padded convolution to all have the same clipped value before we do
+        # a replacement. This is excessive - we only have to check for the padding that will
+        # actually be performed on the depthwise convolution, which is often less. If we felt even
+        # more ambitious, we could do the replacement for "close enough" looking convolution
+        # outputs, which in theory could reduce accuracy but in practice does not. Doing this would
+        # yield a ~0.5% speed gain on MobileNetV1, and nothing on other models.
+
+        if np.all(clipped == clipped[0, 0]):
+            fixed_outputs[i] = clipped[0, 0]
+
+    # TODO look for all-zero entries in the depthwise kernel. I don't think these really occur in

Review Comment:
   same here?
   also for TODOs please add your github handler in ()



##########
python/tvm/topi/arm_cpu/qnn_legalize.py:
##########
@@ -0,0 +1,349 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""QNN legalization transforms that help eliminate sparse channels.
+
+Some models (like MobileNetV1 when fine-tuned) have output channels in their kernels which are
+completely full of zeros. Sometimes these can be optimized away by the C compiler, but this does not
+happen when complex schedules (like the ACLE tensordot convolutions) are used.
+
+Instead, we will remove these channels by replacing blocks of operators with equivalent "denser"
+ones during legalization. This is harder than it looks - while the outputs of channels with all-zero
+kernels do not depend on the input data, they are usually not zero. We work around this by computing
+how these constant values affect subsequent operators, and "folding" these effects into a bias_add.
+
+It would eventually be nice to have a generalized, cross-target solution for removing zero channels,
+as there is no downside. This may be possible with Relax, but I'm unsure.
+"""
+
+import numpy as np
+from scipy.signal import convolve2d
+from tvm.topi.utils import get_const_tuple
+from tvm import nd, relay
+from .qnn_alter_op import prev_ops_match, edit_attrs
+from ..nn import qnn_bias_add_legalize
+
+
+def _compute_fixed_conv2d_outputs(requantize_op):
+    """Compute all conv2d output values that do not depend on the layer input."""
+    bias_add_op = requantize_op.args[0]
+    conv2d_op = bias_add_op.args[0]
+
+    assert conv2d_op.attrs.kernel_layout.isalpha()
+    assert conv2d_op.attrs.groups == 1
+    kernel = conv2d_op.args[1].data.numpy()
+    oc_axis = conv2d_op.attrs.kernel_layout.index("O")
+
+    num_channels = kernel.shape[oc_axis]
+    rq_input_scale = requantize_op.args[1].data.numpy()
+    rq_output_scale = requantize_op.args[3].data.numpy().item()
+    rq_output_zero_point = requantize_op.args[4].data.numpy().item()
+    bias_data = bias_add_op.args[1].data.numpy()
+
+    fixed_outputs = {}
+
+    for i in range(num_channels):
+        if np.any(np.take(kernel, i, axis=oc_axis)):
+            continue
+        scale = rq_input_scale[i] / rq_output_scale
+        channel_constant = round(bias_data[i] * scale + rq_output_zero_point)
+        clipped = min(127, max(-128, channel_constant))
+        fixed_outputs[i] = clipped
+
+    return fixed_outputs
+
+
+def _compute_fixed_depthwise_outputs(requantize_op, fixed_channel_inputs):
+    """Compute all depthwise conv2d output values that do not depend on the PREVIOUS layer input.
+
+    We take as input a requantize operator, and a dictionary of which inputs to our depthwise
+    operator are fixed and what values they are fixed to. However, a fixed input to one channel
+    of our depthwise operator does NOT guarantee we can remove the output, because of padding.
+    This function checks if the padding makes a difference in the outputs, and if not, removes
+    the channels from the depthwise_conv2d.
+    """
+    bias_add_op = requantize_op.args[0]
+    depthwise_op = bias_add_op.args[0]
+
+    assert depthwise_op.attrs.kernel_layout.isalpha()
+    assert depthwise_op.attrs.groups > 1
+    kernel = depthwise_op.args[1].data.numpy()
+    oc_axis = depthwise_op.attrs.kernel_layout.index("O")
+
+    conv_input_zero_point = depthwise_op.args[2].data.numpy().item()
+    rq_input_scale = requantize_op.args[1].data.numpy()
+    rq_output_scale = requantize_op.args[3].data.numpy().item()
+    rq_output_zero_point = requantize_op.args[4].data.numpy().item()
+    bias_data = bias_add_op.args[1].data.numpy()
+
+    kernel_size = get_const_tuple(depthwise_op.attrs.kernel_size)
+
+    # Make a kernel_size x kernel_size array of fixed_input

Review Comment:
   remove or if needed for explanation, make it more polished?



##########
src/relay/transforms/simplify_expr.cc:
##########
@@ -979,7 +992,16 @@ Pass SimplifyExpr() {
   return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"});
 }
 
+Pass SimplifyExprPostAlterOp() {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        return Downcast<Function>(SimplifyExprPostAlterOp(f, m));
+      };
+  return CreateFunctionPass(pass_func, 0, "SimplifyExprPostAlterOp", {"InferType"});
+}
+
 TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr);
+// Don't globally register SimplifyExprPostAlterOp

Review Comment:
   why is that?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087302880


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,136 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(
+    output_dimensions: Tuple[int, int, int],
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    function_names: Dict[Tuple, str],
+    function_code: str,
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout="NHWC",
+):
+    out_height, out_width, out_channels = output_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    data_ptr, kernel_ptr = ptr_gens
+
+    def output_ptr(output, y, c):

Review Comment:
   nit: inconsistent use of leading '_' for internal functions



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1088328877


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,136 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(
+    output_dimensions: Tuple[int, int, int],
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    function_names: Dict[Tuple, str],
+    function_code: str,
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout="NHWC",
+):
+    out_height, out_width, out_channels = output_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    data_ptr, kernel_ptr = ptr_gens
+
+    def output_ptr(output, y, c):

Review Comment:
   I just noticed that inside the `def _make_unrolled_conv2d_primfunc` block there is the following:
   ```
       def output_ptr
       def _make_row_call
       def biased_quantized_conv2d
   ```
   
   But I may be missing something and I am by no means an expert on the subject 😆



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mkatanbaf commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "mkatanbaf (via GitHub)" <gi...@apache.org>.
mkatanbaf commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1093859994


##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,11 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+unsigned int random_seed = 0;
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  // Add a random suffix to eliminate duplicate global variables.
+  int suffix = rand_r(&random_seed) % (2 << 24);

Review Comment:
   I believe a similar fix is needed for the case usmp is enabled here: https://github.com/apache/tvm/blob/2877c5a3cf126637f0968bb9090454410c426cd0/src/target/source/source_module.cc#L295



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ibsidorenko commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "ibsidorenko (via GitHub)" <gi...@apache.org>.
ibsidorenko commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1111600542


##########
python/tvm/topi/nn/qnn.py:
##########
@@ -212,6 +212,48 @@ def qnn_requantize_alter_layout(_attrs, _inputs, _tinfos, _out_type):
     return None
 
 
+@tvm.target.generic_func
+def qnn_bias_add_legalize(_attrs, _inputs, _tinfos):
+    """Legalize bias_add layout.
+
+    Bias add is not a QNN-specific function, but this generic exists so that empty channels can
+    be excised from quantized conv2d operators and folded into bias adds.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : tvm.relay.Expr
+        Grouped input symbols
+    tinfos : list
+        Input shape and dtype
+
+    """
+    return None
+
+
+@tvm.target.generic_func
+def qnn_clip_legalize(_attrs, inputs, _tinfos, _out_type):

Review Comment:
   Do we need this? Can not find place where we call `qnn_clip_legalize` 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1088251735


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -221,23 +261,21 @@ def qnn_conv2d(attrs, inputs, out_type):
     # Make a few checks to unpack the function arguments and ensure it was called with the right
     # arguments. Note that unlike most schedules, qnn_conv2d does not use a wrapper.
     assert len(inputs) == 11
-    data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8]
-    output_layout = attrs.out_layout
-    assert output_layout == "NHWC"
+    assert not any(get_const_tuple(attrs.padding))
 
+    data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8]
     _, height, width, in_channels = get_const_tuple(data.shape)
     out_channels, kernel_h, kernel_w, _ = get_const_tuple(kernel.shape)
-    y_stride, x_stride = get_const_tuple(attrs.strides)
 
+    y_stride, x_stride = get_const_tuple(attrs.strides)
     out_height = _compute_output_dim(height, kernel_h, y_stride)
     out_width = _compute_output_dim(width, kernel_w, x_stride)
 
     # Decide how many sums our function should have running at the same time. Doing
     # this lets us do "more work" for each memory load, but doing too many of them causes us to run
     # out of registers. Currently this is set to either 1 or 2, but autotuning this value would

Review Comment:
   Reworded the sentence to be correct.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087319226


##########
tests/python/relay/strategy/arm_cpu/test_quantized_convolution.py:
##########
@@ -95,6 +95,56 @@ def _get_mobilenet_v1_layer_attributes(layer_num):
     return ((1, 1, 1, 1), (1, 1), True)
 
 
+@pytest.mark.parametrize("layer", range(2, 27, 2))
+def test_infinite_bias_detection(interpreter, layer):
+    """Some models (mainly MobileNetV1) have kernels with many output channels full entirely of
+    zeroes. The VWW mdoel is one of these. This test confirms that the outputs of these channels,
+    as computed by TensorFlow, are indeed not dependent upon the input values.
+    """
+
+    _, kernel, bias, output = _load_tflite_layer(interpreter, layer)
+    kernel_data, kernel_quant = kernel
+    bias_data, bias_quant = bias
+    output_data, output_quant = output
+    is_depthwise = _get_mobilenet_v1_layer_attributes(layer)[2]
+    assert not is_depthwise
+    assert kernel_data.shape[1] == kernel_data.shape[2] == 1
+
+    out_channels = kernel_data.shape[3]
+    fixed_channels = {}
+
+    out_zero_point = output_quant["zero_points"][0]
+    assert out_zero_point == -128
+
+    for i in range(out_channels):
+        # Skip over output channels with data
+        if np.any(kernel_data[i, 0, 0, :]):
+            continue
+
+        scale = bias_quant["scales"][i] / output_quant["scales"][0]
+        channel_constant = round(bias_data[i] * scale + out_zero_point)
+        clipped = min(127, max(-128, channel_constant))
+
+        out_channel_values = output_data[0, :, :, i].flatten()
+        assert all(x == clipped for x in out_channel_values)
+        fixed_channels[i] = clipped
+    print(f"Layer {layer} had {len(fixed_channels)}/{out_channels} empty!")
+
+    # We now need to compute values for the following depthwise layer
+    if layer == 26:
+        return
+
+    _, kernel, bias, output = _load_tflite_layer(interpreter, layer + 1)
+    kernel_data, kernel_quant = kernel
+    bias_data, bias_quant = bias
+    output_data, output_quant = output
+    is_depthwise = _get_mobilenet_v1_layer_attributes(layer + 1)[2]

Review Comment:
   ```suggestion
       _, _, _, output = _load_tflite_layer(interpreter, layer + 1)
       output_data, output_quant = output
       is_depthwise = _get_mobilenet_v1_layer_attributes(layer + 1)[2]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "areusch (via GitHub)" <gi...@apache.org>.
areusch commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1086036906


##########
python/tvm/relay/qnn/op/_qnn.py:
##########
@@ -85,12 +91,72 @@ def simulated_dequantize_compute(attrs, inputs, output_type):
 register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy)
 
 
+def _get_clip_dtype_bounds(dtype):
+    """Returns the minimum and maximum values of a C integer data type."""
+    assert "int" in dtype
+    bits = int(dtype[dtype.find("int") + 3 :])
+
+    if dtype.startswith("int"):
+        return (-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
+    elif dtype.startswith("uint"):
+        return (0, 2**bits - 1)
+    else:
+        raise TVMError(f"Clip legalization is not supported for data type '{dtype}'!")
+
+
+@register_legalize("clip")
+def legalize_clip(attrs, inputs, tinfos):

Review Comment:
   are there tests for these?



##########
python/tvm/topi/arm_cpu/qnn_alter_op.py:
##########
@@ -19,57 +19,108 @@
 import numpy as np
 
 from tvm import nd, relay, target
-from ..nn import qnn_requantize_alter_layout, qnn_add_alter_layout
+from ..utils import get_const_tuple
+from ..nn import qnn_conv2d_alter_layout, qnn_add_alter_layout, qnn_requantize_alter_layout
 
 
-@qnn_requantize_alter_layout.register(["arm_cpu"])
-def alter_requantize_layout(attrs, inputs, _tinfos, _out_type):
-    """Changes a floating point requantize op to use int64 multiply + shift for microTVM.
+def prev_ops_match(curr_op, pattern):

Review Comment:
   can you write a docstring here?



##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,10 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+NameSupply global_name_supply = NameSupply("");
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  std::string symbol_name = global_name_supply->FreshName(op->buffer_var->name_hint);

Review Comment:
   why was this needed? i think need to initialize the NameSupply from the IRModule



##########
python/tvm/topi/arm_cpu/qnn_legalize.py:
##########
@@ -0,0 +1,349 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""QNN legalization transforms that help eliminate sparse channels.
+
+Some models (like MobileNetV1 when fine-tuned) have output channels in their kernels which are
+completely full of zeros. Sometimes these can be optimized away by the C compiler, but this does not
+happen when complex schedules (like the ACLE tensordot convolutions) are used.
+
+Instead, we will remove these channels by replacing blocks of operators with equivalent "denser"
+ones during legalization. This is harder than it looks - while the outputs of channels with all-zero
+kernels do not depend on the input data, they are usually not zero. We work around this by computing
+how these constant values affect subsequent operators, and "folding" these effects into a bias_add.
+
+It would eventually be nice to have a generalized, cross-target solution for removing zero channels,
+as there is no downside. This may be possible with Relax, but I'm unsure.
+"""
+
+import numpy as np
+from scipy.signal import convolve2d
+from tvm.topi.utils import get_const_tuple
+from tvm import nd, relay
+from .qnn_alter_op import prev_ops_match, edit_attrs
+from ..nn import qnn_bias_add_legalize
+
+
+def _compute_fixed_conv2d_outputs(requantize_op):
+    """Compute all conv2d output values that do not depend on the layer input."""

Review Comment:
   could you document the return value, here and below?



##########
tests/python/relay/strategy/arm_cpu/test_quantized_convolution.py:
##########
@@ -95,6 +95,53 @@ def _get_mobilenet_v1_layer_attributes(layer_num):
     return ((1, 1, 1, 1), (1, 1), True)
 
 
+@pytest.mark.parametrize("layer", range(2, 27, 2))
+@tvm.testing.requires_package("tensorflow")
+def test_empty_channel_detection(interpreter, layer):
+    """Some models (mainly MobileNetV1) have kernels with many output channels full entirely of
+    zeroes. The VWW model is one of these. This test confirms that the outputs of these channels,
+    as computed by TensorFlow, are indeed not dependent upon the input values.
+    """
+
+    _, kernel, bias, output = _load_tflite_layer(interpreter, layer)
+    kernel_data, _ = kernel
+    bias_data, bias_quant = bias
+    output_data, output_quant = output
+    is_depthwise = _get_mobilenet_v1_layer_attributes(layer)[2]
+    assert not is_depthwise
+    assert kernel_data.shape[1] == kernel_data.shape[2] == 1
+
+    out_channels = kernel_data.shape[3]
+    fixed_channels = {}
+
+    out_zero_point = output_quant["zero_points"][0]
+    assert out_zero_point == -128
+
+    for i in range(out_channels):
+        # Skip over output channels with data
+        if np.any(kernel_data[i, 0, 0, :]):
+            continue
+
+        scale = bias_quant["scales"][i] / output_quant["scales"][0]
+        channel_constant = round(bias_data[i] * scale + out_zero_point)
+        clipped = min(127, max(-128, channel_constant))
+
+        out_channel_values = output_data[0, :, :, i].flatten()
+        assert all(x == clipped for x in out_channel_values)
+        fixed_channels[i] = clipped
+
+    # We now need to compute values for the following depthwise layer
+    if layer == 26:

Review Comment:
   anything else you can assert here other than `== 26`?



##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,139 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(

Review Comment:
   can you document the parameters here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1101019261


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,139 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(
+    output_dimensions: Tuple[int, int, int],
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    function_names: Dict[Tuple, str],
+    function_code: str,
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout="NHWC",
+):
+    out_height, out_width, out_channels = output_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    data_ptr, kernel_ptr = ptr_gens
+
+    def output_ptr(output, y, c):
+        if output_layout == "NHWC":
+            return _make_tscript_ptr(output, y * const(out_width * out_channels) + c, 1)
+        elif output_layout == "NCHW":
+            return _make_tscript_ptr(
+                output, c * const(out_height * out_width) + y * const(out_width), 1
+            )
+        else:
+            raise TVMError(f"Unsupported out_layout '{output_layout}'!")
+
+    def make_row_call(buffers, c_var, y, c):
+        output, data, kernel, bias, scale = buffers
+        return _make_tscript_call(
+            function_names[(y + c) % 2, c % 2, 0],
+            output_ptr(output, y, c_var + c),
+            data_ptr(data, y, c_var + c, offset=(y + c) % 2),
+            kernel_ptr(kernel, c_var + c, offset=c),
+            _bias_ptr(bias, c_var + c),
+            _scale_ptr(scale, c_var + c),
+        )
+
+    @T.prim_func
+    def biased_quantized_conv2d(
+        data_handle: T.handle,
+        kernel_handle: T.handle,
+        bias_handle: T.handle,
+        scale_handle: T.handle,
+        output_handle: T.handle,
+    ) -> None:
+        # Same setup is used as in _make_conv2d_primfunc
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        data = T.match_buffer(data_handle, data_shape, dtype="int16")
+        kernel = T.match_buffer(kernel_handle, kernel_shape, dtype="int16")
+        bias = T.match_buffer(bias_handle, bias_shape, dtype="int32")
+        scale = T.match_buffer(scale_handle, scale_shape)
+        output = T.match_buffer(output_handle, output_shape, dtype="int16")
+
+        # pylint: disable=unused-variable
+        output[0, 0, 0, 0] = 0
+        __1 = data[0, 0, 0, 0]
+        __2 = kernel[0, 0, 0, 0]
+        __3 = bias[0, 0, 0, 0]
+        __4 = scale[0]
+        # pylint: enable=unused-variable
+
+        for c_ax in T.grid(out_channels // 2):
+            with T.block("conv2ds"):
+                T.block_attr({"pragma_import_c": function_code})
+                c = T.axis.remap("S", [c_ax]) * 2
+
+                # TODO how can I programatically make the right number of

Review Comment:
   Fixed!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] masahi commented on pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on PR #13752:
URL: https://github.com/apache/tvm/pull/13752#issuecomment-1450973069

   My comment hasn't been addressed, but if the Relay pass addition is a minor one, I'm cool with this.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mkatanbaf commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "mkatanbaf (via GitHub)" <gi...@apache.org>.
mkatanbaf commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1106329729


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -17,25 +17,40 @@
 """Contains TVMScript implementations of some QNN operators for Arm.
 
 Currently, the only ops with compute functions are fused regular and depthwise convolutions for
-Arm Cortex-M with DSP.
+Arm Cortex-M with DSP. Additionally, these functions explicitly do not support padding - it
+must be done in a separate Relay op for memory reasons.
 """
 
-from typing import Tuple
+from typing import Callable, Dict, Tuple
 
 import tvm
-from tvm import te
-from tvm.tir import const
+from tvm import te, tir, TVMError
 from tvm.script import tir as T
+from tvm.tir import const
+
 from ..utils import get_const_tuple
 from .mprofile.dsp.micro_kernel import tensordot
 
 
-def int_ceil_division(x, y):
+def _int_ceil_division(x, y):
     return -(x // -y)
 
 
 def _compute_output_dim(data_length, kernel_length, stride):
-    return int_ceil_division(data_length + 1 - kernel_length, stride)
+    return _int_ceil_division(data_length + 1 - kernel_length, stride)
+
+
+def _pick_num_outputs(out_width):
+    """Guess a good value for num_outputs."""
+
+    assert out_width > 1
+
+    # num_outputs is capped at 8
+    for i in range(2, min(out_width + 1, 8)):
+        if out_width % i == 0:

Review Comment:
   Is this a requirement? i.e. does the `out_width` have to be divisible by `num_outputs`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1103890434


##########
python/tvm/relay/qnn/op/_qnn.py:
##########
@@ -85,12 +91,72 @@ def simulated_dequantize_compute(attrs, inputs, output_type):
 register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy)
 
 
+def _get_clip_dtype_bounds(dtype):
+    """Returns the minimum and maximum values of a C integer data type."""
+    assert "int" in dtype
+    bits = int(dtype[dtype.find("int") + 3 :])
+
+    if dtype.startswith("int"):
+        return (-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)
+    elif dtype.startswith("uint"):
+        return (0, 2**bits - 1)
+    else:
+        raise TVMError(f"Clip legalization is not supported for data type '{dtype}'!")
+
+
+@register_legalize("clip")
+def legalize_clip(attrs, inputs, tinfos):

Review Comment:
   Added some tests to `tests/python/relay/qnn/test_clip_legalization`!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "areusch (via GitHub)" <gi...@apache.org>.
areusch commented on PR #13752:
URL: https://github.com/apache/tvm/pull/13752#issuecomment-1450971382

   @masahi could you have another look?
   
   @guberti i think the main comment of mine that still needs resolving is https://github.com/apache/tvm/pull/13752#discussion_r1096416812


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on PR #13752:
URL: https://github.com/apache/tvm/pull/13752#issuecomment-1452006337

   @masahi I don't think it makes sense to split the Relay pass addition, as it is a narrow change that only affects Cortex-M. Without this Relay pass change, my schedule changes would make performance worse. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087301065


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -221,23 +261,21 @@ def qnn_conv2d(attrs, inputs, out_type):
     # Make a few checks to unpack the function arguments and ensure it was called with the right
     # arguments. Note that unlike most schedules, qnn_conv2d does not use a wrapper.
     assert len(inputs) == 11
-    data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8]
-    output_layout = attrs.out_layout
-    assert output_layout == "NHWC"
+    assert not any(get_const_tuple(attrs.padding))
 
+    data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8]
     _, height, width, in_channels = get_const_tuple(data.shape)
     out_channels, kernel_h, kernel_w, _ = get_const_tuple(kernel.shape)
-    y_stride, x_stride = get_const_tuple(attrs.strides)
 
+    y_stride, x_stride = get_const_tuple(attrs.strides)
     out_height = _compute_output_dim(height, kernel_h, y_stride)
     out_width = _compute_output_dim(width, kernel_w, x_stride)
 
     # Decide how many sums our function should have running at the same time. Doing
     # this lets us do "more work" for each memory load, but doing too many of them causes us to run
     # out of registers. Currently this is set to either 1 or 2, but autotuning this value would

Review Comment:
   Can this sentence in the comment be removed since (I think?) it is no longer true?
   ```suggestion
       # out of registers. 
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1088261659


##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,11 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+unsigned int random_seed = 0;
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  // Add a random suffix to eliminate duplicate global variables.
+  int suffix = rand_r(&random_seed) % (2 << 24);

Review Comment:
   A counter would work fine - this is just a hack to fix the duplicate global variables problem. Would love @areusch's take on the right way to fix this long term.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1113408170


##########
python/tvm/relay/qnn/strategy/arm_cpu.py:
##########
@@ -21,9 +21,55 @@
 regular/depthwise conv2d is supported, but qnn_dense will be added eventually."""
 
 from tvm import topi, TVMError
-from .generic import qnn_conv2d_strategy
+from tvm.topi.utils import get_const_tuple
 from ... import op as _op
 from ...op.strategy.generic import is_depthwise_conv2d
+from .generic import (
+    qnn_conv2d_strategy,
+    qnn_dense_strategy,
+    qnn_dequantize_strategy,
+    qnn_quantize_strategy,
+    wrap_compute_dequantize,
+    wrap_compute_quantize,
+    wrap_topi_qnn_dense,
+    wrap_topi_schedule,
+)
+
+
+@qnn_quantize_strategy.register("arm_cpu")
+def qnn_quantize_strategy_arm_cpu(_attrs, _inputs, _out_type, _target):
+    """qnn.quantize strategy for arm_cpu"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_quantize(topi.hexagon.qnn_quantize),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_quantize),
+        name="qnn_quantize.arm_cpu",
+    )
+    return strategy
+
+
+@qnn_dequantize_strategy.register("arm_cpu")
+def qnn_dequantize_strategy_arm_cpu(_attrs, _inputs, _out_type, _target):
+    """qnn.dequantize strategy for arm_cpu"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_dequantize(topi.hexagon.qnn_dequantize),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_dequantize),
+        name="qnn_dequantize.arm_cpu",
+    )
+    return strategy
+
+
+@qnn_dense_strategy.register("arm_cpu")
+def qnn_dense_strategy_arm_cpu(_attrs, _inputs, _out_type, _target):
+    """qnn.dense strategy for arm_cpu"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_topi_qnn_dense(topi.hexagon.qnn_dense),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_dense),

Review Comment:
   It's fine for the time being. I know @mkatanbaf is working on a Cortex-M schedule for `dense`, but these operations do not take very much time on convolutional models.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] ibsidorenko commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "ibsidorenko (via GitHub)" <gi...@apache.org>.
ibsidorenko commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1111618732


##########
python/tvm/relay/qnn/strategy/arm_cpu.py:
##########
@@ -21,9 +21,55 @@
 regular/depthwise conv2d is supported, but qnn_dense will be added eventually."""
 
 from tvm import topi, TVMError
-from .generic import qnn_conv2d_strategy
+from tvm.topi.utils import get_const_tuple
 from ... import op as _op
 from ...op.strategy.generic import is_depthwise_conv2d
+from .generic import (
+    qnn_conv2d_strategy,
+    qnn_dense_strategy,
+    qnn_dequantize_strategy,
+    qnn_quantize_strategy,
+    wrap_compute_dequantize,
+    wrap_compute_quantize,
+    wrap_topi_qnn_dense,
+    wrap_topi_schedule,
+)
+
+
+@qnn_quantize_strategy.register("arm_cpu")
+def qnn_quantize_strategy_arm_cpu(_attrs, _inputs, _out_type, _target):
+    """qnn.quantize strategy for arm_cpu"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_quantize(topi.hexagon.qnn_quantize),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_quantize),
+        name="qnn_quantize.arm_cpu",
+    )
+    return strategy
+
+
+@qnn_dequantize_strategy.register("arm_cpu")
+def qnn_dequantize_strategy_arm_cpu(_attrs, _inputs, _out_type, _target):
+    """qnn.dequantize strategy for arm_cpu"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_dequantize(topi.hexagon.qnn_dequantize),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_dequantize),
+        name="qnn_dequantize.arm_cpu",
+    )
+    return strategy
+
+
+@qnn_dense_strategy.register("arm_cpu")
+def qnn_dense_strategy_arm_cpu(_attrs, _inputs, _out_type, _target):
+    """qnn.dense strategy for arm_cpu"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_topi_qnn_dense(topi.hexagon.qnn_dense),
+        wrap_topi_schedule(topi.hexagon.schedule_qnn_dense),

Review Comment:
   As I see you reuse compute/schedule from Hexagon. These schedules are not optimized and have very naive implementation. Is it acceptable for you?



##########
python/tvm/topi/nn/qnn.py:
##########
@@ -212,6 +212,48 @@ def qnn_requantize_alter_layout(_attrs, _inputs, _tinfos, _out_type):
     return None
 
 
+@tvm.target.generic_func
+def qnn_bias_add_legalize(_attrs, _inputs, _tinfos):
+    """Legalize bias_add layout.
+
+    Bias add is not a QNN-specific function, but this generic exists so that empty channels can
+    be excised from quantized conv2d operators and folded into bias adds.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : tvm.relay.Expr
+        Grouped input symbols
+    tinfos : list
+        Input shape and dtype
+
+    """
+    return None
+
+
+@tvm.target.generic_func
+def qnn_clip_legalize(_attrs, inputs, _tinfos, _out_type):
+    """Change clip layout.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : tvm.relay.Expr
+        Grouped input symbols
+    tinfos : list
+        Input shape and dtype
+    out_type: type
+        The output type
+
+    Note
+    ----
+    Unlike other TOPI functions, this function operates on both graph level and operator level.
+    """
+    return inputs[0]
+
+
 @tvm.target.generic_func
 def qnn_add_alter_layout(_attrs, _inputs, _tinfos, _out_type):

Review Comment:
   Here is the same. How about `qnn_add_alter_layout `--> `add_alter_layout`. Since we do it for nn.add (not qnn.add)



##########
python/tvm/topi/nn/qnn.py:
##########
@@ -212,6 +212,48 @@ def qnn_requantize_alter_layout(_attrs, _inputs, _tinfos, _out_type):
     return None
 
 
+@tvm.target.generic_func
+def qnn_bias_add_legalize(_attrs, _inputs, _tinfos):

Review Comment:
   How about to rename to `bias_add_legalize`? This name looks confused since we do legalization for `nn.bias_add` (not qnn bias_add)



##########
python/tvm/relay/qnn/op/_qnn.py:
##########
@@ -85,12 +91,72 @@ def simulated_dequantize_compute(attrs, inputs, output_type):
 register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy)
 
 
+def _get_clip_dtype_bounds(dtype):
+    """Returns the minimum and maximum values of a C integer data type."""
+    assert "int" in dtype
+    bits = int(dtype[dtype.find("int") + 3 :])
+
+    if dtype.startswith("int"):
+        return (-(2 ** (bits - 1)), 2 ** (bits - 1) - 1)

Review Comment:
   Just a nit comment... np.iinfo(dtype).min / np.iinfo(dtype).max are not suitable?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1113409585


##########
python/tvm/topi/nn/qnn.py:
##########
@@ -212,6 +212,48 @@ def qnn_requantize_alter_layout(_attrs, _inputs, _tinfos, _out_type):
     return None
 
 
+@tvm.target.generic_func
+def qnn_bias_add_legalize(_attrs, _inputs, _tinfos):
+    """Legalize bias_add layout.
+
+    Bias add is not a QNN-specific function, but this generic exists so that empty channels can
+    be excised from quantized conv2d operators and folded into bias adds.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : tvm.relay.Expr
+        Grouped input symbols
+    tinfos : list
+        Input shape and dtype
+
+    """
+    return None
+
+
+@tvm.target.generic_func
+def qnn_clip_legalize(_attrs, inputs, _tinfos, _out_type):

Review Comment:
   You're right, removed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1113880045


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -118,38 +133,89 @@ def _make_tscript_ptr(buffer, offset, length, dtype="int16"):
     )
 
 
+def _bias_ptr(bias, c):
+    return _make_tscript_ptr(bias, c, 1, dtype="int32")
+
+
+def _scale_ptr(scale, c):
+    return _make_tscript_ptr(scale, c, 1, dtype="int32")
+
+
 def _make_tscript_call(func_name, *args):
     return T.evaluate(T.call_extern(func_name, *args, dtype="int32"))
 
 
 def _make_conv2d_primfunc(
-    call_dimensions: Tuple,
-    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    output_dimensions: Tuple[int, int, int, int],
+    buffer_shapes: Tuple,
     aligned_func: Tuple[str, str],
     offset_func: Tuple[str, str],
-    ptr_gens: Tuple,
-):
-    height, width, out_channels = call_dimensions
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout: str = "NHWC",
+) -> tir.function.PrimFunc:
+    """Makes a TIR PrimFunc computing Conv2D using a call to tensordot.
+
+    Can be used to generate regular, depthwise, and grouped Conv2D operators by passing different
+    arguments and ptr_gen functions. However, it only works for Conv2D operators where the height
+    stride of the tensor is divisible by two.
+
+    Parameters
+    ----------
+    output_dimensions : Tuple[int, int, int, int]
+        A tuple containing the out_height, out_width, out_channels, and desired num_outputs values
+        in that order.
+
+    buffer_shapes: Tuple[tvm.ir.container.Array]
+        The shapes of the data, kernel, bias, scale, and output tensors, in that order. Each shape
+        should be a TVM Array.
+
+    aligned_func: Tuple[str, str]
+        A tuple containing the (name, C implementation) of a word-aligned tensordot operator.
+
+    offset_func: Tuple[str, str]
+        A tuple containing the (name, C implementation) of a word-unaligned tensordot operator. Can
+        be a tuple of empty strings if the Conv2D in question does not need an unaligned operator.
+
+    ptr_gens: Tuple[Callable, Callable]
+        A tuple of two functions to generate data and kernel access pointers. They should take as
+        inputs the buffer, (y, x, c) indices, and an alignment offset. They should return a
+        T.tvm_access_ptr object which can be used in T.call_extern.
+
+    output_layout: str
+        The tensor layout that will be prosued by the generated PrimFunc. Should be NHWC or NCHW.
+    """
+
+    out_height, out_width, out_channels, num_outputs = output_dimensions
     data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
     aligned_func_name, aligned_func_code = aligned_func
     offset_func_name, offset_func_code = offset_func
-    output_ptr, data_ptr, kernel_ptr = ptr_gens
+    data_ptr, kernel_ptr = ptr_gens
 
     # If the functions are identical, we can skip the second loop
     if aligned_func_name == offset_func_name:
         aligned_channels = out_channels
-        offset_channels = tvm.tir.const(0)
-        c_step = tvm.tir.const(1)
+        offset_channels = 0
+        c_step = const(1)
     else:
         aligned_channels = out_channels // 2
         offset_channels = out_channels // 2
-        c_step = tvm.tir.const(2)
-
-    def bias_ptr(bias, c):
-        return _make_tscript_ptr(bias, c, 1, dtype="int32")
-
-    def scale_ptr(scale, c):
-        return _make_tscript_ptr(scale, c, 1, dtype="int32")
+        c_step = const(2)

Review Comment:
   Ideally yes, but it makes the implementation harder in the depthwise_conv2d case. Luckily, no MLPerf models use an odd number of output channels.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "areusch (via GitHub)" <gi...@apache.org>.
areusch commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1123499276


##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,10 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+NameSupply global_name_supply = NameSupply("");
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  std::string symbol_name = global_name_supply->FreshName(op->buffer_var->name_hint);

Review Comment:
   i discussed with @gigiblender and in doing so we realized that there is already `var_idmap_` which uses another NameSupply. Does `AllocVarID(op->buffer_var.get())` work here? this parallels what was done in [AllocateNode](https://github.com/apache/tvm/blob/bf86d9f8c25e105a12c1b96ab0a4cc977ee088b7/src/target/source/codegen_c.cc#L857).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087307499


##########
tests/python/relay/strategy/arm_cpu/test_quantized_convolution.py:
##########
@@ -95,6 +95,56 @@ def _get_mobilenet_v1_layer_attributes(layer_num):
     return ((1, 1, 1, 1), (1, 1), True)
 
 
+@pytest.mark.parametrize("layer", range(2, 27, 2))
+def test_infinite_bias_detection(interpreter, layer):
+    """Some models (mainly MobileNetV1) have kernels with many output channels full entirely of
+    zeroes. The VWW mdoel is one of these. This test confirms that the outputs of these channels,

Review Comment:
   nit: model



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on PR #13752:
URL: https://github.com/apache/tvm/pull/13752#issuecomment-1402606269

   ## Next steps
   
   The `137 ms` performance for the `vww` model is impressive, and beats the current state-of-the-art by a good margin. However, there is still a lot of room to improve our MLPerf Tiny performance even further:
   
   - Measure performance for the `ic` and `kws` MLPerf Tiny models. The changes in this pull request should dramatically improve performance on these as well, and we should be able to use them with only minor tinkering. I'm currently working on a follow-up PR to add this functionality.
   - Add support for **autotuning** to my tensordot schedules. Specifically, tuning `num_outputs` will give a substantial performance improvement with very little work. If we're willing to be a little more ambitious, we could use tuning to do reordering of the assembly code in the generated `tensordot` functions (this would especially help Cortex-M7 performance).
   - Add a second, word-unaligned copy of convolution kernels when it would help, and add support for this to `tensordot.py`.
   - Skip padding steps by folding padding into previous operators (this should be enabled by Relax).
   - See if we can use floor instead of rounding in `tensordot.py`'s requantization implementation. This should shave a couple of `ms` off the runtimes of `ic`, `kws`, and `vww`, but it might hurt accuracy slightly.
   - Write a Cortex-M schedule for `qnn_dense`. This will improve performance for `ic`, `kws`, and `vww` by a tiny amount, but it will dramatically improve `ad` performance (which currently still sucks).
   - Generalize `tensordot.py` to support Cortex-M CPUs _without_ the DSP extension. This would allow us to give good performance for Cortex M0, M0+, M1, and M3 devices (this PR only improves performance for M4 and M7).
   - Fix the bug with Arduino Cortex-M performance. Currently, this bug makes the Arduino implementation comically slow.
   
   _Note: adding proper Helium support would require rewriting our `tensordot` implementation, and re-writing our legalization and alter_op passes as well. Helium is very cool, but proper support would take a lot of effort._
   
   ## Generalization of changes
   
   Some of the `legalization` and `alter_op` changes would be useful very broadly in TVM, but are currently only enabled for Arm Cortex-M. This includes our output layout rewriting for `conv -> depthwise` convolution patterns, our stripping of empty channels from `conv2d` operator, and stripping `pad` out into a separate Relay operator (the last one only helps _in some cases_). However, I would want to write more general passes before doing this, and I'm not sure how these would interact with Relax. I'll hold off on this for now.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mkatanbaf commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "mkatanbaf (via GitHub)" <gi...@apache.org>.
mkatanbaf commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1106334576


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -118,38 +133,89 @@ def _make_tscript_ptr(buffer, offset, length, dtype="int16"):
     )
 
 
+def _bias_ptr(bias, c):
+    return _make_tscript_ptr(bias, c, 1, dtype="int32")
+
+
+def _scale_ptr(scale, c):
+    return _make_tscript_ptr(scale, c, 1, dtype="int32")
+
+
 def _make_tscript_call(func_name, *args):
     return T.evaluate(T.call_extern(func_name, *args, dtype="int32"))
 
 
 def _make_conv2d_primfunc(
-    call_dimensions: Tuple,
-    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    output_dimensions: Tuple[int, int, int, int],
+    buffer_shapes: Tuple,
     aligned_func: Tuple[str, str],
     offset_func: Tuple[str, str],
-    ptr_gens: Tuple,
-):
-    height, width, out_channels = call_dimensions
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout: str = "NHWC",
+) -> tir.function.PrimFunc:
+    """Makes a TIR PrimFunc computing Conv2D using a call to tensordot.
+
+    Can be used to generate regular, depthwise, and grouped Conv2D operators by passing different
+    arguments and ptr_gen functions. However, it only works for Conv2D operators where the height
+    stride of the tensor is divisible by two.
+
+    Parameters
+    ----------
+    output_dimensions : Tuple[int, int, int, int]
+        A tuple containing the out_height, out_width, out_channels, and desired num_outputs values
+        in that order.
+
+    buffer_shapes: Tuple[tvm.ir.container.Array]
+        The shapes of the data, kernel, bias, scale, and output tensors, in that order. Each shape
+        should be a TVM Array.
+
+    aligned_func: Tuple[str, str]
+        A tuple containing the (name, C implementation) of a word-aligned tensordot operator.
+
+    offset_func: Tuple[str, str]
+        A tuple containing the (name, C implementation) of a word-unaligned tensordot operator. Can
+        be a tuple of empty strings if the Conv2D in question does not need an unaligned operator.
+
+    ptr_gens: Tuple[Callable, Callable]
+        A tuple of two functions to generate data and kernel access pointers. They should take as
+        inputs the buffer, (y, x, c) indices, and an alignment offset. They should return a
+        T.tvm_access_ptr object which can be used in T.call_extern.
+
+    output_layout: str
+        The tensor layout that will be prosued by the generated PrimFunc. Should be NHWC or NCHW.
+    """
+
+    out_height, out_width, out_channels, num_outputs = output_dimensions
     data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
     aligned_func_name, aligned_func_code = aligned_func
     offset_func_name, offset_func_code = offset_func
-    output_ptr, data_ptr, kernel_ptr = ptr_gens
+    data_ptr, kernel_ptr = ptr_gens
 
     # If the functions are identical, we can skip the second loop
     if aligned_func_name == offset_func_name:
         aligned_channels = out_channels
-        offset_channels = tvm.tir.const(0)
-        c_step = tvm.tir.const(1)
+        offset_channels = 0
+        c_step = const(1)
     else:
         aligned_channels = out_channels // 2
         offset_channels = out_channels // 2
-        c_step = tvm.tir.const(2)
-
-    def bias_ptr(bias, c):
-        return _make_tscript_ptr(bias, c, 1, dtype="int32")
-
-    def scale_ptr(scale, c):
-        return _make_tscript_ptr(scale, c, 1, dtype="int32")
+        c_step = const(2)

Review Comment:
   Do we need to consider cases where the `out_channels` is an odd number?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087302880


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,136 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(
+    output_dimensions: Tuple[int, int, int],
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    function_names: Dict[Tuple, str],
+    function_code: str,
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout="NHWC",
+):
+    out_height, out_width, out_channels = output_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    data_ptr, kernel_ptr = ptr_gens
+
+    def output_ptr(output, y, c):

Review Comment:
   nit: inconsistent use of leading '_' for internal function names



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087319226


##########
tests/python/relay/strategy/arm_cpu/test_quantized_convolution.py:
##########
@@ -95,6 +95,56 @@ def _get_mobilenet_v1_layer_attributes(layer_num):
     return ((1, 1, 1, 1), (1, 1), True)
 
 
+@pytest.mark.parametrize("layer", range(2, 27, 2))
+def test_infinite_bias_detection(interpreter, layer):
+    """Some models (mainly MobileNetV1) have kernels with many output channels full entirely of
+    zeroes. The VWW mdoel is one of these. This test confirms that the outputs of these channels,
+    as computed by TensorFlow, are indeed not dependent upon the input values.
+    """
+
+    _, kernel, bias, output = _load_tflite_layer(interpreter, layer)
+    kernel_data, kernel_quant = kernel
+    bias_data, bias_quant = bias
+    output_data, output_quant = output
+    is_depthwise = _get_mobilenet_v1_layer_attributes(layer)[2]
+    assert not is_depthwise
+    assert kernel_data.shape[1] == kernel_data.shape[2] == 1
+
+    out_channels = kernel_data.shape[3]
+    fixed_channels = {}
+
+    out_zero_point = output_quant["zero_points"][0]
+    assert out_zero_point == -128
+
+    for i in range(out_channels):
+        # Skip over output channels with data
+        if np.any(kernel_data[i, 0, 0, :]):
+            continue
+
+        scale = bias_quant["scales"][i] / output_quant["scales"][0]
+        channel_constant = round(bias_data[i] * scale + out_zero_point)
+        clipped = min(127, max(-128, channel_constant))
+
+        out_channel_values = output_data[0, :, :, i].flatten()
+        assert all(x == clipped for x in out_channel_values)
+        fixed_channels[i] = clipped
+    print(f"Layer {layer} had {len(fixed_channels)}/{out_channels} empty!")
+
+    # We now need to compute values for the following depthwise layer
+    if layer == 26:
+        return
+
+    _, kernel, bias, output = _load_tflite_layer(interpreter, layer + 1)
+    kernel_data, kernel_quant = kernel
+    bias_data, bias_quant = bias
+    output_data, output_quant = output
+    is_depthwise = _get_mobilenet_v1_layer_attributes(layer + 1)[2]

Review Comment:
   are some of these variable assignments unused in this block?
   ```suggestion
       _, _, _, output = _load_tflite_layer(interpreter, layer + 1)
       output_data, output_quant = output
       is_depthwise = _get_mobilenet_v1_layer_attributes(layer + 1)[2]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087304976


##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,11 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+unsigned int random_seed = 0;
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  // Add a random suffix to eliminate duplicate global variables.
+  int suffix = rand_r(&random_seed) % (2 << 24);

Review Comment:
   any reason to not just use a counter and add that as suffix to each global var name?
   
   
   (rand_r() also breaks windows build as-is)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1088266485


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,136 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(
+    output_dimensions: Tuple[int, int, int],
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    function_names: Dict[Tuple, str],
+    function_code: str,
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout="NHWC",
+):
+    out_height, out_width, out_channels = output_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    data_ptr, kernel_ptr = ptr_gens
+
+    def output_ptr(output, y, c):

Review Comment:
   My understanding was that `_` should be used for global internal functions, but that `_` should not be used for local functions. Correct me if I'm wrong, though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #13752: [DRAFT] [microTVM] Use QNN schedules to improve performance

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13752:
URL: https://github.com/apache/tvm/pull/13752#issuecomment-1377924716

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1088372130


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -368,3 +389,136 @@ def kernel_ptr(buffer, c, offset=0):
 def schedule_qnn_depthwise_conv2d(_attrs, _outs, _target):
     """Schedule function for qnn.depthwise_conv2d."""
     return None
+
+
+def _make_unrolled_conv2d_primfunc(
+    output_dimensions: Tuple[int, int, int],
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    function_names: Dict[Tuple, str],
+    function_code: str,
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout="NHWC",
+):
+    out_height, out_width, out_channels = output_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    data_ptr, kernel_ptr = ptr_gens
+
+    def output_ptr(output, y, c):

Review Comment:
   Oops, you're right - `_make_row_call` should be `make_row_call`. Thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] alanmacd commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "alanmacd (via GitHub)" <gi...@apache.org>.
alanmacd commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1087304976


##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,11 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+unsigned int random_seed = 0;
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  // Add a random suffix to eliminate duplicate global variables.
+  int suffix = rand_r(&random_seed) % (2 << 24);

Review Comment:
   any reason to not just use a counter and add that as suffix to each global var name?
   
   
   (this also breaks windows build as-is)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on PR #13752:
URL: https://github.com/apache/tvm/pull/13752#issuecomment-1402556346

   Note - to use these changes, you will need to disable QNN legalization. This can be done by calling `relay.build` as follows:
   ```python
   with tvm.transform.PassContext(
       opt_level=3,
       config={
           "tir.disable_vectorize": True,
           "relay.backend.use_meta_schedule": True,
           "relay.backend.tir_converter": "allow_extern",
       },
       disabled_pass=["qnn.Legalize"],
   ), meta_schedule.database.ScheduleFnDatabase(schedule_fn):
       lowered = tvm.relay.build(
           mod,
           target=target,
           params=params,
           runtime=crt_runtime,
           executor=executor,
       )
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] areusch commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "areusch (via GitHub)" <gi...@apache.org>.
areusch commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1109019885


##########
src/target/source/codegen_c.cc:
##########
@@ -631,8 +632,10 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
   }
 }
 
+NameSupply global_name_supply = NameSupply("");
 void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
-  std::string symbol_name = op->buffer_var->name_hint;
+  std::string symbol_name = global_name_supply->FreshName(op->buffer_var->name_hint);

Review Comment:
   @gigiblender could you advise what you think is best to do here?



##########
src/relay/qnn/op/dense.cc:
##########
@@ -242,7 +242,7 @@ RELAY_REGISTER_OP("qnn.dense")
                   "The quantization zero_point of the weight tensor.")
     .set_support_level(11)
     .add_type_rel("QDense", QnnDenseRel)
-    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnDenseInferCorrectLayout)
+//    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", QnnDenseInferCorrectLayout)

Review Comment:
   revert?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] guberti commented on a diff in pull request #13752: [microTVM] Use QNN schedules to give SOTA performance

Posted by "guberti (via GitHub)" <gi...@apache.org>.
guberti commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1113879454


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -17,25 +17,40 @@
 """Contains TVMScript implementations of some QNN operators for Arm.
 
 Currently, the only ops with compute functions are fused regular and depthwise convolutions for
-Arm Cortex-M with DSP.
+Arm Cortex-M with DSP. Additionally, these functions explicitly do not support padding - it
+must be done in a separate Relay op for memory reasons.
 """
 
-from typing import Tuple
+from typing import Callable, Dict, Tuple
 
 import tvm
-from tvm import te
-from tvm.tir import const
+from tvm import te, tir, TVMError
 from tvm.script import tir as T
+from tvm.tir import const
+
 from ..utils import get_const_tuple
 from .mprofile.dsp.micro_kernel import tensordot
 
 
-def int_ceil_division(x, y):
+def _int_ceil_division(x, y):
     return -(x // -y)
 
 
 def _compute_output_dim(data_length, kernel_length, stride):
-    return int_ceil_division(data_length + 1 - kernel_length, stride)
+    return _int_ceil_division(data_length + 1 - kernel_length, stride)
+
+
+def _pick_num_outputs(out_width):
+    """Guess a good value for num_outputs."""
+
+    assert out_width > 1
+
+    # num_outputs is capped at 8
+    for i in range(2, min(out_width + 1, 8)):
+        if out_width % i == 0:

Review Comment:
   Yep! Otherwise we would have to implement a "special case" for left-over values. At some point, this will have to be done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org