You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/17 16:40:10 UTC

[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6907: AArch64 base algorithm refactoring in LLVM

mbaret commented on a change in pull request #6907:
URL: https://github.com/apache/incubator-tvm/pull/6907#discussion_r525292288



##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -447,28 +88,280 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
         C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
     )
 
+    # Intrinsics used in the following algorithm
+    umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull"
+    uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp"
+    addp_intrin = "llvm.aarch64.neon.addp"
+
+    def uadalp(a, b):
+        """Add pair and accumulate
+
+        Parameters:
+        ----------
+        a: int16x8 vector
+        b: int16x8 vector
+
+        Returns:
+        --------
+            return a int32x4 vector
+
+        Pseudocode:
+        ----------
+            a += (b0+b1, b2+b3, b4+b5, b6+b7)
+        """
+
+        return a + tvm.tir.call_llvm_pure_intrin(
+            "int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b
+        )
+
+    def umull(a, b):
+        """Multiply long (lower part)

Review comment:
       Probably call this 'higher part' just for consistency with tir.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -447,28 +88,280 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
         C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
     )
 
+    # Intrinsics used in the following algorithm
+    umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull"
+    uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp"
+    addp_intrin = "llvm.aarch64.neon.addp"
+
+    def uadalp(a, b):
+        """Add pair and accumulate
+
+        Parameters:
+        ----------
+        a: int16x8 vector
+        b: int16x8 vector
+
+        Returns:
+        --------
+            return a int32x4 vector
+
+        Pseudocode:
+        ----------
+            a += (b0+b1, b2+b3, b4+b5, b6+b7)
+        """
+
+        return a + tvm.tir.call_llvm_pure_intrin(
+            "int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b
+        )
+
+    def umull(a, b):
+        """Multiply long (lower part)
+
+        Parameters:
+        ----------
+        a: int8x16 vector
+        b: int8x16 vector
+
+        Returns:
+        --------
+            return a int16x8 vector
+
+        Pseudocode:
+        ----------
+            c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7)
+        """
+        a_low = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a)
+        b_low = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b)
+        c = tvm.tir.call_llvm_pure_intrin(
+            "int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low
+        )
+        return c
+
+    def umull2(a, b):
+        """Multiply long (uppoer part)

Review comment:
       typo + see above

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -447,28 +88,280 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
         C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
     )
 
+    # Intrinsics used in the following algorithm
+    umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull"
+    uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp"
+    addp_intrin = "llvm.aarch64.neon.addp"
+
+    def uadalp(a, b):
+        """Add pair and accumulate
+
+        Parameters:
+        ----------
+        a: int16x8 vector
+        b: int16x8 vector
+
+        Returns:
+        --------
+            return a int32x4 vector
+
+        Pseudocode:
+        ----------
+            a += (b0+b1, b2+b3, b4+b5, b6+b7)
+        """
+
+        return a + tvm.tir.call_llvm_pure_intrin(
+            "int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b
+        )
+
+    def umull(a, b):
+        """Multiply long (lower part)
+
+        Parameters:
+        ----------
+        a: int8x16 vector
+        b: int8x16 vector
+
+        Returns:
+        --------
+            return a int16x8 vector
+
+        Pseudocode:
+        ----------
+            c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7)
+        """
+        a_low = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a)
+        b_low = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b)
+        c = tvm.tir.call_llvm_pure_intrin(
+            "int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low
+        )
+        return c
+
+    def umull2(a, b):
+        """Multiply long (uppoer part)
+
+        Parameters:
+        ----------
+        a: int8x16 vector
+        b: int8x16 vector
+
+        Returns:
+        --------
+            return a int16x8 vector
+
+        Pseudocode:
+        ----------
+            c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15)
+        """
+        a_high = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a)
+        b_high = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b)
+        c = tvm.tir.call_llvm_pure_intrin(
+            "int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high
+        )
+        return c
+
+    def addp(a, b):
+        """Add two vectors in pairs
+
+        Parameters:
+        ----------
+        a: int32x4 vector
+        b: int32x4 vector
+
+        Returns:
+        --------
+            return a int32x4 vector
+
+        Pseudocode:
+        ----------
+            c = (a0+a1, a2+a3, b0+b1, b0+b3)
+        """
+        return tvm.tir.call_llvm_pure_intrin(
+            "int32x4", addp_intrin, tvm.tir.const(2, "uint32"), a, b
+        )
+
+    def accumulation_loop(M, N, ins, acc, i):
+        a0 = ins[0].vload([i, 0, 0], dtype_vec)
+        a1 = tvm.tir.const(0, "int8x16")
+        if M > 1:
+            a1 = ins[0].vload([i, 1, 0], dtype_vec)
+        a2 = tvm.tir.const(0, "int8x16")
+        if M > 2:
+            a2 = ins[0].vload([i, 2, 0], dtype_vec)
+        a3 = tvm.tir.const(0, "int8x16")
+        if M > 3:
+            a3 = ins[0].vload([i, 3, 0], dtype_vec)
+
+        b0 = ins[1].vload([i, 0, 0], dtype_vec)
+        b1 = tvm.tir.const(0, "int8x16")
+        if N > 1:
+            b1 = ins[1].vload([i, 1, 0], dtype_vec)
+        b2 = tvm.tir.const(0, "int8x16")
+        if N > 2:
+            b2 = ins[1].vload([i, 2, 0], dtype_vec)
+        b3 = tvm.tir.const(0, "int8x16")
+        if N > 3:
+            b3 = ins[1].vload([i, 3, 0], dtype_vec)
+
+        # First half
+        # Lower part of a0 * {b0,b1,b2,b3}
+        d00 = umull(a0, b0)
+        d01 = umull(a0, b1)
+        d02 = umull(a0, b2)
+        d03 = umull(a0, b3)
+
+        # Lower part of a1 * {b0,b1,b2,b3}
+        d10 = umull(a1, b0)
+        d11 = umull(a1, b1)
+        d12 = umull(a1, b2)
+        d13 = umull(a1, b3)
+
+        # Accumulate
+        acc[0] = uadalp(acc[0], d00)
+        acc[1] = uadalp(acc[1], d01)
+        acc[2] = uadalp(acc[2], d02)
+        acc[3] = uadalp(acc[3], d03)
+        acc[4] = uadalp(acc[4], d10)
+        acc[5] = uadalp(acc[5], d11)
+        acc[6] = uadalp(acc[6], d12)
+        acc[7] = uadalp(acc[7], d13)
+
+        # Higher part of a0 * {b0,b1,b2,b3}
+        d00 = umull2(a0, b0)
+        d01 = umull2(a0, b1)
+        d02 = umull2(a0, b2)
+        d03 = umull2(a0, b3)
+
+        # Higher part of a1 * {b0,b1,b2,b3}
+        d10 = umull2(a1, b0)
+        d11 = umull2(a1, b1)
+        d12 = umull2(a1, b2)
+        d13 = umull2(a1, b3)
+
+        # Accumulate again
+        acc[0] = uadalp(acc[0], d00)
+        acc[1] = uadalp(acc[1], d01)
+        acc[2] = uadalp(acc[2], d02)
+        acc[3] = uadalp(acc[3], d03)
+        acc[4] = uadalp(acc[4], d10)
+        acc[5] = uadalp(acc[5], d11)
+        acc[6] = uadalp(acc[6], d12)
+        acc[7] = uadalp(acc[7], d13)
+
+        # Second half
+        # Lower part of a2 * {b0,b1,b2,b3}
+        d00 = umull(a2, b0)
+        d01 = umull(a2, b1)
+        d02 = umull(a2, b2)
+        d03 = umull(a2, b3)
+
+        # Lower part of a3 * {b0,b1,b2,b3}
+        d10 = umull(a3, b0)
+        d11 = umull(a3, b1)
+        d12 = umull(a3, b2)
+        d13 = umull(a3, b3)
+
+        # Accumulate
+        acc[8] = uadalp(acc[8], d00)
+        acc[9] = uadalp(acc[9], d01)
+        acc[10] = uadalp(acc[10], d02)
+        acc[11] = uadalp(acc[11], d03)
+        acc[12] = uadalp(acc[12], d10)
+        acc[13] = uadalp(acc[13], d11)
+        acc[14] = uadalp(acc[14], d12)
+        acc[15] = uadalp(acc[15], d13)
+
+        # Higher part of a2 * {b0,b1,b2,b3}
+        d00 = umull2(a2, b0)
+        d01 = umull2(a2, b1)
+        d02 = umull2(a2, b2)
+        d03 = umull2(a2, b3)
+
+        # Lower part of a3 * {b0,b1,b2,b3}
+        d10 = umull2(a3, b0)
+        d11 = umull2(a3, b1)
+        d12 = umull2(a3, b2)
+        d13 = umull2(a3, b3)
+
+        # Accumulate
+        acc[8] = uadalp(acc[8], d00)
+        acc[9] = uadalp(acc[9], d01)
+        acc[10] = uadalp(acc[10], d02)
+        acc[11] = uadalp(acc[11], d03)
+        acc[12] = uadalp(acc[12], d10)
+        acc[13] = uadalp(acc[13], d11)
+        acc[14] = uadalp(acc[14], d12)
+        acc[15] = uadalp(acc[15], d13)
+
     def _intrin_func(ins, outs):
         def _instr():
             ib = tvm.tir.ir_builder.create()
-            aa, bb = ins
-            cc = outs[0]
-            stepA = min(4, M)
-            stepB = min(4, N)
-            intrin_name = "gemm_quantized_{0}_{0}_int32_{1}_{2}".format(in_type, stepA, stepB)
+            # Allocate a local buffer (possibly translates to registers)
+            acc = ib.allocate("int32x4", 16, name="accs", scope="local")
+            m = outs[0].shape[0]
+            n = outs[0].shape[1]
+            # Initialization
+            for i in range(0, 16):
+                acc[i] = tvm.tir.const(0, "int32x4")
+
             if unroll:
-                intrin_name += "_" + str(K)
-            if interleave:
-                intrin_name += "_interleaved"
-            ib.emit(
-                tvm.tir.call_extern(
-                    "int32",
-                    intrin_name,
-                    outs[0].access_ptr("w"),
-                    a_buffer.access_ptr("r"),
-                    b_buffer.access_ptr("r"),
-                    K,
-                )
-            )
+                for i in range(0, int(K // 16)):
+                    accumulation_loop(M, N, ins, acc, i)
+            else:
+                with ib.for_range(0, K // 16, name="i") as i:
+                    accumulation_loop(M, N, ins, acc, i)
+
+            # Final accumulations
+            # acc[i] contains the partial sums of a[i, 0:K].*b[0,0:K], let's call them (a,b,c,d)

Review comment:
       Need an explicit definition of 'i' here, probably use a different letter?

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -447,28 +88,280 @@ def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
         C.shape, dtype="int32", name="c_buffer", offset_factor=1, strides=[te.var("sc"), 1]
     )
 
+    # Intrinsics used in the following algorithm
+    umull_intrin = "llvm.aarch64.neon.umull" if in_type == "uint8" else "llvm.aarch64.neon.smull"
+    uaddlp_intrin = "llvm.aarch64.neon.uaddlp" if in_type == "uint8" else "llvm.aarch64.neon.saddlp"
+    addp_intrin = "llvm.aarch64.neon.addp"
+
+    def uadalp(a, b):
+        """Add pair and accumulate
+
+        Parameters:
+        ----------
+        a: int16x8 vector
+        b: int16x8 vector
+
+        Returns:
+        --------
+            return a int32x4 vector
+
+        Pseudocode:
+        ----------
+            a += (b0+b1, b2+b3, b4+b5, b6+b7)
+        """
+
+        return a + tvm.tir.call_llvm_pure_intrin(
+            "int32x4", uaddlp_intrin, tvm.tir.const(1, "uint32"), b
+        )
+
+    def umull(a, b):
+        """Multiply long (lower part)
+
+        Parameters:
+        ----------
+        a: int8x16 vector
+        b: int8x16 vector
+
+        Returns:
+        --------
+            return a int16x8 vector
+
+        Pseudocode:
+        ----------
+            c = (a0*b0, a1*b1, a2*b2, a3*b3, a4*b4, a5*b5, a6*b6, a7*b7)
+        """
+        a_low = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", a)
+        b_low = tvm.tir.call_intrin("int8x8", "tir.vectorhigh", b)
+        c = tvm.tir.call_llvm_pure_intrin(
+            "int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_low, b_low
+        )
+        return c
+
+    def umull2(a, b):
+        """Multiply long (uppoer part)
+
+        Parameters:
+        ----------
+        a: int8x16 vector
+        b: int8x16 vector
+
+        Returns:
+        --------
+            return a int16x8 vector
+
+        Pseudocode:
+        ----------
+            c = (a8*b8, a9*b9, a10*b10, a11*b11, a12*b12, a13*b13, a14*b14, a15*b15)
+        """
+        a_high = tvm.tir.call_intrin("int8x8", "tir.vectorlow", a)
+        b_high = tvm.tir.call_intrin("int8x8", "tir.vectorlow", b)
+        c = tvm.tir.call_llvm_pure_intrin(
+            "int16x8", umull_intrin, tvm.tir.const(2, "uint32"), a_high, b_high
+        )
+        return c
+
+    def addp(a, b):
+        """Add two vectors in pairs
+
+        Parameters:
+        ----------
+        a: int32x4 vector
+        b: int32x4 vector
+
+        Returns:
+        --------
+            return a int32x4 vector
+
+        Pseudocode:
+        ----------
+            c = (a0+a1, a2+a3, b0+b1, b0+b3)
+        """
+        return tvm.tir.call_llvm_pure_intrin(
+            "int32x4", addp_intrin, tvm.tir.const(2, "uint32"), a, b
+        )
+
+    def accumulation_loop(M, N, ins, acc, i):

Review comment:
       This deserves a doc string.

##########
File path: python/tvm/tir/ir_builder.py
##########
@@ -103,7 +103,7 @@ def __getitem__(self, index):
         index = self._linear_index(index)
         if t.lanes > 1:
             base = index * t.lanes
-            index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
+            index = _expr.Ramp(base, 1, t.lanes)

Review comment:
       Probably best as a separate review with a failing/fixed test case.

##########
File path: python/tvm/topi/arm_cpu/tensor_intrin.py
##########
@@ -19,369 +19,9 @@
 
 import tvm
 from tvm import te
-from tvm.contrib import utils, clang
-
-
-def gemm_quantized_4_4_batched():
-    return """
-           // First half
-           // Higher part of a0 * {b0,b1,b2,b3}
-           "umull v8.8h, v0.8b, v4.8b\\n"
-           "umull v9.8h, v0.8b, v5.8b\\n"
-           "umull v10.8h, v0.8b, v6.8b\\n"
-           "umull v11.8h, v0.8b, v7.8b\\n"
-
-           // Higher part of a1 * {b0,b1,b2,b3}
-           "umull v12.8h, v1.8b, v4.8b\\n"
-           "umull v13.8h, v1.8b, v5.8b\\n"
-           "umull v14.8h, v1.8b, v6.8b\\n"
-           "umull v15.8h, v1.8b, v7.8b\\n"
-
-           // Accumulate
-           "uadalp v16.4s, v8.8h\\n"
-           "uadalp v17.4s, v9.8h\\n"
-           "uadalp v18.4s, v10.8h\\n"
-           "uadalp v19.4s, v11.8h\\n"
-           "uadalp v20.4s, v12.8h\\n"
-           "uadalp v21.4s, v13.8h\\n"
-           "uadalp v22.4s, v14.8h\\n"
-           "uadalp v23.4s, v15.8h\\n"
-
-           // Lower part of a0 * {b0,b1,b2,b3}
-           "umull2 v8.8h, v0.16b, v4.16b\\n"
-           "umull2 v9.8h, v0.16b, v5.16b\\n"
-           "umull2 v10.8h, v0.16b, v6.16b\\n"
-           "umull2 v11.8h, v0.16b, v7.16b\\n"
-
-           // Lower part of a1 * {b0,b1,b2,b3}
-           "umull2 v12.8h, v1.16b, v4.16b\\n"
-           "umull2 v13.8h, v1.16b, v5.16b\\n"
-           "umull2 v14.8h, v1.16b, v6.16b\\n"
-           "umull2 v15.8h, v1.16b, v7.16b\\n"
-
-            // Accumulate again
-           "uadalp v16.4s, v8.8h\\n"
-           "uadalp v17.4s, v9.8h\\n"
-           "uadalp v18.4s, v10.8h\\n"
-           "uadalp v19.4s, v11.8h\\n"
-           "uadalp v20.4s, v12.8h\\n"
-           "uadalp v21.4s, v13.8h\\n"
-           "uadalp v22.4s, v14.8h\\n"
-           "uadalp v23.4s, v15.8h\\n"
-
-           // Second half
-           // Lower part of a2 * {b0,b1,b2,b3}
-           "umull v8.8h, v2.8b, v4.8b\\n"
-           "umull v9.8h, v2.8b, v5.8b\\n"
-           "umull v10.8h, v2.8b, v6.8b\\n"
-           "umull v11.8h, v2.8b, v7.8b\\n"
-
-           // Lower part of a3 * {b0,b1,b2,b3}
-           "umull v12.8h, v3.8b, v4.8b\\n"
-           "umull v13.8h, v3.8b, v5.8b\\n"
-           "umull v14.8h, v3.8b, v6.8b\\n"
-           "umull v15.8h, v3.8b, v7.8b\\n"
-
-           // Accumulate
-           "uadalp v24.4s, v8.8h\\n"
-           "uadalp v25.4s, v9.8h\\n"
-           "uadalp v26.4s, v10.8h\\n"
-           "uadalp v27.4s, v11.8h\\n"
-           "uadalp v28.4s, v12.8h\\n"
-           "uadalp v29.4s, v13.8h\\n"
-           "uadalp v30.4s, v14.8h\\n"
-           "uadalp v31.4s, v15.8h\\n"
-
-           // Higher part of a2 * {b0,b1,b2,b3}
-           "umull2 v8.8h, v2.16b, v4.16b\\n"
-           "umull2 v9.8h, v2.16b, v5.16b\\n"
-           "umull2 v10.8h, v2.16b, v6.16b\\n"
-           "umull2 v11.8h, v2.16b, v7.16b\\n"
-
-           // Higher part of a3 * {b0,b1,b2,b3}
-           "umull2 v12.8h, v3.16b, v4.16b\\n"
-           "umull2 v13.8h, v3.16b, v5.16b\\n"
-           "umull2 v14.8h, v3.16b, v6.16b\\n"
-           "umull2 v15.8h, v3.16b, v7.16b\\n"
-
-           // Accumulate again
-           "uadalp v24.4s, v8.8h\\n"
-           "uadalp v25.4s, v9.8h\\n"
-           "uadalp v26.4s, v10.8h\\n"
-           "uadalp v27.4s, v11.8h\\n"
-           "uadalp v28.4s, v12.8h\\n"
-           "uadalp v29.4s, v13.8h\\n"
-           "uadalp v30.4s, v14.8h\\n"
-           "uadalp v31.4s, v15.8h\\n"
-    """
-
-
-def gemm_quantized_4_4_interleaved():
-    return """
-             // First half
-             // Higher part of a0 * {b0,b1,b2,b3} and accumulate
-             "umull v8.8h, v0.8b, v4.8b\\n"
-             "uadalp v16.4s, v8.8h\\n"
-             "umull v9.8h, v0.8b, v5.8b\\n"
-             "uadalp v17.4s, v9.8h\\n"
-             "umull v10.8h, v0.8b, v6.8b\\n"
-             "uadalp v18.4s, v10.8h\\n"
-             "umull v11.8h, v0.8b, v7.8b\\n"
-             "uadalp v19.4s, v11.8h\\n"
-
-             // Higher part of a1 * {b0,b1,b2,b3} and accumulate
-             "umull v12.8h, v1.8b, v4.8b\\n"
-             "uadalp v20.4s, v12.8h\\n"
-             "umull v13.8h, v1.8b, v5.8b\\n"
-             "uadalp v21.4s, v13.8h\\n"
-             "umull v14.8h, v1.8b, v6.8b\\n"
-             "uadalp v22.4s, v14.8h\\n"
-             "umull v15.8h, v1.8b, v7.8b\\n"
-             "uadalp v23.4s, v15.8h\\n"
-
-             // Lower part of a0 * {b0,b1,b2,b3} and accumulate
-             "umull2 v8.8h, v0.16b, v4.16b\\n"
-             "uadalp v16.4s, v8.8h\\n"
-             "umull2 v9.8h, v0.16b, v5.16b\\n"
-             "uadalp v17.4s, v9.8h\\n"
-             "umull2 v10.8h, v0.16b, v6.16b\\n"
-             "uadalp v18.4s, v10.8h\\n"
-             "umull2 v11.8h, v0.16b, v7.16b\\n"
-             "uadalp v19.4s, v11.8h\\n"
-
-             // Lower part of a1 * {b0,b1,b2,b3} and accumulate
-             "umull2 v12.8h, v1.16b, v4.16b\\n"
-             "uadalp v20.4s, v12.8h\\n"
-             "umull2 v13.8h, v1.16b, v5.16b\\n"
-             "uadalp v21.4s, v13.8h\\n"
-             "umull2 v14.8h, v1.16b, v6.16b\\n"
-             "uadalp v22.4s, v14.8h\\n"
-             "umull2 v15.8h, v1.16b, v7.16b\\n"
-             "uadalp v23.4s, v15.8h\\n"
-
-             // Second half
-             // Higher part of a2 * {b0,b1,b2,b3} and accumulate
-             "umull v8.8h, v2.8b, v4.8b\\n"
-             "uadalp v24.4s, v8.8h\\n"
-             "umull v9.8h, v2.8b, v5.8b\\n"
-             "uadalp v25.4s, v9.8h\\n"
-             "umull v10.8h, v2.8b, v6.8b\\n"
-             "uadalp v26.4s, v10.8h\\n"
-             "umull v11.8h, v2.8b, v7.8b\\n"
-             "uadalp v27.4s, v11.8h\\n"
-
-             // Higher part of a3 * {b0,b1,b2,b3} and accumulate
-             "umull v12.8h, v3.8b, v4.8b\\n"
-             "uadalp v28.4s, v12.8h\\n"
-             "umull v13.8h, v3.8b, v5.8b\\n"
-             "uadalp v29.4s, v13.8h\\n"
-             "umull v14.8h, v3.8b, v6.8b\\n"
-             "uadalp v30.4s, v14.8h\\n"
-             "umull v15.8h, v3.8b, v7.8b\\n"
-             "uadalp v31.4s, v15.8h\\n"
-
-             // Lower part of a2 * {b0,b1,b2,b3} and accumulate
-             "umull2 v8.8h, v2.16b, v4.16b\\n"
-             "uadalp v24.4s, v8.8h\\n"
-             "umull2 v9.8h, v2.16b, v5.16b\\n"
-             "uadalp v25.4s, v9.8h\\n"
-             "umull2 v10.8h, v2.16b, v6.16b\\n"
-             "uadalp v26.4s, v10.8h\\n"
-             "umull2 v11.8h, v2.16b, v7.16b\\n"
-             "uadalp v27.4s, v11.8h\\n"
-
-             // Lower part of a3 * {b0,b1,b2,b3} and accumulate
-             "umull2 v12.8h, v3.16b, v4.16b\\n"
-             "uadalp v28.4s, v12.8h\\n"
-             "umull2 v13.8h, v3.16b, v5.16b\\n"
-             "uadalp v29.4s, v13.8h\\n"
-             "umull2 v14.8h, v3.16b, v6.16b\\n"
-             "uadalp v30.4s, v14.8h\\n"
-             "umull2 v15.8h, v3.16b, v7.16b\\n"
-             "uadalp v31.4s, v15.8h\\n"
-    """
-
-
-def gemm_quantized_impl(M, N, K, unroll, interleave, data_type="uint8"):
-    """Assembly implementation of a blocked gemv. Given
-    a block a of shape (4, k) and a block b' of shape (4, k)
-    produces the output block c = a*b of shape (4,4)"""
-
-    stepA = min(4, M)
-    stepB = min(4, N)
-    assert data_type in ["uint8", "int8"], "Only uint8/int8 supported for this implementation"
-
-    signature = """extern "C" int gemm_quantized_{0}_{0}_int32_{1}_{2}""".format(
-        data_type, stepA, stepB
-    )
-    if unroll:
-        signature += "_" + str(K)
-
-    if interleave:
-        signature += "_interleaved"
-
-    signature += """(int *c_buffer,
-                      unsigned char *a_buffer,
-                      unsigned char *b_buffer,
-                      int K, int m, int n)"""
-
-    cc_code = signature
-    cc_code += """
-    {
-            unsigned char * a_ptr = a_buffer;
-            unsigned char * b_ptr = b_buffer;
-            int * c_ptr = c_buffer;
-
-            int k = K / 16;
-
-            __asm__  __volatile__ (
-                "movi v16.4s, #0\\n"
-                "movi v17.4s, #0\\n"
-                "movi v18.4s, #0\\n"
-                "movi v19.4s, #0\\n"
-                "movi v20.4s, #0\\n"
-                "movi v21.4s, #0\\n"
-                "movi v22.4s, #0\\n"
-                "movi v23.4s, #0\\n"
-                "movi v24.4s, #0\\n"
-                "movi v25.4s, #0\\n"
-                "movi v26.4s, #0\\n"
-                "movi v27.4s, #0\\n"
-                "movi v28.4s, #0\\n"
-                "movi v29.4s, #0\\n"
-                "movi v30.4s, #0\\n"
-                "movi v31.4s, #0\\n"
-            "1:"
-    """
-
-    main_loop = ' "ldr q0, [%[a_ptr]]\\n" '
-
-    if M > 1:
-        main_loop += ' "ldr q1, [%[a_ptr], #16]\\n" '
-    else:
-        main_loop += ' "movi v1.4s, #0\\n" '
-
-    if M > 2:
-        main_loop += ' "ldr q2, [%[a_ptr], #32]\\n" '
-    else:
-        main_loop += ' "movi v2.4s, #0\\n" '
-
-    if M > 3:
-        main_loop += ' "ldr q3, [%[a_ptr], #48]\\n" '
-    else:
-        main_loop += ' "movi v3.4s, #0\\n" '
-
-    main_loop += ' "ldr q4, [%[b_ptr]]\\n" '
-
-    if N > 1:
-        main_loop += ' "ldr q5, [%[b_ptr], #16]\\n" '
-
-    if N > 2:
-        main_loop += ' "ldr q6, [%[b_ptr], #32]\\n" '
 
-    if N > 3:
-        main_loop += ' "ldr q7, [%[b_ptr], #48]\\n" '
 
-    # Main computation can interleave multiply/accumulate instructions
-    # or schedule them in batches (first all multiplies then all accumulates)
-    if interleave:
-        main_loop += gemm_quantized_4_4_interleaved()
-    else:
-        main_loop += gemm_quantized_4_4_batched()
-
-    blockA = min(64, M * 16)
-    blockB = min(64, N * 16)
-    main_loop += """// Increment pointers
-                    "add %[a_ptr], %[a_ptr], #{0}\\n"
-                    "add %[b_ptr], %[b_ptr], #{1}\\n" """.format(
-        blockA, blockB
-    )
-
-    if unroll:
-        k = int(K // 16)
-        for l in range(0, k):
-            cc_code += main_loop
-    else:
-        cc_code += main_loop
-        cc_code += """
-                    "subs %w[k], %w[k], #1\\n"
-                    "cbnz %w[k], 1b\\n"
-                   """
-    cc_code += """
-                // Final additions
-
-                // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d)
-                // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h)
-                // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l)
-                // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p)
-                "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h)
-                "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p)
-                "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
-
-                // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d)
-                // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h)
-                // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l)
-                // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p)
-                "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h)
-                "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p)
-                "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
-
-                // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d)
-                // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h)
-                // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l)
-                // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p)
-                "addp v24.4s, v24.4s, v25.4s\\n"  // v24 = (a+b, c+d, e+f, g+h)
-                "addp v25.4s, v26.4s, v27.4s\\n"  // v25 = (i+j, k+l, m+n, o+p)
-                "addp v24.4s, v24.4s, v25.4s\\n"  // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
-
-                // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d)
-                // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h)
-                // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l)
-                // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p)
-                "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h)
-                "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p)
-                "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
-
-                "str q16, [%[c_ptr]]\\n"
-            """
-
-    stepC = min(4, N)
-    if M > 1:
-        cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4)
-
-    if M > 2:
-        cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8)
-
-    if M > 3:
-        cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12)
-
-    cc_code += """
-             : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k)
-             :
-             : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
-                    "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
-                    "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
-                    "v27", "v28", "v29", "v30", "v31"
-             );
-        return 0;
-        }
-    """
-
-    if data_type == "int8":
-        cc_code = cc_code.replace("unsigned char", "char")
-        cc_code = cc_code.replace("umull", "smull")
-        cc_code = cc_code.replace("uadalp", "sadalp")
-
-    temp = utils.tempdir()
-    ll_path = temp.relpath("temp.ll")
-    # Create LLVM ir from c source code
-    ll_code = clang.create_llvm(
-        cc_code, options=["--target=aarch64-linux-gnu -mattr=+neon"], output=ll_path
-    )
-    return ll_code
-
-
-def gemm_quantized(M, N, K, unroll, interleave, in_type, out_type):
+def gemm_4x4_int8_int8_int32(M, N, K, unroll, in_type):

Review comment:
       We could add some pseudo-code here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org