You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/30 19:12:00 UTC

[GitHub] [tvm] areusch commented on a diff in pull request #13242: [microTVM] Modernize Arm Cortex-M convolution schedules

areusch commented on code in PR #13242:
URL: https://github.com/apache/tvm/pull/13242#discussion_r1036216601


##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):

Review Comment:
   i think everything here is pretty well documented. one thing that would help even more is to add type annotations. this could be a follow-up, what do you think of that idea?



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):

Review Comment:
   suggest to adopt the same var name everywhere--you call this either split_size or num_sums depending on which function we're in.



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:

Review Comment:
   just curious, why make this an Iterator if you're just going to pass the Iterator to list()? i think it'd be clearer/faster to just construct the list here.



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:

Review Comment:
   can you add parens around `tensor_w % 2` and a brief comment explaining why to yield None here?



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None
+
+
+def _get_int16_alias(position) -> str:
+    if not position:
+        return "unknown"
+    y, x = position
+    return f"y{y:0>2x}_x{x:0>2x}"
+
+
+def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    offset = int(not bool(halfwords[0]))
+
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))

Review Comment:
   more pythonic now is an iterator comprehension, but since it is just two, i suggest to just write it out:
   ```suggestion
           var_name = f"{_get_int16_alias(halfwords[i])}_{_get_int16_alias(halfwords[i+1])}"
   ```



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:

Review Comment:
   are there any unittests we could look at to reason about the correctness of these functions?



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None

Review Comment:
   same here



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:

Review Comment:
   same here



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None
+
+
+def _get_int16_alias(position) -> str:
+    if not position:
+        return "unknown"
+    y, x = position
+    return f"y{y:0>2x}_x{x:0>2x}"
+
+
+def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    offset = int(not bool(halfwords[0]))
+
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        y, x = halfwords[i + 1] or halfwords[i]
+        tensor_index = (y * tensor_w + x + offset) // 2
+        yield f"int tensor__{var_name} = tensor[{tensor_index}];"
+
 
+def _load_kernel_vars(halfwords) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        yield f"int kernel__{var_name} = kernel[{i // 2}];"
 
-def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str:
-    """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has
-    a `jump` argument that advances the pointer of one tensor by that many words after each row. The
-    `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned
-    memory access is slow on the Cortex-M series. Depending on the input datatype, the code may
-    contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See
-    the below pseudocode for reference:
-
-    tensordot(out_ptr, dat_ptr, ker_ptr) {
-        sum = 0;
-        for (i = 0; i < tensor_h; i++) {
-            for (j = 0; j < tensor_w; j++) {
-                sum += (*dat_ptr++) * (*ker_ptr++);
-            }
-            dat_ptr += jump;
-        }
-        *out_ptr = sum;
-    }
+
+def _get_draft_macs(
+    kernel_dims, tensor_halfwords, kernel_halfwords, offset
+) -> Iterator[SMLAInstruction]:
+    """Generates unrolled MAC instructions to compute one tensordot sum.
+
+    Unrolling these loops increases code size a tiny bit (< 0.02 KB), but makes the generated code
+    much faster. The generated code does not use SIMD instructions - they are added later by
+    _apply_simd_optimizations.
+
+    We return an iterator of SMLAInstruction named tuples. Returning an iterator lets us do
+    optimizations by iterator chaining.
+    """
+
+    def get_var(y, x, halfwords) -> Tuple[str, str]:
+        i = halfwords.index((y, x))
+        if i % 2 == 0:
+            return f"{_get_int16_alias((y, x))}__{_get_int16_alias(halfwords[i + 1])}", "b"
+        return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
+
+    kernel_h, kernel_w = kernel_dims
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            tensor_var, tensor_half = get_var(y, x + offset, tensor_halfwords)
+            kernel_var, kernel_half = get_var(y, x, kernel_halfwords)
+            instruction = f"smla{tensor_half}{kernel_half}"
+            yield SMLAInstruction(instruction, f"tensor__{tensor_var}", f"kernel__{kernel_var}")
+
+
+def _apply_simd_optimizations(instruction_tuples) -> Iterator[SMLAInstruction]:
+    """When possible, fuses single MACs into SIMD MAC instructions.
+
+    The compiler cannot do this automatically, as calling __smlaxy forces the SMLAxy instruction to
+    be used. This function takes as input an iterator of SMLAInstructions and returns an iterator of
+    SMLAInstructions (possibly of different length).
     """
+    curr_tuple = next(instruction_tuples, None)
+    while curr_tuple:
+        next_tuple = next(instruction_tuples, None)
+        if not next_tuple:
+            yield curr_tuple
+            break
 
-    simd_lanes = num_simd_lanes_per_word(in_dtype)
-    assert tensor_w % simd_lanes == 0
-    assert jump % simd_lanes == 0
+        if curr_tuple[1:] == next_tuple[1:]:

Review Comment:
   can you use the named accessors here? i realize that complicates the if statement; you could also make a method on the dataclass i discussed earlier



##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -0,0 +1,369 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Contains TVMScript implementations of some QNN operators for Arm.
+
+Currently, the only ops with compute functions are fused regular and depthwise convolutions for
+Arm Cortex-M with DSP.
+"""
+
+from typing import Tuple
+
+import tvm
+from tvm import te
+from tvm.tir import const
+from tvm.script import tir as T
+from ..utils import get_const_tuple
+from .mprofile.dsp.micro_kernel import tensordot
+
+
+def int_ceil_division(x, y):
+    return -(x // -y)
+
+
+def _compute_output_dim(data_length, kernel_length, stride):
+    return int_ceil_division(data_length + 1 - kernel_length, stride)
+
+
+def _pick_tensordot_impl(attrs, inputs, num_sums=2, is_depthwise=False):
+    """Helper function that chooses the right implementation of micro_kernel.tensordot.
+
+    Takes as input the parameters of the conv2d, and returns a tuple of TWO (function_name,
+    function_code). The first pair (the aligned one) is for even numbered output channels, and the
+    second pair (the offset one) is for odd-numbered output channels. This function is used for
+    regular and depthwise convolutions.
+
+    We need different implementations for even vs odd numbered output channels, because the "start"
+    of an odd output channel in the data tensor or kernel might or might not be on a word boundary,
+    and the tensordot code expects all input pointers to be word-aligned.
+    """
+    data, kernel = inputs[0:2]
+    rq_output_zero_point_const = inputs[10]
+    assert len(rq_output_zero_point_const.op.body) == 1
+    output_zero_point = rq_output_zero_point_const.op.body[0]
+
+    _, stride_w = get_const_tuple(attrs.strides)
+
+    if is_depthwise:
+        assert attrs.data_layout == "NCHW"
+        assert attrs.kernel_layout == "IOHW"
+        _, _, height, width = get_const_tuple(data.shape)
+        _, out_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape)
+
+        dimensions = (width, kernel_h, kernel_w)
+        in_stride = stride_w
+        data_per_oc_size = height * width
+    else:
+        assert attrs.data_layout == "NHWC"
+        assert attrs.kernel_layout == "OHWI"
+        _, height, width, in_channels = get_const_tuple(data.shape)
+        out_channels, kernel_h, kernel_w, _ = get_const_tuple(kernel.shape)
+
+        dimensions = (width * in_channels, kernel_h, kernel_w * in_channels)
+        in_stride = in_channels * stride_w
+        data_per_oc_size = 0
+
+    assert attrs.out_layout is not None
+    if attrs.out_layout == "NHWC":
+        out_stride = out_channels
+    elif attrs.out_layout == "NCHW":
+        out_stride = 1
+    else:
+        raise ValueError(f"Unsupported output layout {attrs.out_layout}!")
+
+    x_strides = (in_stride, out_stride)
+    aligned_func = tensordot.tensordot_int16_impl(
+        num_sums,
+        dimensions,
+        (0, 0, 0),
+        x_strides,
+        output_zero_point=output_zero_point,
+    )
+
+    kernel_per_oc_size = dimensions[1] * dimensions[2]
+
+    offsets = (data_per_oc_size % 2, kernel_per_oc_size % 2, 0)
+    offset_func = tensordot.tensordot_int16_impl(
+        num_sums,
+        dimensions,
+        offsets,
+        x_strides,
+        output_zero_point=output_zero_point,
+    )
+
+    return (aligned_func, offset_func)
+
+
+def _make_tscript_ptr(buffer, offset, length, dtype="int16"):
+    return T.tvm_access_ptr(
+        T.type_annotation(dtype=dtype),
+        buffer.data,
+        offset,
+        length,
+        1,
+        dtype="handle",
+    )
+
+
+def _make_tscript_call(func_name, *args):
+    return T.evaluate(T.call_extern(func_name, *args, dtype="int32"))
+
+
+def _make_conv2d_primfunc(
+    call_dimensions: Tuple,
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    aligned_func: Tuple[str, str],
+    offset_func: Tuple[str, str],
+    ptr_gens: Tuple,
+):
+    height, width, out_channels = call_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    aligned_func_name, aligned_func_code = aligned_func
+    offset_func_name, offset_func_code = offset_func
+    output_ptr, data_ptr, kernel_ptr = ptr_gens
+
+    # If the functions are identical, we can skip the second loop
+    if aligned_func_name == offset_func_name:
+        aligned_channels = out_channels
+        offset_channels = tvm.tir.const(0)
+        c_step = tvm.tir.const(1)
+    else:
+        aligned_channels = out_channels // 2
+        offset_channels = out_channels // 2
+        c_step = tvm.tir.const(2)
+
+    def bias_ptr(bias, c):
+        return _make_tscript_ptr(bias, c, 1, dtype="int32")
+
+    def scale_ptr(scale, c):
+        return _make_tscript_ptr(scale, c, 1, dtype="int32")
+
+    @T.prim_func
+    def biased_quantized_conv2d(
+        data_handle: T.handle,
+        kernel_handle: T.handle,
+        bias_handle: T.handle,
+        scale_handle: T.handle,
+        output_handle: T.handle,
+    ) -> None:
+
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        data = T.match_buffer(data_handle, data_shape, dtype="int16")
+        kernel = T.match_buffer(kernel_handle, kernel_shape, dtype="int16")
+        bias = T.match_buffer(bias_handle, bias_shape, dtype="int32")
+
+        # We don't specify a data type for the requantization scale, even though we will read it as
+        # an int32. This is because we must pretend it is a float32, as Relay's requantize op only
+        # allows floating point scales.
+        scale = T.match_buffer(scale_handle, scale_shape)
+        output = T.match_buffer(output_handle, output_shape, dtype="int16")
+
+        # This hack prevents TVM from seeing these variables as "unused". I should be using T.reads

Review Comment:
   can you file a bug for this?



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None
+
+
+def _get_int16_alias(position) -> str:
+    if not position:
+        return "unknown"
+    y, x = position
+    return f"y{y:0>2x}_x{x:0>2x}"
+
+
+def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    offset = int(not bool(halfwords[0]))
+
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        y, x = halfwords[i + 1] or halfwords[i]
+        tensor_index = (y * tensor_w + x + offset) // 2
+        yield f"int tensor__{var_name} = tensor[{tensor_index}];"
+
 
+def _load_kernel_vars(halfwords) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        yield f"int kernel__{var_name} = kernel[{i // 2}];"
 
-def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str:
-    """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has
-    a `jump` argument that advances the pointer of one tensor by that many words after each row. The
-    `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned
-    memory access is slow on the Cortex-M series. Depending on the input datatype, the code may
-    contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See
-    the below pseudocode for reference:
-
-    tensordot(out_ptr, dat_ptr, ker_ptr) {
-        sum = 0;
-        for (i = 0; i < tensor_h; i++) {
-            for (j = 0; j < tensor_w; j++) {
-                sum += (*dat_ptr++) * (*ker_ptr++);
-            }
-            dat_ptr += jump;
-        }
-        *out_ptr = sum;
-    }
+
+def _get_draft_macs(
+    kernel_dims, tensor_halfwords, kernel_halfwords, offset
+) -> Iterator[SMLAInstruction]:
+    """Generates unrolled MAC instructions to compute one tensordot sum.
+
+    Unrolling these loops increases code size a tiny bit (< 0.02 KB), but makes the generated code
+    much faster. The generated code does not use SIMD instructions - they are added later by
+    _apply_simd_optimizations.
+
+    We return an iterator of SMLAInstruction named tuples. Returning an iterator lets us do
+    optimizations by iterator chaining.
+    """
+
+    def get_var(y, x, halfwords) -> Tuple[str, str]:
+        i = halfwords.index((y, x))
+        if i % 2 == 0:
+            return f"{_get_int16_alias((y, x))}__{_get_int16_alias(halfwords[i + 1])}", "b"
+        return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
+
+    kernel_h, kernel_w = kernel_dims
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            tensor_var, tensor_half = get_var(y, x + offset, tensor_halfwords)
+            kernel_var, kernel_half = get_var(y, x, kernel_halfwords)
+            instruction = f"smla{tensor_half}{kernel_half}"
+            yield SMLAInstruction(instruction, f"tensor__{tensor_var}", f"kernel__{kernel_var}")
+
+
+def _apply_simd_optimizations(instruction_tuples) -> Iterator[SMLAInstruction]:
+    """When possible, fuses single MACs into SIMD MAC instructions.
+
+    The compiler cannot do this automatically, as calling __smlaxy forces the SMLAxy instruction to
+    be used. This function takes as input an iterator of SMLAInstructions and returns an iterator of
+    SMLAInstructions (possibly of different length).
     """
+    curr_tuple = next(instruction_tuples, None)
+    while curr_tuple:
+        next_tuple = next(instruction_tuples, None)
+        if not next_tuple:
+            yield curr_tuple
+            break
 
-    simd_lanes = num_simd_lanes_per_word(in_dtype)
-    assert tensor_w % simd_lanes == 0
-    assert jump % simd_lanes == 0
+        if curr_tuple[1:] == next_tuple[1:]:
+            if set([curr_tuple[0], next_tuple[0]]) == set(["smlatt", "smlabb"]):
+                yield SMLAInstruction("smlad", *curr_tuple[1:])
+                next_tuple = next(instruction_tuples, None)
+            elif set([curr_tuple[0], next_tuple[0]]) == set(["smlatb", "smlabt"]):
+                yield SMLAInstruction("smladx", *curr_tuple[1:])
+                next_tuple = next(instruction_tuples, None)
+            else:
+                yield curr_tuple
 
-    if in_dtype == "int8":
-        inner_loop = """
-              uint32_t tensor_c20 = __SXTB16(tensor_batch);
-              uint32_t kernel_c20 = __SXTB16(kernel_batch);
-              sum = __SMLAD(tensor_c20, kernel_c20, sum);
+        else:
+            yield curr_tuple
+        curr_tuple = next_tuple
 
-              uint32_t tensor_c31 = __SXTB16(__ROR(tensor_batch, 8));
-              uint32_t kernel_c31 = __SXTB16(__ROR(kernel_batch, 8));
-              sum = __SMLAD(tensor_c31, kernel_c31, sum);"""
 
-    elif in_dtype == "int16":
-        inner_loop = """
-              sum = __SMLAD(tensor_batch, kernel_batch, sum);"""
+def _expand_instruction_tuples(instruction_tuples, index) -> Iterator[str]:
+    """Converts an iterator of SMLAInstructions into lines of C code.
 
-    elif in_dtype == "int32":
-        inner_loop = """
-              // Compiles to a single MAC instruction
-              sum += tensor_batch * kernel_batch;"""
+    We want the compiler to re-order these with the memory loads, so we generate them as a series of
+    calls to instruction aliases instead of as a single `asm` block.
+    """
+
+    for instruction, op1, op2 in instruction_tuples:
+        assert "smla" in instruction
+
+        # We call the instruction using the Arm C Language Extensions. Using ACLE gives better
+        # cross-compiler compatibility than using __builtin functions.
+        yield f"sum_{index} = __{instruction}({op1}, {op2}, sum_{index});"
+
+
+def _requantize_sums(num_sums, requantize_shift, output_zero_point) -> Iterator[str]:
+    """Generates code to requantize the accumulator values.
+
+    The generated code does not use floating point instructions, as it simulates floating point
+    multiplication with an a int64 multiply + shift. The bias is added at the beginning, so we can
+    skip doing it now. The shift is hard-coded, as this saves a few cycles without hurting accuracy
+    in "most" cases.
+
+    It's *possible* we could save one more cycle here by pre-multiplying the bias with the
+    requantize multiplier, and then doing the bias addition and shift in the same cycle (via <op2>).
+    However, it's complicated and only saves one cycle.
+
+    It's also worth noting the SSAT16 operation doesn't help us here. The data isn't stored as two
+    halfwords in a word, and rearrainging it would take at least one cycle. Two SSAT operations is
+    just as good.
+
+    Calling __ssat directly is a little bit gross, but GCC and Clang are unreliable about compiling
+    other ways of writing this. Both the multiply + shift and shift + saturation combine to one
+    instruction each.
+    """
+
+    yield "int scale_val = *scale;"
+    for i in range(num_sums):
+        yield f"int requant_{i} = (sum_{i} * (long long) scale_val) >> {requantize_shift - 1};"
+        yield f"requant_{i} = (requant_{i} + 1) >> 1;"
+        yield f"requant_{i} = __ssat(requant_{i} + {output_zero_point}, 8);"
+
+
+def _write_sums_to_memory(num_sums, offset, stride) -> Iterator[str]:
+    """Generates code to write the requantized sums to memory.
+
+    Note - halfword packing here *does* help. It seems
+    like it wouldn't, as doing two pipelined int16 stores takes two cycles - the same as halfword
+    packing plus a pipelined int32 store. We still do the int16 stores when there is an output
+    stride, though.
+
+    However, this lets the compiler re-order instructions to better preserve memory, as it doesn't
+    like breaking apart the store instructions (as this messes up pipelining).
+    """
+
+    if stride > 1:
+        for i in range(num_sums):
+            yield f"((short*) output)[{i * stride + offset}] = (short) requant_{i};"

Review Comment:
   how come you use `short` here? int16_t better?



##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -0,0 +1,369 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Contains TVMScript implementations of some QNN operators for Arm.
+
+Currently, the only ops with compute functions are fused regular and depthwise convolutions for
+Arm Cortex-M with DSP.
+"""
+
+from typing import Tuple
+
+import tvm
+from tvm import te
+from tvm.tir import const
+from tvm.script import tir as T
+from ..utils import get_const_tuple
+from .mprofile.dsp.micro_kernel import tensordot
+
+
+def int_ceil_division(x, y):
+    return -(x // -y)
+
+
+def _compute_output_dim(data_length, kernel_length, stride):
+    return int_ceil_division(data_length + 1 - kernel_length, stride)
+
+
+def _pick_tensordot_impl(attrs, inputs, num_sums=2, is_depthwise=False):
+    """Helper function that chooses the right implementation of micro_kernel.tensordot.
+
+    Takes as input the parameters of the conv2d, and returns a tuple of TWO (function_name,
+    function_code). The first pair (the aligned one) is for even numbered output channels, and the
+    second pair (the offset one) is for odd-numbered output channels. This function is used for
+    regular and depthwise convolutions.
+
+    We need different implementations for even vs odd numbered output channels, because the "start"
+    of an odd output channel in the data tensor or kernel might or might not be on a word boundary,
+    and the tensordot code expects all input pointers to be word-aligned.
+    """
+    data, kernel = inputs[0:2]
+    rq_output_zero_point_const = inputs[10]
+    assert len(rq_output_zero_point_const.op.body) == 1
+    output_zero_point = rq_output_zero_point_const.op.body[0]
+
+    _, stride_w = get_const_tuple(attrs.strides)
+
+    if is_depthwise:
+        assert attrs.data_layout == "NCHW"
+        assert attrs.kernel_layout == "IOHW"
+        _, _, height, width = get_const_tuple(data.shape)
+        _, out_channels, kernel_h, kernel_w = get_const_tuple(kernel.shape)
+
+        dimensions = (width, kernel_h, kernel_w)
+        in_stride = stride_w
+        data_per_oc_size = height * width
+    else:
+        assert attrs.data_layout == "NHWC"
+        assert attrs.kernel_layout == "OHWI"
+        _, height, width, in_channels = get_const_tuple(data.shape)
+        out_channels, kernel_h, kernel_w, _ = get_const_tuple(kernel.shape)
+
+        dimensions = (width * in_channels, kernel_h, kernel_w * in_channels)
+        in_stride = in_channels * stride_w
+        data_per_oc_size = 0
+
+    assert attrs.out_layout is not None
+    if attrs.out_layout == "NHWC":
+        out_stride = out_channels
+    elif attrs.out_layout == "NCHW":
+        out_stride = 1
+    else:
+        raise ValueError(f"Unsupported output layout {attrs.out_layout}!")
+
+    x_strides = (in_stride, out_stride)
+    aligned_func = tensordot.tensordot_int16_impl(
+        num_sums,
+        dimensions,
+        (0, 0, 0),
+        x_strides,
+        output_zero_point=output_zero_point,
+    )
+
+    kernel_per_oc_size = dimensions[1] * dimensions[2]
+
+    offsets = (data_per_oc_size % 2, kernel_per_oc_size % 2, 0)
+    offset_func = tensordot.tensordot_int16_impl(
+        num_sums,
+        dimensions,
+        offsets,
+        x_strides,
+        output_zero_point=output_zero_point,
+    )
+
+    return (aligned_func, offset_func)
+
+
+def _make_tscript_ptr(buffer, offset, length, dtype="int16"):
+    return T.tvm_access_ptr(
+        T.type_annotation(dtype=dtype),
+        buffer.data,
+        offset,
+        length,
+        1,
+        dtype="handle",
+    )
+
+
+def _make_tscript_call(func_name, *args):
+    return T.evaluate(T.call_extern(func_name, *args, dtype="int32"))
+
+
+def _make_conv2d_primfunc(
+    call_dimensions: Tuple,
+    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    aligned_func: Tuple[str, str],
+    offset_func: Tuple[str, str],
+    ptr_gens: Tuple,
+):
+    height, width, out_channels = call_dimensions
+    data_shape, kernel_shape, bias_shape, scale_shape, output_shape = buffer_shapes
+    aligned_func_name, aligned_func_code = aligned_func
+    offset_func_name, offset_func_code = offset_func
+    output_ptr, data_ptr, kernel_ptr = ptr_gens
+
+    # If the functions are identical, we can skip the second loop
+    if aligned_func_name == offset_func_name:
+        aligned_channels = out_channels
+        offset_channels = tvm.tir.const(0)
+        c_step = tvm.tir.const(1)
+    else:
+        aligned_channels = out_channels // 2
+        offset_channels = out_channels // 2
+        c_step = tvm.tir.const(2)
+
+    def bias_ptr(bias, c):
+        return _make_tscript_ptr(bias, c, 1, dtype="int32")
+
+    def scale_ptr(scale, c):
+        return _make_tscript_ptr(scale, c, 1, dtype="int32")
+
+    @T.prim_func
+    def biased_quantized_conv2d(
+        data_handle: T.handle,
+        kernel_handle: T.handle,
+        bias_handle: T.handle,
+        scale_handle: T.handle,
+        output_handle: T.handle,
+    ) -> None:
+
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        data = T.match_buffer(data_handle, data_shape, dtype="int16")
+        kernel = T.match_buffer(kernel_handle, kernel_shape, dtype="int16")
+        bias = T.match_buffer(bias_handle, bias_shape, dtype="int32")
+
+        # We don't specify a data type for the requantization scale, even though we will read it as
+        # an int32. This is because we must pretend it is a float32, as Relay's requantize op only
+        # allows floating point scales.
+        scale = T.match_buffer(scale_handle, scale_shape)
+        output = T.match_buffer(output_handle, output_shape, dtype="int16")
+
+        # This hack prevents TVM from seeing these variables as "unused". I should be using T.reads
+        # and T.writes, but they don't work. I think it's an issue with BufferTouchedDomain.
+        # pylint: disable=unused-variable
+        output[0, 0, 0, 0] = 0
+        __1 = data[0, 0, 0, 0]
+        __2 = kernel[0, 0, 0, 0]
+        __3 = bias[0, 0, 0, 0]
+        __4 = scale[0]
+        # pylint: enable=unused-variable
+
+        for c_ax, y_ax, x_ax in T.grid(aligned_channels, height, width):
+            with T.block("conv2d_aligned"):
+                T.block_attr({"pragma_import_c": aligned_func_code})
+                y, x, c = T.axis.remap("SSS", [y_ax, x_ax, c_ax])
+                _make_tscript_call(
+                    aligned_func_name,
+                    output_ptr(output, y, x, c * c_step),
+                    data_ptr(data, y, x, c * c_step),
+                    kernel_ptr(kernel, c * c_step),
+                    bias_ptr(bias, c * c_step),
+                    scale_ptr(scale, c * c_step),
+                )
+
+        for c_ax, y_ax, x_ax in T.grid(offset_channels, height, width):
+            with T.block("conv2d_offset"):
+                T.block_attr({"pragma_import_c": offset_func_code})
+                y, x, c = T.axis.remap("SSS", [y_ax, x_ax, c_ax])
+                _make_tscript_call(
+                    offset_func_name,
+                    output_ptr(output, y, x, c * c_step + 1),
+                    data_ptr(data, y, x, c * c_step + 1, offset=1),
+                    kernel_ptr(kernel, c * c_step + 1, offset=1),
+                    bias_ptr(bias, c * c_step + 1),
+                    scale_ptr(scale, c * c_step + 1),
+                )
+
+    return biased_quantized_conv2d
+
+
+def qnn_conv2d(attrs, inputs, out_type):
+    """Compute for qnn.conv2d with NHWC layout.
+
+    Note that this is a DIFFERENT layout from the Hexagon variant, because they have special
+    instructions Cortex-M doesn't have. We expect the kernel to have OHWI layout. We also assume
+    that padding is not necessary, as it will have been done by another pass.
+    """
+
+    # Make a few checks to unpack the function arguments and ensure it was called with the right
+    # arguments. Note that unlike most schedules, qnn_conv2d does not use a wrapper.
+    assert len(inputs) == 11
+    data, kernel, _izp, _kzp, _iscale, _kscale, bias, scale = inputs[0:8]
+    output_layout = attrs.out_layout
+    assert output_layout == "NHWC"
+
+    _, height, width, in_channels = get_const_tuple(data.shape)
+    out_channels, kernel_h, kernel_w, _ = get_const_tuple(kernel.shape)
+    y_stride, x_stride = get_const_tuple(attrs.strides)
+
+    out_height = _compute_output_dim(height, kernel_h, y_stride)
+    out_width = _compute_output_dim(width, kernel_w, x_stride)
+
+    # Decide how many sums our function should have running at the same time. Doing
+    # this lets us do "more work" for each memory load, but doing too many of them causes us to run
+    # out of registers. Currently this is set to either 1 or 2, but autotuning this value would

Review Comment:
   could you file a bug to track doing this later, whenever that is?



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None
+
+
+def _get_int16_alias(position) -> str:
+    if not position:
+        return "unknown"
+    y, x = position
+    return f"y{y:0>2x}_x{x:0>2x}"
+
+
+def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    offset = int(not bool(halfwords[0]))
+
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        y, x = halfwords[i + 1] or halfwords[i]
+        tensor_index = (y * tensor_w + x + offset) // 2
+        yield f"int tensor__{var_name} = tensor[{tensor_index}];"
+
 
+def _load_kernel_vars(halfwords) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        yield f"int kernel__{var_name} = kernel[{i // 2}];"
 
-def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str:
-    """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has
-    a `jump` argument that advances the pointer of one tensor by that many words after each row. The
-    `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned
-    memory access is slow on the Cortex-M series. Depending on the input datatype, the code may
-    contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See
-    the below pseudocode for reference:
-
-    tensordot(out_ptr, dat_ptr, ker_ptr) {
-        sum = 0;
-        for (i = 0; i < tensor_h; i++) {
-            for (j = 0; j < tensor_w; j++) {
-                sum += (*dat_ptr++) * (*ker_ptr++);
-            }
-            dat_ptr += jump;
-        }
-        *out_ptr = sum;
-    }
+
+def _get_draft_macs(
+    kernel_dims, tensor_halfwords, kernel_halfwords, offset
+) -> Iterator[SMLAInstruction]:
+    """Generates unrolled MAC instructions to compute one tensordot sum.
+
+    Unrolling these loops increases code size a tiny bit (< 0.02 KB), but makes the generated code
+    much faster. The generated code does not use SIMD instructions - they are added later by
+    _apply_simd_optimizations.
+
+    We return an iterator of SMLAInstruction named tuples. Returning an iterator lets us do
+    optimizations by iterator chaining.
+    """
+
+    def get_var(y, x, halfwords) -> Tuple[str, str]:
+        i = halfwords.index((y, x))
+        if i % 2 == 0:
+            return f"{_get_int16_alias((y, x))}__{_get_int16_alias(halfwords[i + 1])}", "b"
+        return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
+
+    kernel_h, kernel_w = kernel_dims
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            tensor_var, tensor_half = get_var(y, x + offset, tensor_halfwords)
+            kernel_var, kernel_half = get_var(y, x, kernel_halfwords)
+            instruction = f"smla{tensor_half}{kernel_half}"
+            yield SMLAInstruction(instruction, f"tensor__{tensor_var}", f"kernel__{kernel_var}")
+
+
+def _apply_simd_optimizations(instruction_tuples) -> Iterator[SMLAInstruction]:
+    """When possible, fuses single MACs into SIMD MAC instructions.
+
+    The compiler cannot do this automatically, as calling __smlaxy forces the SMLAxy instruction to
+    be used. This function takes as input an iterator of SMLAInstructions and returns an iterator of
+    SMLAInstructions (possibly of different length).
     """
+    curr_tuple = next(instruction_tuples, None)
+    while curr_tuple:
+        next_tuple = next(instruction_tuples, None)
+        if not next_tuple:

Review Comment:
   ```suggestion
           if next_tuple is None:
   ```



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None

Review Comment:
   can you add a comment explaining why you yield None here?



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None
+
+
+def _get_int16_alias(position) -> str:
+    if not position:
+        return "unknown"
+    y, x = position
+    return f"y{y:0>2x}_x{x:0>2x}"
+
+
+def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    offset = int(not bool(halfwords[0]))
+
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        y, x = halfwords[i + 1] or halfwords[i]
+        tensor_index = (y * tensor_w + x + offset) // 2
+        yield f"int tensor__{var_name} = tensor[{tensor_index}];"
+
 
+def _load_kernel_vars(halfwords) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        yield f"int kernel__{var_name} = kernel[{i // 2}];"
 
-def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str:
-    """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has
-    a `jump` argument that advances the pointer of one tensor by that many words after each row. The
-    `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned
-    memory access is slow on the Cortex-M series. Depending on the input datatype, the code may
-    contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See
-    the below pseudocode for reference:
-
-    tensordot(out_ptr, dat_ptr, ker_ptr) {
-        sum = 0;
-        for (i = 0; i < tensor_h; i++) {
-            for (j = 0; j < tensor_w; j++) {
-                sum += (*dat_ptr++) * (*ker_ptr++);
-            }
-            dat_ptr += jump;
-        }
-        *out_ptr = sum;
-    }
+
+def _get_draft_macs(
+    kernel_dims, tensor_halfwords, kernel_halfwords, offset
+) -> Iterator[SMLAInstruction]:
+    """Generates unrolled MAC instructions to compute one tensordot sum.
+
+    Unrolling these loops increases code size a tiny bit (< 0.02 KB), but makes the generated code
+    much faster. The generated code does not use SIMD instructions - they are added later by
+    _apply_simd_optimizations.
+
+    We return an iterator of SMLAInstruction named tuples. Returning an iterator lets us do
+    optimizations by iterator chaining.
+    """
+
+    def get_var(y, x, halfwords) -> Tuple[str, str]:
+        i = halfwords.index((y, x))
+        if i % 2 == 0:
+            return f"{_get_int16_alias((y, x))}__{_get_int16_alias(halfwords[i + 1])}", "b"
+        return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
+
+    kernel_h, kernel_w = kernel_dims
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            tensor_var, tensor_half = get_var(y, x + offset, tensor_halfwords)
+            kernel_var, kernel_half = get_var(y, x, kernel_halfwords)
+            instruction = f"smla{tensor_half}{kernel_half}"
+            yield SMLAInstruction(instruction, f"tensor__{tensor_var}", f"kernel__{kernel_var}")
+
+
+def _apply_simd_optimizations(instruction_tuples) -> Iterator[SMLAInstruction]:
+    """When possible, fuses single MACs into SIMD MAC instructions.
+
+    The compiler cannot do this automatically, as calling __smlaxy forces the SMLAxy instruction to
+    be used. This function takes as input an iterator of SMLAInstructions and returns an iterator of
+    SMLAInstructions (possibly of different length).
     """
+    curr_tuple = next(instruction_tuples, None)
+    while curr_tuple:
+        next_tuple = next(instruction_tuples, None)
+        if not next_tuple:
+            yield curr_tuple
+            break
 
-    simd_lanes = num_simd_lanes_per_word(in_dtype)
-    assert tensor_w % simd_lanes == 0
-    assert jump % simd_lanes == 0
+        if curr_tuple[1:] == next_tuple[1:]:
+            if set([curr_tuple[0], next_tuple[0]]) == set(["smlatt", "smlabb"]):
+                yield SMLAInstruction("smlad", *curr_tuple[1:])
+                next_tuple = next(instruction_tuples, None)
+            elif set([curr_tuple[0], next_tuple[0]]) == set(["smlatb", "smlabt"]):
+                yield SMLAInstruction("smladx", *curr_tuple[1:])
+                next_tuple = next(instruction_tuples, None)
+            else:
+                yield curr_tuple
 
-    if in_dtype == "int8":
-        inner_loop = """
-              uint32_t tensor_c20 = __SXTB16(tensor_batch);
-              uint32_t kernel_c20 = __SXTB16(kernel_batch);
-              sum = __SMLAD(tensor_c20, kernel_c20, sum);
+        else:
+            yield curr_tuple
+        curr_tuple = next_tuple
 
-              uint32_t tensor_c31 = __SXTB16(__ROR(tensor_batch, 8));
-              uint32_t kernel_c31 = __SXTB16(__ROR(kernel_batch, 8));
-              sum = __SMLAD(tensor_c31, kernel_c31, sum);"""
 
-    elif in_dtype == "int16":
-        inner_loop = """
-              sum = __SMLAD(tensor_batch, kernel_batch, sum);"""
+def _expand_instruction_tuples(instruction_tuples, index) -> Iterator[str]:
+    """Converts an iterator of SMLAInstructions into lines of C code.
 
-    elif in_dtype == "int32":
-        inner_loop = """
-              // Compiles to a single MAC instruction
-              sum += tensor_batch * kernel_batch;"""
+    We want the compiler to re-order these with the memory loads, so we generate them as a series of
+    calls to instruction aliases instead of as a single `asm` block.
+    """
+
+    for instruction, op1, op2 in instruction_tuples:
+        assert "smla" in instruction
+
+        # We call the instruction using the Arm C Language Extensions. Using ACLE gives better
+        # cross-compiler compatibility than using __builtin functions.
+        yield f"sum_{index} = __{instruction}({op1}, {op2}, sum_{index});"

Review Comment:
   i suggest you put this part as a method closer to the definition of SMLAInstruction. doing this will help to clarify the inputs to SMLAInstruction and the overall objective of that class.  you could either create a class out of SMLAInstruction namedtuple:
   
   ```
   _SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
   class SMLAInstruction(_SMLAInstruction):
     def codegen(self, sum_register):
       return f"{sum_register} = __{self.instruction}({self.tensor_var)}, {self.kernel_var}, {sum_register})"
   ```
   
   or use a dataclass here which is the Python 3.7+ way to do it



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None
+
+
+def _get_int16_alias(position) -> str:
+    if not position:
+        return "unknown"
+    y, x = position
+    return f"y{y:0>2x}_x{x:0>2x}"
+
+
+def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    offset = int(not bool(halfwords[0]))
+
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        y, x = halfwords[i + 1] or halfwords[i]
+        tensor_index = (y * tensor_w + x + offset) // 2
+        yield f"int tensor__{var_name} = tensor[{tensor_index}];"
+
 
+def _load_kernel_vars(halfwords) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        yield f"int kernel__{var_name} = kernel[{i // 2}];"
 
-def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str:
-    """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has
-    a `jump` argument that advances the pointer of one tensor by that many words after each row. The
-    `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned
-    memory access is slow on the Cortex-M series. Depending on the input datatype, the code may
-    contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See
-    the below pseudocode for reference:
-
-    tensordot(out_ptr, dat_ptr, ker_ptr) {
-        sum = 0;
-        for (i = 0; i < tensor_h; i++) {
-            for (j = 0; j < tensor_w; j++) {
-                sum += (*dat_ptr++) * (*ker_ptr++);
-            }
-            dat_ptr += jump;
-        }
-        *out_ptr = sum;
-    }
+
+def _get_draft_macs(
+    kernel_dims, tensor_halfwords, kernel_halfwords, offset
+) -> Iterator[SMLAInstruction]:
+    """Generates unrolled MAC instructions to compute one tensordot sum.
+
+    Unrolling these loops increases code size a tiny bit (< 0.02 KB), but makes the generated code
+    much faster. The generated code does not use SIMD instructions - they are added later by
+    _apply_simd_optimizations.
+
+    We return an iterator of SMLAInstruction named tuples. Returning an iterator lets us do
+    optimizations by iterator chaining.
+    """
+
+    def get_var(y, x, halfwords) -> Tuple[str, str]:
+        i = halfwords.index((y, x))
+        if i % 2 == 0:
+            return f"{_get_int16_alias((y, x))}__{_get_int16_alias(halfwords[i + 1])}", "b"
+        return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
+
+    kernel_h, kernel_w = kernel_dims
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            tensor_var, tensor_half = get_var(y, x + offset, tensor_halfwords)
+            kernel_var, kernel_half = get_var(y, x, kernel_halfwords)
+            instruction = f"smla{tensor_half}{kernel_half}"
+            yield SMLAInstruction(instruction, f"tensor__{tensor_var}", f"kernel__{kernel_var}")
+
+
+def _apply_simd_optimizations(instruction_tuples) -> Iterator[SMLAInstruction]:
+    """When possible, fuses single MACs into SIMD MAC instructions.
+
+    The compiler cannot do this automatically, as calling __smlaxy forces the SMLAxy instruction to
+    be used. This function takes as input an iterator of SMLAInstructions and returns an iterator of
+    SMLAInstructions (possibly of different length).
     """
+    curr_tuple = next(instruction_tuples, None)
+    while curr_tuple:
+        next_tuple = next(instruction_tuples, None)
+        if not next_tuple:
+            yield curr_tuple
+            break
 
-    simd_lanes = num_simd_lanes_per_word(in_dtype)
-    assert tensor_w % simd_lanes == 0
-    assert jump % simd_lanes == 0
+        if curr_tuple[1:] == next_tuple[1:]:
+            if set([curr_tuple[0], next_tuple[0]]) == set(["smlatt", "smlabb"]):
+                yield SMLAInstruction("smlad", *curr_tuple[1:])
+                next_tuple = next(instruction_tuples, None)
+            elif set([curr_tuple[0], next_tuple[0]]) == set(["smlatb", "smlabt"]):
+                yield SMLAInstruction("smladx", *curr_tuple[1:])
+                next_tuple = next(instruction_tuples, None)
+            else:
+                yield curr_tuple
 
-    if in_dtype == "int8":
-        inner_loop = """
-              uint32_t tensor_c20 = __SXTB16(tensor_batch);
-              uint32_t kernel_c20 = __SXTB16(kernel_batch);
-              sum = __SMLAD(tensor_c20, kernel_c20, sum);
+        else:
+            yield curr_tuple
+        curr_tuple = next_tuple
 
-              uint32_t tensor_c31 = __SXTB16(__ROR(tensor_batch, 8));
-              uint32_t kernel_c31 = __SXTB16(__ROR(kernel_batch, 8));
-              sum = __SMLAD(tensor_c31, kernel_c31, sum);"""
 
-    elif in_dtype == "int16":
-        inner_loop = """
-              sum = __SMLAD(tensor_batch, kernel_batch, sum);"""
+def _expand_instruction_tuples(instruction_tuples, index) -> Iterator[str]:
+    """Converts an iterator of SMLAInstructions into lines of C code.
 
-    elif in_dtype == "int32":
-        inner_loop = """
-              // Compiles to a single MAC instruction
-              sum += tensor_batch * kernel_batch;"""
+    We want the compiler to re-order these with the memory loads, so we generate them as a series of
+    calls to instruction aliases instead of as a single `asm` block.
+    """
+
+    for instruction, op1, op2 in instruction_tuples:
+        assert "smla" in instruction
+
+        # We call the instruction using the Arm C Language Extensions. Using ACLE gives better
+        # cross-compiler compatibility than using __builtin functions.
+        yield f"sum_{index} = __{instruction}({op1}, {op2}, sum_{index});"
+
+
+def _requantize_sums(num_sums, requantize_shift, output_zero_point) -> Iterator[str]:
+    """Generates code to requantize the accumulator values.
+
+    The generated code does not use floating point instructions, as it simulates floating point
+    multiplication with an a int64 multiply + shift. The bias is added at the beginning, so we can
+    skip doing it now. The shift is hard-coded, as this saves a few cycles without hurting accuracy
+    in "most" cases.
+
+    It's *possible* we could save one more cycle here by pre-multiplying the bias with the
+    requantize multiplier, and then doing the bias addition and shift in the same cycle (via <op2>).
+    However, it's complicated and only saves one cycle.
+
+    It's also worth noting the SSAT16 operation doesn't help us here. The data isn't stored as two
+    halfwords in a word, and rearrainging it would take at least one cycle. Two SSAT operations is
+    just as good.
+
+    Calling __ssat directly is a little bit gross, but GCC and Clang are unreliable about compiling
+    other ways of writing this. Both the multiply + shift and shift + saturation combine to one
+    instruction each.
+    """
+
+    yield "int scale_val = *scale;"
+    for i in range(num_sums):
+        yield f"int requant_{i} = (sum_{i} * (long long) scale_val) >> {requantize_shift - 1};"

Review Comment:
   s/long long/int64_t/? also why int64_t?



##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,334 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from collections import namedtuple
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
+SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
 
-from .common import num_simd_lanes_per_word
 
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
+
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None
+
+
+def _get_int16_alias(position) -> str:
+    if not position:
+        return "unknown"
+    y, x = position
+    return f"y{y:0>2x}_x{x:0>2x}"
+
+
+def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    offset = int(not bool(halfwords[0]))
+
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        y, x = halfwords[i + 1] or halfwords[i]
+        tensor_index = (y * tensor_w + x + offset) // 2
+        yield f"int tensor__{var_name} = tensor[{tensor_index}];"
+
 
+def _load_kernel_vars(halfwords) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        yield f"int kernel__{var_name} = kernel[{i // 2}];"
 
-def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str:
-    """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has
-    a `jump` argument that advances the pointer of one tensor by that many words after each row. The
-    `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned
-    memory access is slow on the Cortex-M series. Depending on the input datatype, the code may
-    contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See
-    the below pseudocode for reference:
-
-    tensordot(out_ptr, dat_ptr, ker_ptr) {
-        sum = 0;
-        for (i = 0; i < tensor_h; i++) {
-            for (j = 0; j < tensor_w; j++) {
-                sum += (*dat_ptr++) * (*ker_ptr++);
-            }
-            dat_ptr += jump;
-        }
-        *out_ptr = sum;
-    }
+
+def _get_draft_macs(
+    kernel_dims, tensor_halfwords, kernel_halfwords, offset
+) -> Iterator[SMLAInstruction]:
+    """Generates unrolled MAC instructions to compute one tensordot sum.
+
+    Unrolling these loops increases code size a tiny bit (< 0.02 KB), but makes the generated code
+    much faster. The generated code does not use SIMD instructions - they are added later by
+    _apply_simd_optimizations.
+
+    We return an iterator of SMLAInstruction named tuples. Returning an iterator lets us do
+    optimizations by iterator chaining.
+    """
+
+    def get_var(y, x, halfwords) -> Tuple[str, str]:
+        i = halfwords.index((y, x))
+        if i % 2 == 0:
+            return f"{_get_int16_alias((y, x))}__{_get_int16_alias(halfwords[i + 1])}", "b"
+        return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
+
+    kernel_h, kernel_w = kernel_dims
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            tensor_var, tensor_half = get_var(y, x + offset, tensor_halfwords)
+            kernel_var, kernel_half = get_var(y, x, kernel_halfwords)
+            instruction = f"smla{tensor_half}{kernel_half}"
+            yield SMLAInstruction(instruction, f"tensor__{tensor_var}", f"kernel__{kernel_var}")
+
+
+def _apply_simd_optimizations(instruction_tuples) -> Iterator[SMLAInstruction]:
+    """When possible, fuses single MACs into SIMD MAC instructions.
+
+    The compiler cannot do this automatically, as calling __smlaxy forces the SMLAxy instruction to
+    be used. This function takes as input an iterator of SMLAInstructions and returns an iterator of
+    SMLAInstructions (possibly of different length).
     """
+    curr_tuple = next(instruction_tuples, None)
+    while curr_tuple:
+        next_tuple = next(instruction_tuples, None)
+        if not next_tuple:
+            yield curr_tuple
+            break
 
-    simd_lanes = num_simd_lanes_per_word(in_dtype)
-    assert tensor_w % simd_lanes == 0
-    assert jump % simd_lanes == 0
+        if curr_tuple[1:] == next_tuple[1:]:
+            if set([curr_tuple[0], next_tuple[0]]) == set(["smlatt", "smlabb"]):
+                yield SMLAInstruction("smlad", *curr_tuple[1:])

Review Comment:
   rather than `*curr_tuple[1:]`, suggest to write out the whole thing for clarity



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org