You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/10 01:49:47 UTC
[incubator-tvm] branch master updated: Create loops according to
storage scope and thread hierarchies (#5190)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 3d09e64 Create loops according to storage scope and thread hierarchies (#5190)
3d09e64 is described below
commit 3d09e64daeb506ae6f75193faf512180771fd583
Author: yongfeng-nv <49...@users.noreply.github.com>
AuthorDate: Thu Apr 9 21:49:37 2020 -0400
Create loops according to storage scope and thread hierarchies (#5190)
* Set IterVar index to 0 for local thread bound IterVars.
* Lint fix
* Use rank instead of scope name to predicate. Add tests.
* Handle cases other than local/threadIdx.
* Turn warp to the old behavior.
* Modify test to cover global/blockIdx.
* Fix a typo.
* Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
---
src/te/operation/op_util.cc | 9 ++-
tests/python/unittest/test_te_schedule_ops.py | 88 +++++++++++++++++++++++++++
2 files changed, 96 insertions(+), 1 deletion(-)
diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc
index 3714f43..4ecfe94 100644
--- a/src/te/operation/op_util.cc
+++ b/src/te/operation/op_util.cc
@@ -29,6 +29,7 @@
#include "op_util.h"
#include "../schedule/message_passing.h"
#include "../../arith/compute_expr.h"
+#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace te {
@@ -162,7 +163,13 @@ MakeLoopNest(const Stage& stage,
if (!debug_keep_trivial_loop && is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
- value_map[iv] = var;
+ runtime::ThreadScope ts = runtime::ThreadScope::make(bind_iv->thread_tag);
+ if (stage->scope == "" || stage->scope == "warp" ||
+ static_cast<int>(runtime::StorageScope::make(stage->scope).rank) <= ts.rank) {
+ value_map[iv] = var;
+ } else {
+ value_map[iv] = dom->min;
+ }
}
}
// annotate the extent of the IterVar
diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py
index 8d10cee..4e27ad3 100644
--- a/tests/python/unittest/test_te_schedule_ops.py
+++ b/tests/python/unittest/test_te_schedule_ops.py
@@ -482,6 +482,92 @@ def test_schedule_compute_inline():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
+
+def test_local_stage_predicate():
+ m = 1
+ n = 3
+ p = 2
+ A = tvm.te.placeholder((m, n, p), name='A')
+ B = tvm.te.compute((m, n, p), lambda bi, bj, bk: A[bi, bj, bk], name="B")
+ C = tvm.te.compute((m, n, p), lambda ci, cj, ck: B[ci, cj, ck], name="C")
+ by = tvm.te.thread_axis("blockIdx.y")
+ tx = tvm.te.thread_axis("threadIdx.x")
+ vx = tvm.te.thread_axis("vthread")
+
+ def schedule(thread_tag, mem_scope) :
+ s = tvm.te.create_schedule(C.op)
+ s[B].compute_at(s[C], s[C].op.axis[0])
+ s[B].set_scope(mem_scope)
+ bno, bni = s[B].split(s[B].op.axis[1], n)
+ bx = tvm.te.thread_axis("blockIdx.x")
+ s[C].bind(s[C].op.axis[0], bx)
+ s[C].bind(s[C].op.axis[1], thread_tag)
+ s[B].bind(bni, thread_tag)
+ return s
+
+ def collect_visit(stmt, f):
+ ret = []
+ tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
+ return ret
+ # local vs. threadIdx
+ s = schedule(tx, "local")
+ lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ assert (not any(
+ collect_visit(lowered_body,
+ lambda x: isinstance(x, tvm.tir.IfThenElse))))
+ # local vs. vthread
+ s = schedule(vx, "local")
+ lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ assert (not any(
+ collect_visit(lowered_body,
+ lambda x: isinstance(x, tvm.tir.IfThenElse))))
+ # shared vs. blockIdx
+ s = schedule(by, "shared")
+ lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ assert (not any(
+ collect_visit(lowered_body,
+ lambda x: isinstance(x, tvm.tir.IfThenElse))))
+
+def test_local_stage_predicate2():
+ A = tvm.te.placeholder((128, ), name="A")
+ B = tvm.te.compute((128, ), lambda bi: A[bi] + 1, name="B")
+ C = tvm.te.compute((128, ), lambda ci: B[ci] + 2, name="C")
+ s = tvm.te.create_schedule(C.op)
+ AA = s.cache_read(A, "local", [B])
+ s[B].set_scope("shared")
+ block_x = tvm.te.thread_axis("blockIdx.x")
+ thread_x = tvm.te.thread_axis((0, 32), "threadIdx.x")
+ oc, ic = s[C].split(s[C].op.axis[0], factor=64)
+ ooc, ioc = s[C].split(oc, factor=2)
+ oic, iic = s[C].split(ic, factor=32)
+ s[C].bind(ooc, block_x)
+ s[C].bind(iic, thread_x)
+ s[B].compute_at(s[C], ioc)
+ ob, ib = s[B].split(s[B].op.axis[0], factor=32)
+ s[B].bind(ib, thread_x)
+ s[AA].compute_root()
+ s[AA].compute_at(s[C], ooc)
+ oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
+ s[AA].bind(iaa, thread_x)
+ lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+
+ def collect_visit(stmt, f):
+ ret = []
+ tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x: ret.append(f(x)))
+ return ret
+
+ def visit_stmt(op):
+ print(op)
+ if (isinstance(op, tvm.tir.Allocate)):
+ return op.extents[0].value == 97
+ return False
+
+ assert (not any(
+ collect_visit(lowered_body,
+ lambda x: isinstance(x, tvm.tir.IfThenElse))))
+ assert (any(collect_visit(lowered_body, visit_stmt)))
+
+
if __name__ == "__main__":
test_loop_dep_reduce()
test_loop_dep_reduce_cache_write()
@@ -506,3 +592,5 @@ if __name__ == "__main__":
test_schedule_tensor_compute3()
test_reduction_and_dummy_fuse_split()
test_schedule_compute_inline()
+ test_local_stage_predicate()
+ test_local_stage_predicate2()
\ No newline at end of file