You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/02 15:47:41 UTC

[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6445: Add dot product support for quantized convolution.

mbaret commented on a change in pull request #6445:
URL: https://github.com/apache/incubator-tvm/pull/6445#discussion_r498880238



##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and

Review comment:
       Might be useful to include a note that n = rows?

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((rows, 16), dtype, name="data")
+    B = te.placeholder((4, 16, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+    idxm = tvm.tir.indexmod
+    k = te.reduce_axis((0, 16), name="k")
+    C = te.compute(
+        (rows, 16),
+        lambda i, j: te.sum(
+            A[i, k].astype("int32") * B[k // 4, j, idxm(k, 4)].astype("int32"), axis=k
+        ),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape,
+        dtype,
+        name="bb_buffer",
+        offset_factor=1,
+        strides=[te.var("sb0"), te.var("sb1"), 1],
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, rows):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16")))
+                return ib.get()
+            # Iterate on the number of rows of the output
+            for k in range(0, rows):
+                # Load 16 elements of A
+                # vec_a = [a, b, c, e, f, g, h, i, l, m, n, o, p, q, r,];

Review comment:
       Avoid 'i' as well I think

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((rows, 16), dtype, name="data")
+    B = te.placeholder((4, 16, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+    idxm = tvm.tir.indexmod
+    k = te.reduce_axis((0, 16), name="k")
+    C = te.compute(
+        (rows, 16),
+        lambda i, j: te.sum(
+            A[i, k].astype("int32") * B[k // 4, j, idxm(k, 4)].astype("int32"), axis=k
+        ),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape,
+        dtype,
+        name="bb_buffer",
+        offset_factor=1,
+        strides=[te.var("sb0"), te.var("sb1"), 1],
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, rows):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16")))
+                return ib.get()
+            # Iterate on the number of rows of the output
+            for k in range(0, rows):
+                # Load 16 elements of A
+                # vec_a = [a, b, c, e, f, g, h, i, l, m, n, o, p, q, r,];
+                vec_a = ins[0].vload([k, 0], dtype_vec)
+
+                # Iterate over each column of the output

Review comment:
       Revise this comment (output columns = 16)

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((rows, 16), dtype, name="data")
+    B = te.placeholder((4, 16, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+    idxm = tvm.tir.indexmod
+    k = te.reduce_axis((0, 16), name="k")
+    C = te.compute(
+        (rows, 16),
+        lambda i, j: te.sum(
+            A[i, k].astype("int32") * B[k // 4, j, idxm(k, 4)].astype("int32"), axis=k
+        ),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape,
+        dtype,
+        name="bb_buffer",
+        offset_factor=1,
+        strides=[te.var("sb0"), te.var("sb1"), 1],
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, rows):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16")))
+                return ib.get()
+            # Iterate on the number of rows of the output
+            for k in range(0, rows):
+                # Load 16 elements of A
+                # vec_a = [a, b, c, e, f, g, h, i, l, m, n, o, p, q, r,];
+                vec_a = ins[0].vload([k, 0], dtype_vec)
+
+                # Iterate over each column of the output
+                for j in range(0, 4):
+                    # Accumulate over each of the 4 (16x4) tiles contained in B
+                    for i in range(0, 4):
+                        # As before, replicate a single 4-element group of A
+                        vec_aa = select_word(vec_a, i, dtype_vec)
+                        # Load 4 rows (each rows with 4 elements) from B
+                        # vec_b = [0, 16, 32, 48,

Review comment:
       Show the multiplication between vec_a and vec_b

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]

Review comment:
       Motivate this with reference to the registers.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")

Review comment:
       Rename 'name' too

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    A = te.placeholder((rows, 16), dtype, name="data")
+    B = te.placeholder((4, 16, 4), dtype, name="kernel")

Review comment:
       Rename 'name' too

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions

Review comment:
       16x4 -> nx16

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];

Review comment:
       ijk

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -589,6 +587,287 @@ def _instr(index):
     )
 
 
+def select_word(vec, lane, dtype_vec):
+    """
+    Utility function used to select a int8x4 word within a int8x16 vector
+    and replicate 4 times.
+    The pseudo-code for this operation is:
+
+    v = [x0, ..., x15]
+    vsub(lane) = v[4*lane:4*lane+3]
+    replicated_v(lane) = [vsub(lane), vsub(lane), vsub(lane), vsub(lane)]
+
+    Note that 0<=lane<4
+
+     Parameters
+    ----------
+    vec: tvm.tir.Expr
+         int8x16 vector expression
+    lane: int
+        vector lane we want to replicate
+    dtype_vec: str
+        vector data type (e.g., int8x16)
+
+    Returns
+    ----------
+    output: tvm.tir.Expr
+        replicated vector
+    """
+    # Reinterpret vec_a as 4 int32 words
+    vec_int32 = tvm.tir.call_intrin("int32x4", "tir.reinterpret", vec)
+    # Broadcast the lane-th word
+    vec_int32_shuffled = tvm.tir.Shuffle([vec_int32], [lane, lane, lane, lane])
+    # Convert back to uint8x16
+    vec_int8_broadcast = tvm.tir.call_intrin(dtype_vec, "tir.reinterpret", vec_int32_shuffled)
+    return vec_int8_broadcast
+
+
+def gemm_acc_4x4_int8_int8_int32(dtype):
+    """
+    Int8 4x4 matrix multiplication and accumulation using sdot/udot
+    instructions. This function takes two arrays of int8 datatype
+    -- A[4][4] and B[4][4] and produces a 4x4 matrix
+    which is equal to A*B.
+
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){
+            for (int i = 0; i < 4; i++){
+                for (int j = 0; i < 4; i++){
+                    for (int k = 0; k < 4; k++){
+                        C[i][j] += A[i][k] * B[j][k]
+                    }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+    This function returns a TensorIntrin that can be used to tensorize a schedule.
+
+    Parameters
+    ----------
+    dtype: str, {"uint8", "int8"}
+        Whether it works on unsigned int or signed int
+
+    Returns
+    -------
+    intrin : TensorIntrin
+        The Arm TensorIntrin that can be used in tensorizing schedule
+    """
+    # This needs to be a variable number of "rows" since TVM
+    # "thinks" I only need to compute one row because of
+    # padding
+    A = te.placeholder((te.var("rows"), 4), dtype, name="data")
+    B = te.placeholder((4, 4), dtype, name="kernel")
+    dtype_vec = dtype + "x16"
+
+    k = te.reduce_axis((0, 4), name="k")
+    C = te.compute(
+        (te.var("rows"), 4),
+        lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k),
+        name="C",
+    )
+
+    aa_buffer = tvm.tir.decl_buffer(
+        A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1]
+    )
+    bb_buffer = tvm.tir.decl_buffer(
+        B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1]
+    )
+    cc_buffer = tvm.tir.decl_buffer(
+        C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1]
+    )
+
+    llvm_intrin = "llvm.aarch64.neon.sdot" if dtype == "int8" else "llvm.aarch64.neon.udot"
+
+    def _intrin_func(ins, outs):
+        def _instr(index):
+            ib = tvm.tir.ir_builder.create()
+            if index == 1:
+                for i in range(0, 4):
+                    ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4")))
+                return ib.get()
+            # Load all the elements of tile A.
+            # vec_a = [a, b, c, d,
+            #          e, f, g, h,
+            #          i, l, m, n,
+            #          o, p, q, r,];
+            vec_a = ins[0].vload([0, 0], dtype_vec)
+
+            # Replicate 4 times the i-th row of A. For instance,
+            # vec_a[0] = [a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,
+            #             a, b, c, d,];
+            vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)]
+
+            # Load all the elements of B. Remember that B
+            # is transposed:
+            # vec_b = [0, 4, 8, 12,
+            #          1, 5, 9, 13,
+            #          2, 6, 10, 14,
+            #          3, 7, 11, 15,];
+            vec_b = ins[1].vload([0, 0], dtype_vec)
+
+            # Execute the dot product
+            for i in range(0, 4):
+                vec_c = outs[0].vload([i, 0], "int32x4")
+                # Compute the product between the i-th row of A
+                # and all the rows of B. Remember that sdot/udot
+                # subdive the input vectors in 16 elements
+                # and then take the dot product among each group.
+                # The result is stored in a int32x4 register
+                #
+                # For instance, for i=0, we have:
+                # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12,
+                #                           a*1+b*5+c*9+d*13,
+                #                           a*2+b*6+c*10+d*14,
+                #                           a*3+b*7+c*11+d*15]
+                vdot = tvm.tir.call_llvm_intrin(
+                    "int32x4",
+                    llvm_intrin,
+                    tvm.tir.const(3, "uint32"),
+                    vec_c,
+                    vec_b,
+                    vec_aa[i],
+                )
+
+                # Store the result
+                ib.emit(outs[0].vstore([i, 0], vdot))
+
+            return ib.get()
+
+        # body, reset, update
+        return _instr(0), _instr(1), _instr(2)
+
+    buffer_params = {"offset_factor": 1}
+    return te.decl_tensor_intrin(
+        C.op,
+        _intrin_func,
+        binds={A: aa_buffer, B: bb_buffer, C: cc_buffer},
+        default_buffer_params=buffer_params,
+    )
+
+
+def gemm_acc_nx16_int8_int8_int32(dtype, rows):
+    """
+    Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions
+    This function takes two arrays of int8 datatype -- A[rows][4] and
+    B[4][16] and produces a rowsx16 matrix which is equal to A*B
+    The pseudo code is as follows.
+
+    .. code-block:: c
+
+        void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){
+            for (int i = 0; i < rows; i++){
+                for (int j = 0; i < 16; i++){
+                    for (int k = 0; k < 16; k++){
+                        out[i][j] += A[i][k] * B[k//4][j][k%4]
+                    }
+                }
+            }
+        }
+
+    Notes:
+        * The rows of matrix B are transposed
+        * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16
+          we need 4 tiles of B to compute a single row of the output. The first 4 values of
+          k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on
+
+    This function returns a TensorIntrin that can be used to tensorize a schedule.

Review comment:
       Remove




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org