You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/05/11 05:02:55 UTC
[tvm] branch main updated: [Vulkan][Codegen] Spir-V codegen,
correct labels/blocks in WhileNode. (#8013)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7b49744 [Vulkan][Codegen] Spir-V codegen, correct labels/blocks in WhileNode. (#8013)
7b49744 is described below
commit 7b497442ec21ded9a5fc40dab233588485170ca9
Author: Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon May 10 22:02:25 2021 -0700
[Vulkan][Codegen] Spir-V codegen, correct labels/blocks in WhileNode. (#8013)
Previously, the WhileNode assumes that evaluating the loop condition
will not introduce any additional labels. If this assumption is
violated, such as for a WhileNode whose condition is an if/else
statement, then the OpLoopMerge instruction appears in the wrong
block.
The unittest added exercises this code path, but doesn't yet trigger a
failure. Once spvValidate is enabled for all vulkan codegen, then
this unit test will catch the failure mode.
Co-authored-by: Eric Lunderberg <el...@octoml.ai>
---
src/target/spirv/codegen_spirv.cc | 9 +++-
.../python/unittest/test_target_codegen_vulkan.py | 50 ++++++++++++++++++++++
2 files changed, 58 insertions(+), 1 deletion(-)
diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index 8188744..0c6deb2 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -549,6 +549,7 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
void CodeGenSPIRV::VisitStmt_(const WhileNode* op) {
spirv::Label head_label = builder_->NewLabel();
+ spirv::Label condition_label = builder_->NewLabel();
spirv::Label body_label = builder_->NewLabel();
spirv::Label continue_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
@@ -556,9 +557,15 @@ void CodeGenSPIRV::VisitStmt_(const WhileNode* op) {
// Loop head
builder_->StartLabel(head_label);
- spirv::Value loop_cond = MakeValue(op->condition);
uint32_t control = spv::LoopControlMaskNone;
builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
+ builder_->MakeInst(spv::OpBranch, condition_label);
+
+ // Loop condition evaluation. The condition could contain if/else
+ // blocks that introduce additional labels, so the condition cannot
+ // be in the loop head's block.
+ builder_->StartLabel(condition_label);
+ spirv::Value loop_cond = MakeValue(op->condition);
builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label,
weight_likely_branch_, 1);
diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py
index 9528741..56181db 100644
--- a/tests/python/unittest/test_target_codegen_vulkan.py
+++ b/tests/python/unittest/test_target_codegen_vulkan.py
@@ -307,6 +307,56 @@ def test_vulkan_constant_passing():
test_scalar_params(2044)
+@tvm.testing.parametrize_targets("vulkan")
+def test_vulkan_while_if(target, dev):
+ def do_compute(A, B, n):
+ ib = tvm.tir.ir_builder.create()
+ A = ib.buffer_ptr(A)
+ B = ib.buffer_ptr(B)
+
+ ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)
+
+ iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
+ iterations[0] = 0
+ B[0] = 0
+
+ # WhileNode's condition is re-evaluated every loop. The
+ # if_then_else block introduces additional labels/blocks that
+ # must be kept separate from the WhileNode's block.
+ loop_condition = iterations[0] < tvm.tir.if_then_else(A[0] > 0, 10, 20)
+ with ib.while_loop(loop_condition):
+ iterations[0] += 1
+ B[0] += iterations[0]
+
+ return ib.get()
+
+ n = 1
+ dtype = "int32"
+ A = te.placeholder((n,), name="A", dtype=dtype)
+
+ B = te.extern(
+ A.shape,
+ [A],
+ lambda ins, outs: do_compute(ins[0], outs[0], n),
+ dtype=dtype,
+ )
+ s = te.create_schedule(B.op)
+
+ # Point of failure would be here, at tvm.build.
+ with tvm.transform.PassContext(opt_level=3):
+ func = tvm.build(s, [A, B], target)
+
+ a = tvm.nd.array(np.array([5], dtype=A.dtype), dev)
+ b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
+ func(a, b)
+ tvm.testing.assert_allclose(b.asnumpy(), [55])
+
+ a = tvm.nd.array(np.array([-5], dtype=A.dtype), dev)
+ b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
+ func(a, b)
+ tvm.testing.assert_allclose(b.asnumpy(), [210])
+
+
if __name__ == "__main__":
test_vector_comparison()
test_vulkan_copy()