You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "masahi (via GitHub)" <gi...@apache.org> on 2023/03/21 19:51:17 UTC

[GitHub] [tvm] masahi commented on a diff in pull request #14209: [MetaSchedule][ARM] Enable ARM CPU intrinsic for MetaSchedule

masahi commented on code in PR #14209:
URL: https://github.com/apache/tvm/pull/14209#discussion_r1143919180


##########
python/tvm/tir/tensor_intrin/arm_cpu.py:
##########
@@ -131,8 +163,68 @@ def dot_product_4x4_i8i8i32_sdot(
         )
 
 
+@T.prim_func
+def dot_product_4x4_u8u8u32_udot(
+    A: T.Buffer((4,), "uint8", offset_factor=1),
+    B: T.Buffer((4, 4), "uint8", offset_factor=1),
+    C: T.Buffer((4,), "uint32", offset_factor=1),
+) -> None:
+    with T.block("root"):
+        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
+        T.writes(C[0:4])
+
+        A_i8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_i8x4, dtype="uint32")
+        vec_ai32 = T.broadcast(A_i32, 4)
+        vec_a = T.reinterpret(vec_ai32, dtype="uint8x16")
+
+        vec_b = B.vload([0, 0], dtype="uint8x16")
+
+        vec_c = C.vload([0], dtype="uint32x4")
+
+        C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
+            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.udot.v4u32.v16u8"),
+            T.uint32(3),
+            vec_c,
+            vec_a,
+            vec_b,
+            dtype="uint32x4",
+        )
+
+
+@T.prim_func
+def dot_product_4x4_u8u8i32_hdot(
+    A: T.Buffer((4,), "uint8", offset_factor=1),
+    B: T.Buffer((4, 4), "uint8", offset_factor=1),
+    C: T.Buffer((4,), "int32", offset_factor=1),
+) -> None:
+    with T.block("root"):
+        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
+        T.writes(C[0:4])
+
+        A_i8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_i8x4, dtype="uint32")
+        vec_ai32 = T.broadcast(A_i32, 4)
+        vec_a = T.reinterpret(vec_ai32, dtype="uint8x16")
+
+        vec_b = B.vload([0, 0], dtype="uint8x16")
+
+        vec_c = C.vload([0], dtype="int32x4")
+
+        C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
+            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.udot.v4u32.v16u8"),
+            T.uint32(3),
+            vec_c,
+            vec_a,
+            vec_b,
+            dtype="int32x4",
+        )

Review Comment:
   It should be possible to clean up a lot of code duplication between different dtypes here. See `tensor_intrin/cuda.py` for examples.



##########
src/meta_schedule/schedule_rule/schedule_rule.cc:
##########
@@ -295,6 +295,118 @@ Array<ScheduleRule> ScheduleRule::DefaultMicro() {
   };
 }
 
+Array<ScheduleRule> ScheduleRule::DefaultARMNeon() {
+  return {
+      ScheduleRule::ApplyCustomRule(),
+      ScheduleRule::InlineConstantScalars(),
+      ScheduleRule::AutoInline(
+          /*into_producer=*/false,
+          /*into_consumer=*/true,
+          /*inline_const_tensor=*/true,
+          /*disallow_if_then_else=*/true,
+          /*require_injective=*/true,
+          /*require_ordered=*/true,
+          /*disallow_op=*/Array<String>{"tir.exp"}),
+      ScheduleRule::AddRFactor(
+          /*max_jobs_per_core=*/8,
+          /*max_innermost_factor=*/Integer(32)),
+      ScheduleRule::MultiLevelTilingWithIntrin(
+          /*intrin_name=*/String("dot_4x4_i8i8s32_neon"),
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(32),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+      ScheduleRule::MultiLevelTiling(
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(32),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+      ScheduleRule::ParallelizeVectorizeUnroll(
+          /*max_jobs_per_core=*/8,
+          /*max_vectorize_extent=*/32,
+          /*unroll_max_steps=*/Array<Integer>{0, 8, 32, 256},
+          /*unroll_explicit=*/true),
+      ScheduleRule::RandomComputeLocation(),
+  };
+}
+
+Array<ScheduleRule> ScheduleRule::DefaultARMDotprod() {

Review Comment:
   Please remove the dup with `DefaultARMDotprod`, to make the difference obvious.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org