You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/06/26 05:52:29 UTC

[incubator-tvm] branch master updated: [TE] Add LegalizeInvalidAttach to legalize the compute_at location after split or fuse (#5917)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 96bf271  [TE] Add LegalizeInvalidAttach to legalize the compute_at location after split or fuse (#5917)
96bf271 is described below

commit 96bf271e4925030d4533737991dfa7161be60d93
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Thu Jun 25 22:52:19 2020 -0700

    [TE] Add LegalizeInvalidAttach to legalize the compute_at location after split or fuse (#5917)
    
    * Add LegalizeInvalidAttach
    
    * lint & typo
    
    * lint & typo
    
    * address comment
    
    * fix lint
---
 .gitignore                                   |  1 +
 src/te/schedule/schedule_dataflow_rewrite.cc | 79 +++++++++++++++++++++++++++-
 tests/python/unittest/test_te_schedule.py    | 20 +++++++
 3 files changed, 98 insertions(+), 2 deletions(-)

diff --git a/.gitignore b/.gitignore
index b935701..506e54d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -196,6 +196,7 @@ tvm_t.*
 .python_history
 .pytest_cache
 .local
+cmake-build-debug
 
 # Visual Studio Code
 .vscode
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc
index af72d3b..f130cb4 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -451,7 +451,7 @@ Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) {
   }
 }
 
-void RebaseNonZeroMinLoop(const Schedule& sch) {
+void RebaseNonZeroMinLoop(ScheduleNode* sch) {
   std::unordered_map<IterVar, IterVar> rebase_map;
   for (Stage s : sch->stages) {
     if (s->attach_type == kInlinedAlready) continue;
@@ -614,10 +614,85 @@ void InjectInline(ScheduleNode* sch) {
   }
 }
 
+void LegalizeInvalidAttach(ScheduleNode* sch) {
+  // Legalize the compute_at location if the target iterator of compute_at is split or fused.
+  // Case 1: If the target of compute_at is split,
+  //         we will move the compute_at location to the inner iterator.
+  // Case 2: If the target of compute_at is fused,
+  //         we will move the compute_at location to the newly fused iterator.
+  // Note that case 2 can only happen if the target of compute_at
+  // is the innermost operand of fuse operation.
+
+  // Map an old invalid attach point to its new valid attach point
+  std::unordered_map<IterVar, IterVar> replace_map;
+
+  for (Stage stage : sch->stages) {
+    for (Stage s = stage; s.defined();) {
+      // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
+      // because we follow the validation check in that function to legalize the attach.
+      Stage spec = s.GetAttachSpec();
+      if (spec->attach_type != kScope) {
+        break;
+      }
+      bool start_attach = false;
+      IterVar attach_ivar = spec->attach_ivar;
+      s = spec->attach_stage;
+      CHECK(attach_ivar.defined());
+      CHECK(s.defined());
+
+      for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
+        IterVar iv = s->leaf_iter_vars[i - 1];
+        if (!start_attach && iv.same_as(attach_ivar)) {
+          start_attach = true;
+          break;
+        }
+      }
+
+      if (!start_attach) {
+        IterVar new_attach_ivar = attach_ivar;
+        bool updated = true;
+        // recursively update the relations
+        while (updated) {
+          updated = false;
+          for (const auto& rel : s->relations) {
+            if (const FuseNode* r = rel.as<FuseNode>()) {
+              if (new_attach_ivar.same_as(r->inner)) {
+                new_attach_ivar = r->fused;
+                updated = true;
+              }
+            } else if (const SplitNode* r = rel.as<SplitNode>()) {
+              if (new_attach_ivar.same_as(r->parent)) {
+                new_attach_ivar = r->inner;
+                updated = true;
+              }
+            }
+          }
+          replace_map[attach_ivar] = new_attach_ivar;
+        }
+      }
+    }
+  }
+
+  // remap the parent relation
+  for (Stage s : sch->stages) {
+    if (s->attach_type != kScope) continue;
+    if (replace_map.count(s->attach_ivar)) {
+      s->attach_ivar = replace_map.at(s->attach_ivar);
+    }
+  }
+  for (Stage s : sch->groups) {
+    if (s->attach_type != kScope) continue;
+    if (replace_map.count(s->attach_ivar)) {
+      s->attach_ivar = replace_map.at(s->attach_ivar);
+    }
+  }
+}
+
 Schedule Schedule::normalize() {
   Schedule sn = copy();
   InjectInline(sn.operator->());
-  RebaseNonZeroMinLoop(sn);
+  RebaseNonZeroMinLoop(sn.operator->());
+  LegalizeInvalidAttach(sn.operator->());
   return sn;
 }
 
diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py
index 2c851cc..c00ee70 100644
--- a/tests/python/unittest/test_te_schedule.py
+++ b/tests/python/unittest/test_te_schedule.py
@@ -289,6 +289,25 @@ def test_tensor_intrin_scalar_params():
     assert str(stmt.body.body.value.args[3]) == "(i: int32*i)"
     assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)"
 
+def test_legalize_invalid_attach():
+    A = te.compute((10, 10), lambda i, j: 1.0, name='A')
+    B = te.compute((10, 10), lambda i, j: A[i][j], name='B')
+
+    # Case 1: Split an axis which is the target of a compute_at
+    s = te.create_schedule([B.op])
+    s[A].compute_at(s[B], B.op.axis[1])
+    s[B].split(B.op.axis[1], 2)
+
+    stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body
+    assert isinstance(stmt.body.body, tvm.tir.stmt.For)
+
+    # Case 2: Fuse an axis which is the target of a compute_at
+    s = te.create_schedule([B.op])
+    s[A].compute_at(s[B], B.op.axis[1])
+    s[B].fuse(B.op.axis[0], B.op.axis[1])
+    stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body
+    assert isinstance(stmt, tvm.tir.stmt.For)
+
 if __name__ == "__main__":
     test_singleton()
     test_pragma()
@@ -305,3 +324,4 @@ if __name__ == "__main__":
     test_fuse_with_out_of_order_axis_with_reorder()
     test_vectorize()
     test_vectorize_commreduce()
+    test_legalize_invalid_attach()