You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/23 12:05:36 UTC

[GitHub] [tvm] Mousius commented on a diff in pull request #13242: [microTVM] [WIP] Modernize Arm Cortex-M convolution schedules

Mousius commented on code in PR #13242:
URL: https://github.com/apache/tvm/pull/13242#discussion_r1030362145


##########
python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py:
##########
@@ -14,142 +14,333 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
-including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
-are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
-this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
+"""Generates optimized code to compute a tensor dot product on ARMv7E-M.
 
+This function can be used to tensorize many common operators including regular conv2d, depthwise
+conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
+layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
+and kernel layout OIHW.
+"""
+
+from itertools import chain
 import textwrap
+from typing import Iterator, Tuple
 
-from tvm import te, tir
 
-from .common import num_simd_lanes_per_word
+def _get_c_function_name(split_size, dimensions, offsets, x_strides):
+    """Generates a C function name for tensordot.
 
+    We do not need a suffix, as the generated function will have an #include guard. Unlike other
+    microTVM operators, _get_c_function_name is never called externally.
+    """
+    tensor_w, kernel_h, kernel_w = dimensions
+    return (
+        f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
+        + f"{kernel_h}x{kernel_w}_"
+        + "".join(map(str, offsets))
+        + (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
+    )
 
-def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
-    """Gets the C function name of the tensordot function."""
-    return f"tensordot_{in_dtype}_h{tensor_h}_j{jump}_w{tensor_w}_{suffix}"
 
+def _init_biased_accumulators(split_size):
+    """Generates code to load the bias into the accumulators.
 
-def make_intrin_tensordot(slices, strides, tensordot_params):
-    """Helper function for constructing tensordot intrinsic. We can't construct the whole thing here
-    (as multiple schedules use tensordot and each must build the intrinstic differently) but we can
-    build part here to simplify the code."""
+    Addition is commutative, so we could add the bias before, during, or after performing our
+    multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
 
-    # in_dtype, tensor_h, jump, tensor_w, suffix = tensordot_params
-    data, kernel, output = slices
-    data_strides, kernel_strides = strides
+    Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
+    trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
+    so we'll do it first.
+    """
+    assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
+    joined_assignments = ", ".join(assignments)
+    return f"int {joined_assignments};"
 
-    data_buf = tir.decl_buffer(
-        data.shape, data.dtype, name="data", offset_factor=1, strides=data_strides
-    )
-    kernel_buf = tir.decl_buffer(
-        kernel.shape,
-        kernel.dtype,
-        name="kernel",
-        offset_factor=1,
-        strides=kernel_strides,
-    )
-    output_buf = tir.decl_buffer(
-        output.shape, output.dtype, name="output", offset_factor=1, strides=[1]
-    )
 
-    def intrin_func(ins, outs):
-        builder = tir.ir_builder.create()
-        builder.emit(
-            tir.call_extern(
-                "int32",
-                _get_func_name(*tensordot_params),
-                outs[0].access_ptr("w"),
-                ins[0].access_ptr("r"),
-                ins[1].access_ptr("r"),
-            )
-        )
-        return builder.get()
+def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
+    tensor_w, kernel_h, kernel_w = dimensions
 
-    return te.decl_tensor_intrin(
-        output.op,
-        intrin_func,
-        binds={data: data_buf, kernel: kernel_buf, output: output_buf},
-    )
+    split_max = (split_size - 1) * in_stride
+    for y in range(kernel_h):
+        if y * tensor_w % 2 + offset == 1:
+            yield None
+        for x in range(kernel_w + split_max):
+            yield (y, x)
+        if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
+            yield None
+
+
+def _get_kernel_halfwords(dimensions, offset) -> Iterator:
+    _, kernel_h, kernel_w = dimensions
+    if offset == 1:
+        yield None
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            yield (y, x)
+    if (kernel_h * kernel_w + offset) % 2 == 1:
+        yield None
+
+
+def _get_int16_alias(position) -> str:
+    if not position:
+        return "unknown"
+    y, x = position
+    return f"y{y:0>2x}_x{x:0>2x}"
+
+
+def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    offset = int(not bool(halfwords[0]))
+
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        y, x = halfwords[i + 1] or halfwords[i]
+        tensor_index = (y * tensor_w + x + offset) // 2
+        yield f"int tensor__{var_name} = tensor[{tensor_index}];"
+
+
+def _load_kernel_vars(halfwords) -> Iterator[str]:
+    assert len(halfwords) % 2 == 0
+    for i in range(0, len(halfwords), 2):
+        var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
+        yield f"int kernel__{var_name} = kernel[{i // 2}];"
 
 
-def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffix: str) -> str:
-    """Generates C code for taking the dot products of two `tensor_h` * `tensor_w` tensors. Also has
-    a `jump` argument that advances the pointer of one tensor by that many words after each row. The
-    `jump` and `tensor_w` values must be word-aligned for the input data type, as non-word-aligned
-    memory access is slow on the Cortex-M series. Depending on the input datatype, the code may
-    contain DSP instructions for Arm v7e-m. C code contains DSP instructions for Arm v7e-m. See
-    the below pseudocode for reference:
-
-    tensordot(out_ptr, dat_ptr, ker_ptr) {
-        sum = 0;
-        for (i = 0; i < tensor_h; i++) {
-            for (j = 0; j < tensor_w; j++) {
-                sum += (*dat_ptr++) * (*ker_ptr++);
-            }
-            dat_ptr += jump;
-        }
-        *out_ptr = sum;
-    }
+def _get_draft_macs(kernel_dims, tensor_halfwords, kernel_halfwords, offset) -> Iterator[Tuple]:
+    """Generates an un-optimized list of multiply-accumulate instructions.
+
+    We will optimize these into SIMD instructions later. The tuples in the returned iterator are
+    organized as:
+
+    (instruction, (arg1_y, arg1_x), (arg2_y, arg2_x))
+
+    We return an iterator so that optimizations may be done by iterator chaining.
+    """
+
+    def get_var(y, x, halfwords):
+        i = halfwords.index((y, x))
+        if i % 2 == 0:
+            return f"{_get_int16_alias((y, x))}__{_get_int16_alias(halfwords[i + 1])}", "b"
+        return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
+
+    kernel_h, kernel_w = kernel_dims
+    for y in range(kernel_h):
+        for x in range(kernel_w):
+            tensor_var, tensor_half = get_var(y, x + offset, tensor_halfwords)
+            kernel_var, kernel_half = get_var(y, x, kernel_halfwords)
+            instruction = f"smla{tensor_half}{kernel_half}"
+            yield instruction, f"tensor__{tensor_var}", f"kernel__{kernel_var}"
+
+
+def _apply_simd_optimizations(instruction_tuples) -> Iterator[Tuple]:
+    """When possible, fuses single MACs into SIMD MAC instructions.
+
+    The compiler cannot do this automatically, as calling __builtin_arm_smlaxy forces the SMLAxy
+    instruction to be used. This function takes as input an iterator of (instruction, var1, var2)
+    tuples, and returns an iterator of (instruction, var1, var2) tuples.
     """
+    curr_tuple = next(instruction_tuples, None)
+    while curr_tuple:
+        next_tuple = next(instruction_tuples, None)
+        if not next_tuple:
+            yield curr_tuple
+            break
 
-    simd_lanes = num_simd_lanes_per_word(in_dtype)
-    assert tensor_w % simd_lanes == 0
-    assert jump % simd_lanes == 0
+        if curr_tuple[1:] == next_tuple[1:]:
+            if set([curr_tuple[0], next_tuple[0]]) == set(["smlatt", "smlabb"]):
+                yield ("smlad", *curr_tuple[1:])
+                next_tuple = next(instruction_tuples, None)
+            elif set([curr_tuple[0], next_tuple[0]]) == set(["smlatb", "smlabt"]):
+                yield ("smladx", *curr_tuple[1:])
+                next_tuple = next(instruction_tuples, None)
+            else:
+                yield curr_tuple
 
-    if in_dtype == "int8":
-        inner_loop = """
-              uint32_t tensor_c20 = __SXTB16(tensor_batch);
-              uint32_t kernel_c20 = __SXTB16(kernel_batch);
-              sum = __SMLAD(tensor_c20, kernel_c20, sum);
+        else:
+            yield curr_tuple
+        curr_tuple = next_tuple
 
-              uint32_t tensor_c31 = __SXTB16(__ROR(tensor_batch, 8));
-              uint32_t kernel_c31 = __SXTB16(__ROR(kernel_batch, 8));
-              sum = __SMLAD(tensor_c31, kernel_c31, sum);"""
 
-    elif in_dtype == "int16":
-        inner_loop = """
-              sum = __SMLAD(tensor_batch, kernel_batch, sum);"""
+def _expand_instruction_tuples(instruction_tuples, index) -> Iterator[str]:
+    """Converts a series of (instruction, var1, var2) tuples into lines of C code.
 
-    elif in_dtype == "int32":
-        inner_loop = """
-              // Compiles to a single MAC instruction
-              sum += tensor_batch * kernel_batch;"""
+    We want the compiler to re-order these with the memory loads, so we generate them as a series of
+    calls to instruction aliases instead of as a single `asm` block.
+    """
+
+    for instruction, op1, op2 in instruction_tuples:
+        assert "smla" in instruction
+
+        # Arm GCC does not have `__builtin_arm_smlabt`, even though `__builtin_arm_smlatt`,
+        # `__builtin_arm_smlatb`, `__builtin_arm_smlad` and so on all exist. Perhaps this is a
+        # choice, since we can just use `smlabt` with the argument order swapped instead? Note that
+        # `__builtin_arm_smlabt` exists on most compilers (e.g. Clang) - this is just a GCC thing.
+        if instruction == "smlabt":
+            yield f"sum_{index} = __builtin_arm_smlatb({op2}, {op1}, sum_{index});"
+        else:
+            yield f"sum_{index} = __builtin_arm_{instruction}({op1}, {op2}, sum_{index});"

Review Comment:
   I believe this is because you're using the builtins directly rather than using the ACLE interface (
   https://arm-software.github.io/acle/main/acle.html#accumulating-multiplications) - unsure how much guarantee you get with built-ins, I would move to the ACLE interface anyway.
   
   Also see: https://github.com/gcc-mirror/gcc/blob/master/gcc/config/arm/arm_acle.h#L661-L675 😸 
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org