You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2021/04/24 21:27:07 UTC

[tvm] branch main updated: [TE] Fix bug if find a loop in compute_at attach path (#7898)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ed16f07  [TE] Fix bug if find a loop in compute_at attach path (#7898)
ed16f07 is described below

commit ed16f0799df0eaebb9bbadab6d0429e3d1b6cdb4
Author: Y <cy...@live.com>
AuthorDate: Sun Apr 25 05:26:48 2021 +0800

    [TE] Fix bug if find a loop in compute_at attach path (#7898)
---
 src/te/schedule/schedule_dataflow_rewrite.cc |  3 +++
 tests/python/unittest/test_te_schedule.py    | 28 ++++++++++++++++++++++++++++
 2 files changed, 31 insertions(+)

diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc
index d1f3a89..fae826b 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -647,9 +647,12 @@ void LegalizeInvalidAttach(ScheduleNode* sch) {
   std::unordered_map<IterVar, IterVar> replace_map;
 
   for (Stage stage : sch->stages) {
+    std::unordered_set<const Object*> visited;
     for (Stage s = stage; s.defined();) {
       // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
       // because we follow the validation check in that function to legalize the attach.
+      ICHECK(!visited.count(s.get())) << "Find loop in compute_at attach group";
+      visited.insert(s.get());
       Stage spec = s.GetAttachSpec();
       if (spec->attach_type != kScope) {
         break;
diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py
index 316aa6f..60a3247 100644
--- a/tests/python/unittest/test_te_schedule.py
+++ b/tests/python/unittest/test_te_schedule.py
@@ -321,6 +321,33 @@ def test_legalize_invalid_attach():
     assert isinstance(stmt, tvm.tir.stmt.For)
 
 
+def test_compute_at():
+    def add():
+        shape = (16, 16)
+        A = tvm.te.compute(shape, lambda *i: 1.0, name="A")
+        B = tvm.te.compute(shape, lambda *i: 2.0, name="B")
+        C = tvm.te.compute(shape, lambda *i: A(*i) + B(*i), name="C")
+        return A, B, C
+
+    def invalid_compute_at_self():
+        A, B, C = add()
+        s = tvm.te.create_schedule(C.op)
+        s[C].compute_at(s[C], C.op.axis[0])
+        with pytest.raises(RuntimeError):
+            tvm.lower(s, [A, B], simple_mode=True)
+
+    def invalid_compute_at_loop():
+        A, B, C = add()
+        s = tvm.te.create_schedule(C.op)
+        s[A].compute_at(s[C], C.op.axis[0])
+        s[C].compute_at(s[A], A.op.axis[0])
+        with pytest.raises(RuntimeError):
+            tvm.lower(s, [C], simple_mode=True)
+
+    invalid_compute_at_self()
+    invalid_compute_at_loop()
+
+
 if __name__ == "__main__":
     test_singleton()
     test_pragma()
@@ -338,3 +365,4 @@ if __name__ == "__main__":
     test_vectorize()
     test_vectorize_commreduce()
     test_legalize_invalid_attach()
+    test_compute_at()