You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2022/07/02 11:16:35 UTC

[tvm] branch main updated: [MetaSchedule] Enhance AutoInline for Spatial Task (#11996)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e97186957 [MetaSchedule] Enhance AutoInline for Spatial Task (#11996)
0e97186957 is described below

commit 0e971869575df7e5b12381e4566a1a8fd98a4a77
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Sat Jul 2 04:16:25 2022 -0700

    [MetaSchedule] Enhance AutoInline for Spatial Task (#11996)
    
    Previously, Auto-Inline on CPU will only inline according to strict
    conditions, for example, ordered index mapping. This is generally good
    practice to do so, but on the other hand, there is no much benefit to
    stop inlining only due to some restrictive conditions for pure spatial
    subgraphs. By doing so, we also save some search trials on pure spatial
    subgraphs so that more can be allocated to more important ones.
---
 src/meta_schedule/schedule_rule/auto_inline.cc     | 16 +++-
 ...test_meta_schedule_schedule_rule_auto_inline.py | 93 ++++++++++++++++++++++
 2 files changed, 106 insertions(+), 3 deletions(-)

diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc
index 0cfe35298d..309f0a60ac 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -31,6 +31,15 @@ enum class InlineType : int32_t {
   kInlineIntoProducer = 2,
 };
 
+bool IsInSpatialPrimFunc(const tir::Schedule& sch, const tir::StmtSRef& block_sref) {
+  using namespace tvm::tir;
+  const StmtSRefNode* sref = block_sref.get();
+  for (; sref->parent != nullptr; sref = sref->parent) {
+  }
+  ICHECK(sref->stmt != nullptr && sref->stmt->IsInstance<BlockNode>());
+  return IsSpatialPrimFunc(GetRef<PrimFunc>(GetRootPrimFunc(sch->mod(), sref->stmt, nullptr)));
+}
+
 /*! \brief The rule that inlines spatial blocks if it satisfies some conditions. */
 class AutoInlineNode : public ScheduleRuleNode {
  public:
@@ -85,6 +94,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
                                               const tir::BlockRV& block_rv) {
   using namespace tvm::tir;
   StmtSRef block_sref = sch->GetSRef(block_rv);
+  bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref);
   ScheduleState state = sch->state();
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   BlockRealize realize = GetBlockRealize(state, block_sref);
@@ -97,15 +107,15 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
     return InlineType::kInlineIntoConsumer;
   }
   // Cond 3. The block doesn't contain any disallowed operators
-  if (!disallow_op.empty() && HasOp(realize, disallow_op)) {
+  if (!is_pure_sptial && !disallow_op.empty() && HasOp(realize, disallow_op)) {
     return InlineType::kNoInline;
   }
   // Cond 4. The block doesn't have any if-then-else-like constructs
-  if (disallow_if_then_else && HasIfThenElse(realize)) {
+  if (!is_pure_sptial && disallow_if_then_else && HasIfThenElse(realize)) {
     return InlineType::kNoInline;
   }
   // Cond 5. The mapping from read indices to write indices are injective and ordered
-  if (require_injective || require_ordered) {
+  if (!is_pure_sptial && (require_injective || require_ordered)) {
     const BufferRegion& write_region = block->writes[0];
     for (const BufferRegion& read_region : block->reads) {
       bool injective, ordered;
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
index 2a8a1e5fe1..a8ffa6ff9d 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
@@ -240,6 +240,86 @@ class SoftmaxAfterInline:
                 T_softmax_norm[i0_4, i1_1] = T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") / T_softmax_expsum[i0_4]
 
 
+@tvm.script.ir_module
+class BeforePureSpatial:
+    @T.prim_func
+    def main(
+        placeholder: T.Buffer[(1, 384), "int64"],
+        placeholder_1: T.Buffer[(30522, 768), "float32"],
+        placeholder_2: T.Buffer[(1, 384, 768), "float32"],
+        T_add: T.Buffer[(1, 384, 768), "float32"],
+    ) -> None:
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        compile_engine_const = T.alloc_buffer([], dtype="int64")
+        T_less = T.alloc_buffer([1, 384], dtype="bool")
+        compile_engine_const_1 = T.alloc_buffer([], dtype="int64")
+        T_add_1 = T.alloc_buffer([1, 384], dtype="int64")
+        T_where = T.alloc_buffer([1, 384], dtype="int64")
+        T_take = T.alloc_buffer([1, 384, 768], dtype="float32")
+        with T.block("compile_engine_const"):
+            vi = T.axis.spatial(1, 0)
+            T.reads()
+            T.writes(compile_engine_const[()])
+            compile_engine_const[()] = T.int64(0)
+        for i0, i1 in T.grid(1, 384):
+            with T.block("T_less"):
+                ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                T.reads(placeholder[ax0, ax1], compile_engine_const[()])
+                T.writes(T_less[ax0, ax1])
+                T_less[ax0, ax1] = placeholder[ax0, ax1] < compile_engine_const[()]
+        with T.block("compile_engine_const_1"):
+            vi = T.axis.spatial(1, 0)
+            T.reads()
+            T.writes(compile_engine_const_1[()])
+            compile_engine_const_1[()] = T.int64(30522)
+        for i0, i1 in T.grid(1, 384):
+            with T.block("T_add"):
+                ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                T.reads(placeholder[ax0, ax1], compile_engine_const_1[()])
+                T.writes(T_add_1[ax0, ax1])
+                T_add_1[ax0, ax1] = placeholder[ax0, ax1] + compile_engine_const_1[()]
+        for i0, i1 in T.grid(1, 384):
+            with T.block("T_where"):
+                ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                T.reads(T_less[ax0, ax1], T_add_1[ax0, ax1], placeholder[ax0, ax1])
+                T.writes(T_where[ax0, ax1])
+                T_where[ax0, ax1] = T.Select(
+                    T.cast(T_less[ax0, ax1], "int32") != 0, T_add_1[ax0, ax1], placeholder[ax0, ax1]
+                )
+        for i0, i1, i2 in T.grid(1, 384, 768):
+            with T.block("T_take"):
+                ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(
+                    placeholder_1[T.min(T.max(T.int64(0), T_where[ax0, ax1]), T.int64(30521)), ax2],
+                    T_where[ax0, ax1],
+                )
+                T.writes(T_take[ax0, ax1, ax2])
+                T_take[ax0, ax1, ax2] = placeholder_1[
+                    T.min(T.max(T.int64(0), T_where[ax0, ax1]), T.int64(30521)), ax2
+                ]
+        for i0, i1, i2 in T.grid(1, 384, 768):
+            with T.block("T_add_1"):
+                ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(T_take[ax0, ax1, ax2], placeholder_2[ax0, ax1, ax2])
+                T.writes(T_add[ax0, ax1, ax2])
+                T_add[ax0, ax1, ax2] = T_take[ax0, ax1, ax2] + placeholder_2[ax0, ax1, ax2]
+
+
+@tvm.script.ir_module
+class AfterPureSpatial:
+    @T.prim_func
+    def main(placeholder: T.Buffer[(1, 384), "int64"], placeholder_1: T.Buffer[(30522, 768), "float32"], placeholder_2: T.Buffer[(1, 384, 768), "float32"], T_add: T.Buffer[(1, 384, 768), "float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        # with T.block("root")
+        for i0, i1, i2 in T.grid(1, 384, 768):
+            with T.block("T_add_1"):
+                ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(placeholder[ax0, ax1], placeholder_1[T.min(T.max(T.int64(0), placeholder[ax0, ax1]), T.int64(30521)) : T.min(T.max(T.int64(0), placeholder[ax0, ax1] + T.int64(30522)), T.int64(30521)) + T.int64(1), ax2], placeholder_2[ax0, ax1, ax2])
+                T.writes(T_add[ax0, ax1, ax2])
+                T_add[ax0, ax1, ax2] = placeholder_1[T.min(T.max(T.int64(0), T.Select(T.cast(placeholder[ax0, ax1] < T.int64(0), "int32") != 0, placeholder[ax0, ax1] + T.int64(30522), placeholder[ax0, ax1])), T.int64(30521)), ax2] + placeholder_2[ax0, ax1, ax2]
+
 # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
 # fmt: on
 
@@ -291,7 +371,20 @@ def test_inline_into_multiple_consumers():
     tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline)
 
 
+def test_inline_pure_spatial():
+    mod = BeforePureSpatial
+    target = Target("llvm")
+    ctx = _create_context(
+        mod=mod,
+        target=target,
+        rule=auto_inline(target=target),
+    )
+    (space,) = ctx.space_generator.generate_design_space(mod=mod)
+    tvm.ir.assert_structural_equal(lhs=space.mod, rhs=AfterPureSpatial)
+
+
 if __name__ == "__main__":
     test_inline_consumer_chain()
     test_inline_into_cache()
     test_inline_into_multiple_consumers()
+    test_inline_pure_spatial()