You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/23 10:23:04 UTC

[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6445: Add dot product support for quantized convolution.

mbaret commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r493346458



##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -135,20 +135,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="conv2d_direct_simd.micro_dev",
                 )
             elif kernel_layout == "HWIO":
-                is_aarch64 = "aarch64" in str(isa.target)
-
+                is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
+                has_dot_prod = topi.arm_cpu.arm_utils.is_fast_int8_on_arm()
+                if has_dot_prod and data.dtype in ["int8", "uint8"]:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_hybrid),

Review comment:
       hybrid typically refers to something implemented using hybrid script, might want a different name?

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -135,20 +135,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="conv2d_direct_simd.micro_dev",
                 )
             elif kernel_layout == "HWIO":
-                is_aarch64 = "aarch64" in str(isa.target)
-
+                is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
+                has_dot_prod = topi.arm_cpu.arm_utils.is_fast_int8_on_arm()
+                if has_dot_prod and data.dtype in ["int8", "uint8"]:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_hybrid),
+                        wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_hybrid),
+                        name="conv2d_NHWC_quantized_hybrid.arm_cpu",
+                    )
                 if is_aarch64 and data.dtype in ["int8", "uint8"]:
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
                         wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
                         name="conv2d_NHWC_quantized.arm_cpu",
                     )
-
-                strategy.add_implementation(
-                    wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
-                    wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
-                    name="conv2d_nhwc_spatial_pack.arm_cpu",
-                )
+                if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
+                    # TODO

Review comment:
       assign the TODO

##########
File path: python/tvm/topi/arm_cpu/arm_utils.py
##########
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Arm target utility functions"""
+
+import tvm
+
+
+def is_fast_int8_on_arm():
+    """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+    target = tvm.target.Target.current(allow_none=False)
+    return "+v8.2a" in target.mattr and "+dotprod" in target.mattr

Review comment:
       @u99127 Could you take a look at whether this check will be sufficient in the general (and future) case?

##########
File path: python/tvm/topi/arm_cpu/conv2d_alter_op.py
##########
@@ -27,10 +27,52 @@
 from ..nn import conv2d_alter_layout
 from ..util import get_const_tuple
 from ..x86.conv2d import _get_default_config as _get_x86_default_config
+from .arm_utils import is_fast_int8_on_arm
 
 logger = logging.getLogger("topi")
 
 
+def interleave_transpose_B(inputs, data, kernel, interleave_A):

Review comment:
       Document the parameters here

##########
File path: python/tvm/topi/arm_cpu/conv2d_alter_op.py
##########
@@ -280,43 +322,36 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
     if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
-        assert (
-            data.dtype == "int8"
-            and kernel.dtype == "int8"
-            or data.dtype == "uint8"
-            and kernel.dtype == "uint8"
-        )
         assert data_layout == "NHWC" and kernel_layout == "HWIO"
         KH, KW, IC, OC = get_const_tuple(kernel.shape)
-        K = KH * KW * IC
         N = OC
-
-        tile_rows = 4
-        tile_cols = 16
-        pad_K = 0
-        pad_N = 0
-
-        if N % tile_rows != 0:
-            pad_N = tile_rows - (N % tile_rows)
-        if K % tile_cols != 0:
-            pad_K = tile_cols - (K % tile_cols)
-
-        N_padded = N + pad_N
-        K_padded = K + pad_K
-        kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_rows, tile_cols)
-        new_kernel = te.placeholder(
-            (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols), kernel.dtype
-        )
-
         new_workload_name = "conv2d_NHWC_quantized_without_transform.arm_cpu"
+        new_kernel, new_kernel_expr = interleave_transpose_B(
+            inputs, data, kernel, interleave_A=True
+        )
         new_workload = autotvm.task.args_to_workload(
             [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
             new_workload_name,
         )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.contrib_conv2d_gemm_without_weight_transform(
-            inputs[0], kernel_expr, **new_attrs
+            inputs[0], new_kernel_expr, **new_attrs
+        )
+    if topi_tmpl == "conv2d_NHWC_quantized_hybrid.arm_cpu":
+        assert data_layout == "NHWC" and kernel_layout == "HWIO"
+        KH, KW, IC, OC = get_const_tuple(kernel.shape)
+        N = OC

Review comment:
       Not needed

##########
File path: python/tvm/topi/arm_cpu/conv2d_alter_op.py
##########
@@ -27,10 +27,52 @@
 from ..nn import conv2d_alter_layout
 from ..util import get_const_tuple
 from ..x86.conv2d import _get_default_config as _get_x86_default_config
+from .arm_utils import is_fast_int8_on_arm
 
 logger = logging.getLogger("topi")
 
 
+def interleave_transpose_B(inputs, data, kernel, interleave_A):
+    """Return the new placeholder and the expression that represent
+    the matrix B transposed and interleaved"""
+
+    assert (
+        data.dtype == "int8"
+        and kernel.dtype == "int8"
+        or data.dtype == "uint8"
+        and kernel.dtype == "uint8"
+    )
+
+    KH, KW, IC, OC = get_const_tuple(kernel.shape)
+    K = KH * KW * IC
+    N = OC
+
+    if is_fast_int8_on_arm():
+        tile_rows_B = 12 if interleave_A else 16
+        tile_cols_B = 4
+    else:
+        tile_rows_B = 4
+        tile_cols_B = 16

Review comment:
       Some additional documentation here to explain the difference would be helpful

##########
File path: python/tvm/topi/arm_cpu/arm_utils.py
##########
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Arm target utility functions"""
+
+import tvm
+
+
+def is_fast_int8_on_arm():

Review comment:
       Recommend making this more explicitly related to the dotproduct

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -74,11 +108,7 @@ def compute_conv2d_gemm_without_weight_transform(
 
     A_shape = (batches, M, K)
     if K_AREA == 1:

Review comment:
       Would clash less with later definitions if this were kernel_area

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0
 
-    if M % 4 != 0:
-        pad_m = 4 - (M % 4)
+    if M % tile_rows_A != 0:
+        pad_m = tile_rows_A - (M % tile_rows_A)
 
-    if K % 16 != 0:
-        pad_k = 16 - (K % 16)
+    if K % tile_cols_A != 0:
+        pad_k = tile_cols_A - (K % tile_cols_A)
 
     M_padded = M + pad_m
     K_padded = K + pad_k
+    N_padded = N_transformed * tile_rows_B
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_m, pad_k)
 
     if pad_m != 0 or pad_k != 0:
         A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
 
-    # --- GEMM: A*B'
+    idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute(
-        (batches, M_padded // 4, K_padded // 16, 4, 16),
-        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-        name="A_interleaved",
-    )
-
-    C_interleaved = te.compute(
-        (batches, M_padded // 4, N_transformed, 4, 4),
-        lambda b, x, y, w, z: te.sum(
-            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
-            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
-            axis=k,
-        ),
-        name="C_interleaved",
-    )
+    if interleave_A:
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded)
 
-    # --- Unpack C
-    C = te.compute(
-        (batches, M, N),
-        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-        name="C",
-    )
+        # Pack A
+        A_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
+            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
+            name="A_interleaved",
+        )
+        # Compute C

Review comment:
       Mention this is doing the GEMM

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0
 
-    if M % 4 != 0:
-        pad_m = 4 - (M % 4)
+    if M % tile_rows_A != 0:
+        pad_m = tile_rows_A - (M % tile_rows_A)
 
-    if K % 16 != 0:
-        pad_k = 16 - (K % 16)
+    if K % tile_cols_A != 0:
+        pad_k = tile_cols_A - (K % tile_cols_A)
 
     M_padded = M + pad_m
     K_padded = K + pad_k
+    N_padded = N_transformed * tile_rows_B
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_m, pad_k)
 
     if pad_m != 0 or pad_k != 0:
         A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
 
-    # --- GEMM: A*B'
+    idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute(
-        (batches, M_padded // 4, K_padded // 16, 4, 16),
-        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-        name="A_interleaved",
-    )
-
-    C_interleaved = te.compute(
-        (batches, M_padded // 4, N_transformed, 4, 4),
-        lambda b, x, y, w, z: te.sum(
-            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
-            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
-            axis=k,
-        ),
-        name="C_interleaved",
-    )
+    if interleave_A:
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded)
 
-    # --- Unpack C
-    C = te.compute(
-        (batches, M, N),
-        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-        name="C",
-    )
+        # Pack A
+        A_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
+            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
+            name="A_interleaved",
+        )
+        # Compute C
+        C_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
+            lambda b, x, y, w, z: te.sum(
+                A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
+                * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
+                axis=k,
+            ),
+            name="C_interleaved",
+        )
+        # Unpack C
+        C = te.compute(
+            (batches, M, N),
+            lambda b, x, y: C_interleaved[
+                b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
+            ].astype(out_dtype),
+            name="C",
+        )
+        zero = tvm.tir.const(0)
+    else:
+        # No need to pack/unpack
+        C = te.compute(
+            (batches, M_padded, N_padded),
+            lambda b, x, y: te.sum(
+                A[b, x, k].astype("int32")
+                * B_interleaved_t[
+                    y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B)
+                ].astype("int32"),
+                axis=k,
+            ),
+            name="C",
+        )
+        zero = (
+            tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+            - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+        )
 
     # --- Produce the conv output
     out_shape = (batches, OH, OW, OC)
-    out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z), name="conv2d_gemm_output")
-
-    # Configuration space
-    x, y = cfg.axis(M_padded // 4), cfg.axis(K_padded // 16)
-    cfg.define_reorder("reorder_gemm", [x, y], policy="candidate", candidate=[[x, y], [y, x]])
-
-    outer_loop, inner_loop = cfg.axis(4), cfg.axis(16)
-    cfg.define_annotate(
-        "A_interleaved_unroll_vec", [outer_loop, inner_loop], policy="try_unroll_vec"
+    out = te.compute(
+        out_shape,
+        lambda b, x, y, z: (C(b, y + OW * x, z) + zero).astype(out_dtype),
+        name="conv2d_gemm_output",
     )
-    cfg.define_knob("gemm_quantized_unroll", [True, False])
-    cfg.define_knob("gemm_quantized_interleave", [True, False])
-
-    # Fallback configuration
-    if cfg.is_fallback:
-        cfg["reorder_gemm"] = ReorderEntity([0, 1])
-        cfg["A_interleaved_unroll_vec"] = AnnotateEntity(["unroll", "vec"])
-        cfg["gemm_quantized_unroll"] = OtherOptionEntity(False)
-        cfg["gemm_quantized_interleave"] = OtherOptionEntity(True)
     return out
 
 
-# Schedules
 def schedule_conv2d_gemm(cfg, s, out, final_out):

Review comment:
       Rename to schedule_conv2d_gemm_interleaved

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0
 
-    if M % 4 != 0:
-        pad_m = 4 - (M % 4)
+    if M % tile_rows_A != 0:
+        pad_m = tile_rows_A - (M % tile_rows_A)
 
-    if K % 16 != 0:
-        pad_k = 16 - (K % 16)
+    if K % tile_cols_A != 0:
+        pad_k = tile_cols_A - (K % tile_cols_A)
 
     M_padded = M + pad_m
     K_padded = K + pad_k
+    N_padded = N_transformed * tile_rows_B
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_m, pad_k)
 
     if pad_m != 0 or pad_k != 0:
         A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
 
-    # --- GEMM: A*B'
+    idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute(
-        (batches, M_padded // 4, K_padded // 16, 4, 16),
-        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-        name="A_interleaved",
-    )
-
-    C_interleaved = te.compute(
-        (batches, M_padded // 4, N_transformed, 4, 4),
-        lambda b, x, y, w, z: te.sum(
-            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
-            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
-            axis=k,
-        ),
-        name="C_interleaved",
-    )
+    if interleave_A:
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded)
 
-    # --- Unpack C
-    C = te.compute(
-        (batches, M, N),
-        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-        name="C",
-    )
+        # Pack A
+        A_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
+            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
+            name="A_interleaved",
+        )
+        # Compute C
+        C_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
+            lambda b, x, y, w, z: te.sum(
+                A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
+                * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
+                axis=k,
+            ),
+            name="C_interleaved",
+        )
+        # Unpack C
+        C = te.compute(
+            (batches, M, N),
+            lambda b, x, y: C_interleaved[
+                b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
+            ].astype(out_dtype),
+            name="C",
+        )
+        zero = tvm.tir.const(0)
+    else:
+        # No need to pack/unpack
+        C = te.compute(
+            (batches, M_padded, N_padded),
+            lambda b, x, y: te.sum(
+                A[b, x, k].astype("int32")
+                * B_interleaved_t[
+                    y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B)
+                ].astype("int32"),
+                axis=k,
+            ),
+            name="C",
+        )
+        zero = (
+            tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+            - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+        )
 
     # --- Produce the conv output

Review comment:
       Document why this is different to the previous unpacking step.

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -135,20 +135,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="conv2d_direct_simd.micro_dev",
                 )
             elif kernel_layout == "HWIO":
-                is_aarch64 = "aarch64" in str(isa.target)
-
+                is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
+                has_dot_prod = topi.arm_cpu.arm_utils.is_fast_int8_on_arm()
+                if has_dot_prod and data.dtype in ["int8", "uint8"]:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_hybrid),
+                        wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_hybrid),
+                        name="conv2d_NHWC_quantized_hybrid.arm_cpu",
+                    )
                 if is_aarch64 and data.dtype in ["int8", "uint8"]:
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),

Review comment:
       Maybe change to quantized_interleaved to align with later definitions.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,223 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(i) = v[i:i+3]
+    replicated_v(i) = [vsub(i), vsub(i), vsub(i), vsub(i)]
+
+    Note that i can vary between 0 and 3

Review comment:
       0 <= i <= 12

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,223 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):

Review comment:
       Document params

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16

Review comment:
       Comment on magic numbers

##########
File path: python/tvm/topi/arm_cpu/conv2d_int8.py
##########
@@ -41,6 +46,20 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
         )
 
 
+def get_tiling_B(interleave_A):

Review comment:
       Move to utils and use in alter_op_layout

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0

Review comment:
       Maybe capitalize m and k here to match with later convention

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0
 
-    if M % 4 != 0:
-        pad_m = 4 - (M % 4)
+    if M % tile_rows_A != 0:
+        pad_m = tile_rows_A - (M % tile_rows_A)
 
-    if K % 16 != 0:
-        pad_k = 16 - (K % 16)
+    if K % tile_cols_A != 0:
+        pad_k = tile_cols_A - (K % tile_cols_A)
 
     M_padded = M + pad_m
     K_padded = K + pad_k
+    N_padded = N_transformed * tile_rows_B
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_m, pad_k)
 
     if pad_m != 0 or pad_k != 0:
         A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
 
-    # --- GEMM: A*B'
+    idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute(
-        (batches, M_padded // 4, K_padded // 16, 4, 16),
-        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-        name="A_interleaved",
-    )
-
-    C_interleaved = te.compute(
-        (batches, M_padded // 4, N_transformed, 4, 4),
-        lambda b, x, y, w, z: te.sum(
-            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
-            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
-            axis=k,
-        ),
-        name="C_interleaved",
-    )
+    if interleave_A:
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded)
 
-    # --- Unpack C
-    C = te.compute(
-        (batches, M, N),
-        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-        name="C",
-    )
+        # Pack A
+        A_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
+            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
+            name="A_interleaved",
+        )
+        # Compute C
+        C_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
+            lambda b, x, y, w, z: te.sum(
+                A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
+                * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
+                axis=k,
+            ),
+            name="C_interleaved",
+        )
+        # Unpack C
+        C = te.compute(
+            (batches, M, N),
+            lambda b, x, y: C_interleaved[
+                b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
+            ].astype(out_dtype),
+            name="C",
+        )
+        zero = tvm.tir.const(0)
+    else:
+        # No need to pack/unpack
+        C = te.compute(
+            (batches, M_padded, N_padded),
+            lambda b, x, y: te.sum(
+                A[b, x, k].astype("int32")
+                * B_interleaved_t[
+                    y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B)
+                ].astype("int32"),
+                axis=k,
+            ),
+            name="C",
+        )
+        zero = (

Review comment:
       Document this trick




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org