You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/06/26 05:52:29 UTC
[incubator-tvm] branch master updated: [TE] Add
LegalizeInvalidAttach to legalize the compute_at location after split or
fuse (#5917)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 96bf271 [TE] Add LegalizeInvalidAttach to legalize the compute_at location after split or fuse (#5917)
96bf271 is described below
commit 96bf271e4925030d4533737991dfa7161be60d93
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Thu Jun 25 22:52:19 2020 -0700
[TE] Add LegalizeInvalidAttach to legalize the compute_at location after split or fuse (#5917)
* Add LegalizeInvalidAttach
* lint & typo
* lint & typo
* address comment
* fix lint
---
.gitignore | 1 +
src/te/schedule/schedule_dataflow_rewrite.cc | 79 +++++++++++++++++++++++++++-
tests/python/unittest/test_te_schedule.py | 20 +++++++
3 files changed, 98 insertions(+), 2 deletions(-)
diff --git a/.gitignore b/.gitignore
index b935701..506e54d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -196,6 +196,7 @@ tvm_t.*
.python_history
.pytest_cache
.local
+cmake-build-debug
# Visual Studio Code
.vscode
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc
index af72d3b..f130cb4 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -451,7 +451,7 @@ Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) {
}
}
-void RebaseNonZeroMinLoop(const Schedule& sch) {
+void RebaseNonZeroMinLoop(ScheduleNode* sch) {
std::unordered_map<IterVar, IterVar> rebase_map;
for (Stage s : sch->stages) {
if (s->attach_type == kInlinedAlready) continue;
@@ -614,10 +614,85 @@ void InjectInline(ScheduleNode* sch) {
}
}
+void LegalizeInvalidAttach(ScheduleNode* sch) {
+ // Legalize the compute_at location if the target iterator of compute_at is split or fused.
+ // Case 1: If the target of compute_at is split,
+ // we will move the compute_at location to the inner iterator.
+ // Case 2: If the target of compute_at is fused,
+ // we will move the compute_at location to the newly fused iterator.
+ // Note that case 2 can only happen if the target of compute_at
+ // is the innermost operand of fuse operation.
+
+ // Map an old invalid attach point to its new valid attach point
+ std::unordered_map<IterVar, IterVar> replace_map;
+
+ for (Stage stage : sch->stages) {
+ for (Stage s = stage; s.defined();) {
+ // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
+ // because we follow the validation check in that function to legalize the attach.
+ Stage spec = s.GetAttachSpec();
+ if (spec->attach_type != kScope) {
+ break;
+ }
+ bool start_attach = false;
+ IterVar attach_ivar = spec->attach_ivar;
+ s = spec->attach_stage;
+ CHECK(attach_ivar.defined());
+ CHECK(s.defined());
+
+ for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
+ IterVar iv = s->leaf_iter_vars[i - 1];
+ if (!start_attach && iv.same_as(attach_ivar)) {
+ start_attach = true;
+ break;
+ }
+ }
+
+ if (!start_attach) {
+ IterVar new_attach_ivar = attach_ivar;
+ bool updated = true;
+ // recursively update the relations
+ while (updated) {
+ updated = false;
+ for (const auto& rel : s->relations) {
+ if (const FuseNode* r = rel.as<FuseNode>()) {
+ if (new_attach_ivar.same_as(r->inner)) {
+ new_attach_ivar = r->fused;
+ updated = true;
+ }
+ } else if (const SplitNode* r = rel.as<SplitNode>()) {
+ if (new_attach_ivar.same_as(r->parent)) {
+ new_attach_ivar = r->inner;
+ updated = true;
+ }
+ }
+ }
+ replace_map[attach_ivar] = new_attach_ivar;
+ }
+ }
+ }
+ }
+
+ // remap the parent relation
+ for (Stage s : sch->stages) {
+ if (s->attach_type != kScope) continue;
+ if (replace_map.count(s->attach_ivar)) {
+ s->attach_ivar = replace_map.at(s->attach_ivar);
+ }
+ }
+ for (Stage s : sch->groups) {
+ if (s->attach_type != kScope) continue;
+ if (replace_map.count(s->attach_ivar)) {
+ s->attach_ivar = replace_map.at(s->attach_ivar);
+ }
+ }
+}
+
Schedule Schedule::normalize() {
Schedule sn = copy();
InjectInline(sn.operator->());
- RebaseNonZeroMinLoop(sn);
+ RebaseNonZeroMinLoop(sn.operator->());
+ LegalizeInvalidAttach(sn.operator->());
return sn;
}
diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py
index 2c851cc..c00ee70 100644
--- a/tests/python/unittest/test_te_schedule.py
+++ b/tests/python/unittest/test_te_schedule.py
@@ -289,6 +289,25 @@ def test_tensor_intrin_scalar_params():
assert str(stmt.body.body.value.args[3]) == "(i: int32*i)"
assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)"
+def test_legalize_invalid_attach():
+ A = te.compute((10, 10), lambda i, j: 1.0, name='A')
+ B = te.compute((10, 10), lambda i, j: A[i][j], name='B')
+
+ # Case 1: Split an axis which is the target of a compute_at
+ s = te.create_schedule([B.op])
+ s[A].compute_at(s[B], B.op.axis[1])
+ s[B].split(B.op.axis[1], 2)
+
+ stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body
+ assert isinstance(stmt.body.body, tvm.tir.stmt.For)
+
+ # Case 2: Fuse an axis which is the target of a compute_at
+ s = te.create_schedule([B.op])
+ s[A].compute_at(s[B], B.op.axis[1])
+ s[B].fuse(B.op.axis[0], B.op.axis[1])
+ stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body
+ assert isinstance(stmt, tvm.tir.stmt.For)
+
if __name__ == "__main__":
test_singleton()
test_pragma()
@@ -305,3 +324,4 @@ if __name__ == "__main__":
test_fuse_with_out_of_order_axis_with_reorder()
test_vectorize()
test_vectorize_commreduce()
+ test_legalize_invalid_attach()