You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2021/04/24 21:27:07 UTC
[tvm] branch main updated: [TE] Fix bug if find a loop in
compute_at attach path (#7898)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ed16f07 [TE] Fix bug if find a loop in compute_at attach path (#7898)
ed16f07 is described below
commit ed16f0799df0eaebb9bbadab6d0429e3d1b6cdb4
Author: Y <cy...@live.com>
AuthorDate: Sun Apr 25 05:26:48 2021 +0800
[TE] Fix bug if find a loop in compute_at attach path (#7898)
---
src/te/schedule/schedule_dataflow_rewrite.cc | 3 +++
tests/python/unittest/test_te_schedule.py | 28 ++++++++++++++++++++++++++++
2 files changed, 31 insertions(+)
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc
index d1f3a89..fae826b 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -647,9 +647,12 @@ void LegalizeInvalidAttach(ScheduleNode* sch) {
std::unordered_map<IterVar, IterVar> replace_map;
for (Stage stage : sch->stages) {
+ std::unordered_set<const Object*> visited;
for (Stage s = stage; s.defined();) {
// The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
// because we follow the validation check in that function to legalize the attach.
+ ICHECK(!visited.count(s.get())) << "Find loop in compute_at attach group";
+ visited.insert(s.get());
Stage spec = s.GetAttachSpec();
if (spec->attach_type != kScope) {
break;
diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py
index 316aa6f..60a3247 100644
--- a/tests/python/unittest/test_te_schedule.py
+++ b/tests/python/unittest/test_te_schedule.py
@@ -321,6 +321,33 @@ def test_legalize_invalid_attach():
assert isinstance(stmt, tvm.tir.stmt.For)
+def test_compute_at():
+ def add():
+ shape = (16, 16)
+ A = tvm.te.compute(shape, lambda *i: 1.0, name="A")
+ B = tvm.te.compute(shape, lambda *i: 2.0, name="B")
+ C = tvm.te.compute(shape, lambda *i: A(*i) + B(*i), name="C")
+ return A, B, C
+
+ def invalid_compute_at_self():
+ A, B, C = add()
+ s = tvm.te.create_schedule(C.op)
+ s[C].compute_at(s[C], C.op.axis[0])
+ with pytest.raises(RuntimeError):
+ tvm.lower(s, [A, B], simple_mode=True)
+
+ def invalid_compute_at_loop():
+ A, B, C = add()
+ s = tvm.te.create_schedule(C.op)
+ s[A].compute_at(s[C], C.op.axis[0])
+ s[C].compute_at(s[A], A.op.axis[0])
+ with pytest.raises(RuntimeError):
+ tvm.lower(s, [C], simple_mode=True)
+
+ invalid_compute_at_self()
+ invalid_compute_at_loop()
+
+
if __name__ == "__main__":
test_singleton()
test_pragma()
@@ -338,3 +365,4 @@ if __name__ == "__main__":
test_vectorize()
test_vectorize_commreduce()
test_legalize_invalid_attach()
+ test_compute_at()