You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/10 17:00:46 UTC

[GitHub] [incubator-tvm] giuseros opened a new pull request #6445: Add dot product support for quantized convolution.

giuseros opened a new pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445


   ### High level description of the submission
   We added two new intrinsics in: `topi/arm_cpu/tensor_intrin.py`, namely
   - `mmla4x4`: compute a matrix multiplication between tile `A(4,4)` and tile
     `B(4,4)`
   - `mmla16x4`: compute a matrix multiplication between tile `A(rows,4)` and tile
     `B(4,16)`
   Then we used those intrinsics in two separate strategies. We added the
   strategies in `topi/arm_cpu/conv2d_int8.py` and implemented the schedules
   in `topi/arm_cpu/conv2d_gemm.py`. In particular:
   - `schedule_conv2d_gemm`, when accelerated, packs matrix `A`, compute GEMM,
     and unpack the resulting matrix. This uses the `mmla4x4` intrinsic
   - `schedule_conv2d_gemm_hybrid` doesn't do any packing on `A` and `C` which
     are in native form.  This uses the `mmla16x4` intrinsic
   
   Please note that for the limitations of `tensorize` we need to pad
   matrix `A` in both cases (when dimensions are not multiple of the tiling
   shape)
   
   ### RFC
   This PR is based on the following RFC: https://discuss.tvm.apache.org/t/rfc-accelerate-quantized-convolution-through-dot-product/7873
   
   Change-Id: Id0d818d84ffc458c6dad7983fd350a0f3d5db395


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-704289187


   Thanks @FrozenGene ! 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r500281070



##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -445,7 +443,7 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
     )
 
     c_buffer = tvm.tir.decl_buffer(
-        C.shape, dtype=out_type, name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+        C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]

Review comment:
       Hi @FrozenGene , the problem is the following: in quantized `conv2d`, we do `conv2d` and then `requantization` (those are two different relay operators). Conv2d goes from `int8->int32`, and requantization goes from `int32->int8`. So in theory this would work with `out_type`.
   
   However, in some tests (pre-existing to my changes, that I run internally) I noticed that they set the (`conv2d`) `out_type` to `int8`(or `uint8`). In this case the intrinsic still needs to produce an `int32` value and the cast needs to happen at a later stage. 
   
   This change is basically saying: no matter the `out_type` the intrinsic will produce a `int32` result. If we want the output to be `int8` (which would be wrong, but some tests do it to simplify the testing) the conversion needs to happen later. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-698874588


   Hi @mbaret ,
   Thanks for the review! 
   
   I addressed the comments and added compilation tests to verify the compilation flow with dot-product. 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-703764060


   Hi @mbaret , 
   Thank you for the careful review!
   
   @FrozenGene , @anijain2305 should we merge this in?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on a change in pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
giuseros commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r500281070



##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -445,7 +443,7 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
     )
 
     c_buffer = tvm.tir.decl_buffer(
-        C.shape, dtype=out_type, name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+        C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]

Review comment:
       Hi @FrozenGene , the problem is the following: in quantized `conv2d`, we do `conv2d` and then `requantization` (those are two different relay operators). Conv2d goes from `int8->int32`, and requantization goes from `int32->int8`. So in theory this would work with `out_type`.
   
   However, in some tests (pre-existing to my changes, that I run internally) I noticed that they set the (`conv2d`) `out_type` to `int8`(or `uint8`). In this case the intrinsic still needs to produce an `int32` value and the cast to `int8` (or `uint8`) needs to happen at a later stage. 
   
   This change is basically saying: no matter the `out_type` the intrinsic will produce a `int32` result. If we want the output to be `int8` (which would be wrong, but some tests do it to simplify the testing) the conversion needs to happen later. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r500261911



##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -445,7 +443,7 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
     )
 
     c_buffer = tvm.tir.decl_buffer(
-        C.shape, dtype=out_type, name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+        C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]

Review comment:
       what is the reason changing this?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r498880238



##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and

Review comment:
       Might be useful to include a note that n = rows?

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((rows, 16), dtype, name="data")
+    B = te.placeholder((4, 16, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+    idxm = tvm.tir.indexmod
+    k = te.reduce_axis((0, 16), name="k")
+    C = te.compute(
+        (rows, 16),
+        lambda i, j: te.sum(
+            A[i, k].astype("int32") * B[k // 4, j, idxm(k, 4)].astype("int32"), axis=k
+        ),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape,
+        dtype,
+        name="bb_buffer",
+        offset_factor=1,
+        strides=[te.var("sb0"), te.var("sb1"), 1],
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, rows):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16")))
+                return ib.get()
+            # Iterate on the number of rows of the output
+            for k in range(0, rows):
+                # Load 16 elements of A
+                # vec_a = [a, b, c, e, f, g, h, i, l, m, n, o, p, q, r,];

Review comment:
       Avoid 'i' as well I think

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((rows, 16), dtype, name="data")
+    B = te.placeholder((4, 16, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+    idxm = tvm.tir.indexmod
+    k = te.reduce_axis((0, 16), name="k")
+    C = te.compute(
+        (rows, 16),
+        lambda i, j: te.sum(
+            A[i, k].astype("int32") * B[k // 4, j, idxm(k, 4)].astype("int32"), axis=k
+        ),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape,
+        dtype,
+        name="bb_buffer",
+        offset_factor=1,
+        strides=[te.var("sb0"), te.var("sb1"), 1],
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, rows):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16")))
+                return ib.get()
+            # Iterate on the number of rows of the output
+            for k in range(0, rows):
+                # Load 16 elements of A
+                # vec_a = [a, b, c, e, f, g, h, i, l, m, n, o, p, q, r,];
+                vec_a = ins[0].vload([k, 0], dtype_vec)
+
+                # Iterate over each column of the output

Review comment:
       Revise this comment (output columns = 16)

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((rows, 16), dtype, name="data")
+    B = te.placeholder((4, 16, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+    idxm = tvm.tir.indexmod
+    k = te.reduce_axis((0, 16), name="k")
+    C = te.compute(
+        (rows, 16),
+        lambda i, j: te.sum(
+            A[i, k].astype("int32") * B[k // 4, j, idxm(k, 4)].astype("int32"), axis=k
+        ),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape,
+        dtype,
+        name="bb_buffer",
+        offset_factor=1,
+        strides=[te.var("sb0"), te.var("sb1"), 1],
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, rows):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16")))
+                return ib.get()
+            # Iterate on the number of rows of the output
+            for k in range(0, rows):
+                # Load 16 elements of A
+                # vec_a = [a, b, c, e, f, g, h, i, l, m, n, o, p, q, r,];
+                vec_a = ins[0].vload([k, 0], dtype_vec)
+
+                # Iterate over each column of the output
+                for j in range(0, 4):
+                    # Accumulate over each of the 4 (16x4) tiles contained in B
+                    for i in range(0, 4):
+                        # As before, replicate a single 4-element group of A
+                        vec_aa = select_word(vec_a, i, dtype_vec)
+                        # Load 4 rows (each rows with 4 elements) from B
+                        # vec_b = [0, 16, 32, 48,

Review comment:
       Show the multiplication between vec_a and vec_b

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]

Review comment:
       Motivate this with reference to the registers.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")

Review comment:
       Rename 'name' too

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((rows, 16), dtype, name="data")
+    B = te.placeholder((4, 16, 4), dtype, name="kernel")

Review comment:
       Rename 'name' too

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions

Review comment:
       16x4 -> nx16

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];

Review comment:
       ijk

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.

Review comment:
       Remove




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] u99127 commented on a change in pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
u99127 commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r495302617



##########
File path: python/tvm/topi/arm_cpu/arm_utils.py
##########
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Arm target utility functions"""
+
+import tvm
+
+
+def is_fast_int8_on_arm():
+    """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+    target = tvm.target.Target.current(allow_none=False)
+    return "+v8.2a" in target.mattr and "+dotprod" in target.mattr

Review comment:
       So, architecturally IIRC dot product is an optional feature in Armv8.2-A and Armv8.3-A. IIRC it becomes mandatory from Armv8.4-a. So in theory this needs a bit more proper hierarchical modelling of the architectural features and needs some cleanup. We would need some kind of mandatory and optional features for each arch level and rewrite the tests to be recast in that form rather than this. 
   
   Alternatively it might be worth checking if we can use llvm to see if it has an API to check architecture features in a level of the architecture string passed in via mattr. 
   
   It may also be worth renaming this function as is_int8_dotproduct_on_arm() rather than a generic name like is_fast_int8_on_arm especially as there are other instructions like SMMLA and UMMLA that can be pretty quick as well for int8 arithmetic on AArch64.
   
   
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on a change in pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r500299165



##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -445,7 +443,7 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
     )
 
     c_buffer = tvm.tir.decl_buffer(
-        C.shape, dtype=out_type, name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+        C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]

Review comment:
       get it




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-694884318


   @u99127 , I am on it. It is strange that the command I ran locally didn't catch this. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
FrozenGene commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-704285023


   Thanks everyone. Merged now.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] FrozenGene merged pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
FrozenGene merged pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-698874588


   Hi @mbaret ,
   Thanks for the review! 
   
   I addressed the comments and added compilation tests to verify the compilation flow with dot-product. 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] u99127 commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
u99127 commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-694536822


   Can you see why ci is failing @giuseros ?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r496745509



##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.

Review comment:
       This needs to explicitly reference the accumulation step, also in the pseudo-code

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved

Review comment:
       Matrix A is not interleaved at this stage

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    kernel = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(data[i, k].astype("int32") * kernel[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        kernel.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+

Review comment:
       Could we include some comments here, maybe with an example?

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    kernel = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(data[i, k].astype("int32") * kernel[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        kernel.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={data: aa_buffer, kernel: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def mmla_16x4_int8_int8_int32(dtype, rows):

Review comment:
       This needs at least changing to nx16 with a description of how that interacts with the row parameters.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((te.var("rows"), 4), dtype, name="data")

Review comment:
       Using te.var("rows") here instead of just 4 is strange, could you explain?

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    kernel = te.placeholder((4, 4), dtype, name="kernel")

Review comment:
       Advise either a comment explaining how A/B map to data/kernel or changes these to A/B

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    kernel = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(data[i, k].astype("int32") * kernel[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        kernel.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={data: aa_buffer, kernel: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def mmla_16x4_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][4], int8 B[4][16], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }

Review comment:
       indent

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):

Review comment:
       suggestion mmla -> gemm_acc? This avoids overloading the term mmla which also refers to a hardware intrinsic.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    kernel = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(data[i, k].astype("int32") * kernel[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        kernel.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={data: aa_buffer, kernel: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def mmla_16x4_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][4], int8 B[4][16], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * A is not interleaved, but used in its native form

Review comment:
       I think it's better to describe the properties this matrix has now rather than how it was produced.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    kernel = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(data[i, k].astype("int32") * kernel[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        kernel.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={data: aa_buffer, kernel: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def mmla_16x4_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and

Review comment:
       [4] should be [16], change in other places too

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,236 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def mmla_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[4][4] and
+    B[4][4] and produces a 4x4 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * Matrix A is interleaved
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    kernel = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(data[i, k].astype("int32") * kernel[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        data.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        kernel.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={data: aa_buffer, kernel: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def mmla_16x4_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][4], int8 B[4][16], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    out[i][j] = 0;
+                    for (int k = 0; k < 4; k++){
+                        out[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * A is not interleaved, but used in its native form
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    data = te.placeholder((rows, 16), dtype, name="data")
+    kernel = te.placeholder((4, 16, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+    idxm = tvm.tir.indexmod
+    k = te.reduce_axis((0, 16), name="k")
+    C = te.compute(
+        (rows, 16),
+        lambda i, j: te.sum(
+            data[i, k].astype("int32") * kernel[k // 4, j, idxm(k, 4)].astype("int32"), axis=k
+        ),
+        name="C",
+    )

Review comment:
       This needs further explanation.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] ZihengJiang commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
ZihengJiang commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-693624497


   @anijain2305 @FrozenGene Would you please have a look at this?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r493346458



##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -135,20 +135,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="conv2d_direct_simd.micro_dev",
                 )
             elif kernel_layout == "HWIO":
-                is_aarch64 = "aarch64" in str(isa.target)
-
+                is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
+                has_dot_prod = topi.arm_cpu.arm_utils.is_fast_int8_on_arm()
+                if has_dot_prod and data.dtype in ["int8", "uint8"]:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_hybrid),

Review comment:
       hybrid typically refers to something implemented using hybrid script, might want a different name?

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -135,20 +135,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="conv2d_direct_simd.micro_dev",
                 )
             elif kernel_layout == "HWIO":
-                is_aarch64 = "aarch64" in str(isa.target)
-
+                is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
+                has_dot_prod = topi.arm_cpu.arm_utils.is_fast_int8_on_arm()
+                if has_dot_prod and data.dtype in ["int8", "uint8"]:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_hybrid),
+                        wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_hybrid),
+                        name="conv2d_NHWC_quantized_hybrid.arm_cpu",
+                    )
                 if is_aarch64 and data.dtype in ["int8", "uint8"]:
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
                         wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
                         name="conv2d_NHWC_quantized.arm_cpu",
                     )
-
-                strategy.add_implementation(
-                    wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
-                    wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
-                    name="conv2d_nhwc_spatial_pack.arm_cpu",
-                )
+                if (not is_aarch64) or (data.dtype not in ["int8", "uint8"]):
+                    # TODO

Review comment:
       assign the TODO

##########
File path: python/tvm/topi/arm_cpu/arm_utils.py
##########
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Arm target utility functions"""
+
+import tvm
+
+
+def is_fast_int8_on_arm():
+    """ Checks whether the hardware has support for fast Int8 arithmetic operations. """
+    target = tvm.target.Target.current(allow_none=False)
+    return "+v8.2a" in target.mattr and "+dotprod" in target.mattr

Review comment:
       @u99127 Could you take a look at whether this check will be sufficient in the general (and future) case?

##########
File path: python/tvm/topi/arm_cpu/conv2d_alter_op.py
##########
@@ -27,10 +27,52 @@
 from ..nn import conv2d_alter_layout
 from ..util import get_const_tuple
 from ..x86.conv2d import _get_default_config as _get_x86_default_config
+from .arm_utils import is_fast_int8_on_arm
 
 logger = logging.getLogger("topi")
 
 
+def interleave_transpose_B(inputs, data, kernel, interleave_A):

Review comment:
       Document the parameters here

##########
File path: python/tvm/topi/arm_cpu/conv2d_alter_op.py
##########
@@ -280,43 +322,36 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
         dispatch_ctx.update(target, new_workload, cfg)
         return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
     if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
-        assert (
-            data.dtype == "int8"
-            and kernel.dtype == "int8"
-            or data.dtype == "uint8"
-            and kernel.dtype == "uint8"
-        )
         assert data_layout == "NHWC" and kernel_layout == "HWIO"
         KH, KW, IC, OC = get_const_tuple(kernel.shape)
-        K = KH * KW * IC
         N = OC
-
-        tile_rows = 4
-        tile_cols = 16
-        pad_K = 0
-        pad_N = 0
-
-        if N % tile_rows != 0:
-            pad_N = tile_rows - (N % tile_rows)
-        if K % tile_cols != 0:
-            pad_K = tile_cols - (K % tile_cols)
-
-        N_padded = N + pad_N
-        K_padded = K + pad_K
-        kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_rows, tile_cols)
-        new_kernel = te.placeholder(
-            (N_padded // tile_rows, K_padded // tile_cols, tile_rows, tile_cols), kernel.dtype
-        )
-
         new_workload_name = "conv2d_NHWC_quantized_without_transform.arm_cpu"
+        new_kernel, new_kernel_expr = interleave_transpose_B(
+            inputs, data, kernel, interleave_A=True
+        )
         new_workload = autotvm.task.args_to_workload(
             [data, new_kernel, strides, padding, dilation, out_dtype, (KH, KW), OC],
             new_workload_name,
         )
         dispatch_ctx.update(target, new_workload, cfg)
 
         return relay.nn.contrib_conv2d_gemm_without_weight_transform(
-            inputs[0], kernel_expr, **new_attrs
+            inputs[0], new_kernel_expr, **new_attrs
+        )
+    if topi_tmpl == "conv2d_NHWC_quantized_hybrid.arm_cpu":
+        assert data_layout == "NHWC" and kernel_layout == "HWIO"
+        KH, KW, IC, OC = get_const_tuple(kernel.shape)
+        N = OC

Review comment:
       Not needed

##########
File path: python/tvm/topi/arm_cpu/conv2d_alter_op.py
##########
@@ -27,10 +27,52 @@
 from ..nn import conv2d_alter_layout
 from ..util import get_const_tuple
 from ..x86.conv2d import _get_default_config as _get_x86_default_config
+from .arm_utils import is_fast_int8_on_arm
 
 logger = logging.getLogger("topi")
 
 
+def interleave_transpose_B(inputs, data, kernel, interleave_A):
+    """Return the new placeholder and the expression that represent
+    the matrix B transposed and interleaved"""
+
+    assert (
+        data.dtype == "int8"
+        and kernel.dtype == "int8"
+        or data.dtype == "uint8"
+        and kernel.dtype == "uint8"
+    )
+
+    KH, KW, IC, OC = get_const_tuple(kernel.shape)
+    K = KH * KW * IC
+    N = OC
+
+    if is_fast_int8_on_arm():
+        tile_rows_B = 12 if interleave_A else 16
+        tile_cols_B = 4
+    else:
+        tile_rows_B = 4
+        tile_cols_B = 16

Review comment:
       Some additional documentation here to explain the difference would be helpful

##########
File path: python/tvm/topi/arm_cpu/arm_utils.py
##########
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Arm target utility functions"""
+
+import tvm
+
+
+def is_fast_int8_on_arm():

Review comment:
       Recommend making this more explicitly related to the dotproduct

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -74,11 +108,7 @@ def compute_conv2d_gemm_without_weight_transform(
 
     A_shape = (batches, M, K)
     if K_AREA == 1:

Review comment:
       Would clash less with later definitions if this were kernel_area

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0
 
-    if M % 4 != 0:
-        pad_m = 4 - (M % 4)
+    if M % tile_rows_A != 0:
+        pad_m = tile_rows_A - (M % tile_rows_A)
 
-    if K % 16 != 0:
-        pad_k = 16 - (K % 16)
+    if K % tile_cols_A != 0:
+        pad_k = tile_cols_A - (K % tile_cols_A)
 
     M_padded = M + pad_m
     K_padded = K + pad_k
+    N_padded = N_transformed * tile_rows_B
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_m, pad_k)
 
     if pad_m != 0 or pad_k != 0:
         A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
 
-    # --- GEMM: A*B'
+    idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute(
-        (batches, M_padded // 4, K_padded // 16, 4, 16),
-        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-        name="A_interleaved",
-    )
-
-    C_interleaved = te.compute(
-        (batches, M_padded // 4, N_transformed, 4, 4),
-        lambda b, x, y, w, z: te.sum(
-            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
-            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
-            axis=k,
-        ),
-        name="C_interleaved",
-    )
+    if interleave_A:
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded)
 
-    # --- Unpack C
-    C = te.compute(
-        (batches, M, N),
-        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-        name="C",
-    )
+        # Pack A
+        A_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
+            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
+            name="A_interleaved",
+        )
+        # Compute C

Review comment:
       Mention this is doing the GEMM

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0
 
-    if M % 4 != 0:
-        pad_m = 4 - (M % 4)
+    if M % tile_rows_A != 0:
+        pad_m = tile_rows_A - (M % tile_rows_A)
 
-    if K % 16 != 0:
-        pad_k = 16 - (K % 16)
+    if K % tile_cols_A != 0:
+        pad_k = tile_cols_A - (K % tile_cols_A)
 
     M_padded = M + pad_m
     K_padded = K + pad_k
+    N_padded = N_transformed * tile_rows_B
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_m, pad_k)
 
     if pad_m != 0 or pad_k != 0:
         A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
 
-    # --- GEMM: A*B'
+    idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute(
-        (batches, M_padded // 4, K_padded // 16, 4, 16),
-        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-        name="A_interleaved",
-    )
-
-    C_interleaved = te.compute(
-        (batches, M_padded // 4, N_transformed, 4, 4),
-        lambda b, x, y, w, z: te.sum(
-            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
-            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
-            axis=k,
-        ),
-        name="C_interleaved",
-    )
+    if interleave_A:
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded)
 
-    # --- Unpack C
-    C = te.compute(
-        (batches, M, N),
-        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-        name="C",
-    )
+        # Pack A
+        A_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
+            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
+            name="A_interleaved",
+        )
+        # Compute C
+        C_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
+            lambda b, x, y, w, z: te.sum(
+                A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
+                * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
+                axis=k,
+            ),
+            name="C_interleaved",
+        )
+        # Unpack C
+        C = te.compute(
+            (batches, M, N),
+            lambda b, x, y: C_interleaved[
+                b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
+            ].astype(out_dtype),
+            name="C",
+        )
+        zero = tvm.tir.const(0)
+    else:
+        # No need to pack/unpack
+        C = te.compute(
+            (batches, M_padded, N_padded),
+            lambda b, x, y: te.sum(
+                A[b, x, k].astype("int32")
+                * B_interleaved_t[
+                    y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B)
+                ].astype("int32"),
+                axis=k,
+            ),
+            name="C",
+        )
+        zero = (
+            tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+            - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+        )
 
     # --- Produce the conv output
     out_shape = (batches, OH, OW, OC)
-    out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z), name="conv2d_gemm_output")
-
-    # Configuration space
-    x, y = cfg.axis(M_padded // 4), cfg.axis(K_padded // 16)
-    cfg.define_reorder("reorder_gemm", [x, y], policy="candidate", candidate=[[x, y], [y, x]])
-
-    outer_loop, inner_loop = cfg.axis(4), cfg.axis(16)
-    cfg.define_annotate(
-        "A_interleaved_unroll_vec", [outer_loop, inner_loop], policy="try_unroll_vec"
+    out = te.compute(
+        out_shape,
+        lambda b, x, y, z: (C(b, y + OW * x, z) + zero).astype(out_dtype),
+        name="conv2d_gemm_output",
     )
-    cfg.define_knob("gemm_quantized_unroll", [True, False])
-    cfg.define_knob("gemm_quantized_interleave", [True, False])
-
-    # Fallback configuration
-    if cfg.is_fallback:
-        cfg["reorder_gemm"] = ReorderEntity([0, 1])
-        cfg["A_interleaved_unroll_vec"] = AnnotateEntity(["unroll", "vec"])
-        cfg["gemm_quantized_unroll"] = OtherOptionEntity(False)
-        cfg["gemm_quantized_interleave"] = OtherOptionEntity(True)
     return out
 
 
-# Schedules
 def schedule_conv2d_gemm(cfg, s, out, final_out):

Review comment:
       Rename to schedule_conv2d_gemm_interleaved

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0
 
-    if M % 4 != 0:
-        pad_m = 4 - (M % 4)
+    if M % tile_rows_A != 0:
+        pad_m = tile_rows_A - (M % tile_rows_A)
 
-    if K % 16 != 0:
-        pad_k = 16 - (K % 16)
+    if K % tile_cols_A != 0:
+        pad_k = tile_cols_A - (K % tile_cols_A)
 
     M_padded = M + pad_m
     K_padded = K + pad_k
+    N_padded = N_transformed * tile_rows_B
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_m, pad_k)
 
     if pad_m != 0 or pad_k != 0:
         A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
 
-    # --- GEMM: A*B'
+    idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute(
-        (batches, M_padded // 4, K_padded // 16, 4, 16),
-        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-        name="A_interleaved",
-    )
-
-    C_interleaved = te.compute(
-        (batches, M_padded // 4, N_transformed, 4, 4),
-        lambda b, x, y, w, z: te.sum(
-            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
-            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
-            axis=k,
-        ),
-        name="C_interleaved",
-    )
+    if interleave_A:
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded)
 
-    # --- Unpack C
-    C = te.compute(
-        (batches, M, N),
-        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-        name="C",
-    )
+        # Pack A
+        A_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
+            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
+            name="A_interleaved",
+        )
+        # Compute C
+        C_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
+            lambda b, x, y, w, z: te.sum(
+                A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
+                * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
+                axis=k,
+            ),
+            name="C_interleaved",
+        )
+        # Unpack C
+        C = te.compute(
+            (batches, M, N),
+            lambda b, x, y: C_interleaved[
+                b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
+            ].astype(out_dtype),
+            name="C",
+        )
+        zero = tvm.tir.const(0)
+    else:
+        # No need to pack/unpack
+        C = te.compute(
+            (batches, M_padded, N_padded),
+            lambda b, x, y: te.sum(
+                A[b, x, k].astype("int32")
+                * B_interleaved_t[
+                    y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B)
+                ].astype("int32"),
+                axis=k,
+            ),
+            name="C",
+        )
+        zero = (
+            tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+            - tvm.tir.const(1, C.dtype) * C[0, M_padded - 1, N_padded - 1]
+        )
 
     # --- Produce the conv output

Review comment:
       Document why this is different to the previous unpacking step.

##########
File path: python/tvm/relay/op/strategy/arm_cpu.py
##########
@@ -135,20 +135,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
                     name="conv2d_direct_simd.micro_dev",
                 )
             elif kernel_layout == "HWIO":
-                is_aarch64 = "aarch64" in str(isa.target)
-
+                is_aarch64 = topi.arm_cpu.arm_utils.is_aarch64_arm()
+                has_dot_prod = topi.arm_cpu.arm_utils.is_fast_int8_on_arm()
+                if has_dot_prod and data.dtype in ["int8", "uint8"]:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized_hybrid),
+                        wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_hybrid),
+                        name="conv2d_NHWC_quantized_hybrid.arm_cpu",
+                    )
                 if is_aarch64 and data.dtype in ["int8", "uint8"]:
                     strategy.add_implementation(
                         wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),

Review comment:
       Maybe change to quantized_interleaved to align with later definitions.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,223 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(i) = v[i:i+3]
+    replicated_v(i) = [vsub(i), vsub(i), vsub(i), vsub(i)]
+
+    Note that i can vary between 0 and 3

Review comment:
       0 <= i <= 12

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,223 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):

Review comment:
       Document params

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16

Review comment:
       Comment on magic numbers

##########
File path: python/tvm/topi/arm_cpu/conv2d_int8.py
##########
@@ -41,6 +46,20 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
         )
 
 
+def get_tiling_B(interleave_A):

Review comment:
       Move to utils and use in alter_op_layout

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0

Review comment:
       Maybe capitalize m and k here to match with later convention

##########
File path: python/tvm/topi/arm_cpu/conv2d_gemm.py
##########
@@ -90,82 +120,100 @@ def compute_conv2d_gemm_without_weight_transform(
             ],
             name="data_im2col",
         )
-    N_transformed = B_interleaved_t.shape[0]
 
     # --- Pad if necessary
-    idxm = tvm.tir.indexmod
+    N_transformed = B_interleaved_t.shape[0]
+    tile_rows_B = B_interleaved_t.shape[2]
+    tile_cols_B = B_interleaved_t.shape[3]
+
+    if is_fast_int8_on_arm() and interleave_A:
+        tile_rows_A = 8
+        tile_cols_A = 4
+    else:
+        tile_rows_A = 4
+        tile_cols_A = 16
 
     pad_m = 0
     pad_k = 0
 
-    if M % 4 != 0:
-        pad_m = 4 - (M % 4)
+    if M % tile_rows_A != 0:
+        pad_m = tile_rows_A - (M % tile_rows_A)
 
-    if K % 16 != 0:
-        pad_k = 16 - (K % 16)
+    if K % tile_cols_A != 0:
+        pad_k = tile_cols_A - (K % tile_cols_A)
 
     M_padded = M + pad_m
     K_padded = K + pad_k
+    N_padded = N_transformed * tile_rows_B
 
     pad_before = (0, 0, 0)
     pad_after = (0, pad_m, pad_k)
 
     if pad_m != 0 or pad_k != 0:
         A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
 
-    # --- GEMM: A*B'
+    idxm = tvm.tir.indexmod
     k = te.reduce_axis((0, K_padded), "k")
 
-    A_interleaved = te.compute(
-        (batches, M_padded // 4, K_padded // 16, 4, 16),
-        lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
-        name="A_interleaved",
-    )
-
-    C_interleaved = te.compute(
-        (batches, M_padded // 4, N_transformed, 4, 4),
-        lambda b, x, y, w, z: te.sum(
-            A_interleaved[b, x, k // 16, w, idxm(k, 16)].astype(out_dtype)
-            * B_interleaved_t[y, k // 16, z, idxm(k, 16)].astype(out_dtype),
-            axis=k,
-        ),
-        name="C_interleaved",
-    )
+    if interleave_A:
+        # Configuration space
+        configure_knobs(cfg, M_padded, K_padded)
 
-    # --- Unpack C
-    C = te.compute(
-        (batches, M, N),
-        lambda b, x, y: C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
-        name="C",
-    )
+        # Pack A
+        A_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, K_padded // tile_cols_A, tile_rows_A, tile_cols_A),
+            lambda b, x, y, z, w: A[b, z + tile_rows_A * x, w + tile_cols_A * y],
+            name="A_interleaved",
+        )
+        # Compute C
+        C_interleaved = te.compute(
+            (batches, M_padded // tile_rows_A, N_transformed, tile_rows_A, tile_rows_B),
+            lambda b, x, y, w, z: te.sum(
+                A_interleaved[b, x, k // tile_cols_A, w, idxm(k, tile_cols_A)].astype("int32")
+                * B_interleaved_t[y, k // tile_cols_B, z, idxm(k, tile_cols_B)].astype("int32"),
+                axis=k,
+            ),
+            name="C_interleaved",
+        )
+        # Unpack C
+        C = te.compute(
+            (batches, M, N),
+            lambda b, x, y: C_interleaved[
+                b, x // tile_rows_A, y // tile_rows_B, idxm(x, tile_rows_A), idxm(y, tile_rows_B)
+            ].astype(out_dtype),
+            name="C",
+        )
+        zero = tvm.tir.const(0)
+    else:
+        # No need to pack/unpack
+        C = te.compute(
+            (batches, M_padded, N_padded),
+            lambda b, x, y: te.sum(
+                A[b, x, k].astype("int32")
+                * B_interleaved_t[
+                    y // tile_rows_B, k // tile_cols_B, idxm(y, tile_rows_B), idxm(k, tile_cols_B)
+                ].astype("int32"),
+                axis=k,
+            ),
+            name="C",
+        )
+        zero = (

Review comment:
       Document this trick




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] giuseros commented on pull request #6445: Add dot product support for quantized convolution.

Posted by GitBox <gi...@apache.org>.
giuseros commented on pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#issuecomment-694944064


   Mmmm the last failure seems like a tolerance issue. Let me retrigger the CI


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org