You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/05/11 05:02:55 UTC

[tvm] branch main updated: [Vulkan][Codegen] Spir-V codegen, correct labels/blocks in WhileNode. (#8013)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b49744  [Vulkan][Codegen] Spir-V codegen, correct labels/blocks in WhileNode. (#8013)
7b49744 is described below

commit 7b497442ec21ded9a5fc40dab233588485170ca9
Author: Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon May 10 22:02:25 2021 -0700

    [Vulkan][Codegen] Spir-V codegen, correct labels/blocks in WhileNode. (#8013)
    
    Previously, the WhileNode assumes that evaluating the loop condition
    will not introduce any additional labels.  If this assumption is
    violated, such as for a WhileNode whose condition is an if/else
    statement, then the OpLoopMerge instruction appears in the wrong
    block.
    
    The unittest added exercises this code path, but doesn't yet trigger a
    failure.  Once spvValidate is enabled for all vulkan codegen, then
    this unit test will catch the failure mode.
    
    Co-authored-by: Eric Lunderberg <el...@octoml.ai>
---
 src/target/spirv/codegen_spirv.cc                  |  9 +++-
 .../python/unittest/test_target_codegen_vulkan.py  | 50 ++++++++++++++++++++++
 2 files changed, 58 insertions(+), 1 deletion(-)

diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc
index 8188744..0c6deb2 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -549,6 +549,7 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
 
 void CodeGenSPIRV::VisitStmt_(const WhileNode* op) {
   spirv::Label head_label = builder_->NewLabel();
+  spirv::Label condition_label = builder_->NewLabel();
   spirv::Label body_label = builder_->NewLabel();
   spirv::Label continue_label = builder_->NewLabel();
   spirv::Label merge_label = builder_->NewLabel();
@@ -556,9 +557,15 @@ void CodeGenSPIRV::VisitStmt_(const WhileNode* op) {
 
   // Loop head
   builder_->StartLabel(head_label);
-  spirv::Value loop_cond = MakeValue(op->condition);
   uint32_t control = spv::LoopControlMaskNone;
   builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
+  builder_->MakeInst(spv::OpBranch, condition_label);
+
+  // Loop condition evaluation.  The condition could contain if/else
+  // blocks that introduce additional labels, so the condition cannot
+  // be in the loop head's block.
+  builder_->StartLabel(condition_label);
+  spirv::Value loop_cond = MakeValue(op->condition);
   builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label,
                      weight_likely_branch_, 1);
 
diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py
index 9528741..56181db 100644
--- a/tests/python/unittest/test_target_codegen_vulkan.py
+++ b/tests/python/unittest/test_target_codegen_vulkan.py
@@ -307,6 +307,56 @@ def test_vulkan_constant_passing():
     test_scalar_params(2044)
 
 
+@tvm.testing.parametrize_targets("vulkan")
+def test_vulkan_while_if(target, dev):
+    def do_compute(A, B, n):
+        ib = tvm.tir.ir_builder.create()
+        A = ib.buffer_ptr(A)
+        B = ib.buffer_ptr(B)
+
+        ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)
+
+        iterations = ib.allocate("int32", (1,), name="iterations", scope="local")
+        iterations[0] = 0
+        B[0] = 0
+
+        # WhileNode's condition is re-evaluated every loop.  The
+        # if_then_else block introduces additional labels/blocks that
+        # must be kept separate from the WhileNode's block.
+        loop_condition = iterations[0] < tvm.tir.if_then_else(A[0] > 0, 10, 20)
+        with ib.while_loop(loop_condition):
+            iterations[0] += 1
+            B[0] += iterations[0]
+
+        return ib.get()
+
+    n = 1
+    dtype = "int32"
+    A = te.placeholder((n,), name="A", dtype=dtype)
+
+    B = te.extern(
+        A.shape,
+        [A],
+        lambda ins, outs: do_compute(ins[0], outs[0], n),
+        dtype=dtype,
+    )
+    s = te.create_schedule(B.op)
+
+    # Point of failure would be here, at tvm.build.
+    with tvm.transform.PassContext(opt_level=3):
+        func = tvm.build(s, [A, B], target)
+
+    a = tvm.nd.array(np.array([5], dtype=A.dtype), dev)
+    b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
+    func(a, b)
+    tvm.testing.assert_allclose(b.asnumpy(), [55])
+
+    a = tvm.nd.array(np.array([-5], dtype=A.dtype), dev)
+    b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
+    func(a, b)
+    tvm.testing.assert_allclose(b.asnumpy(), [210])
+
+
 if __name__ == "__main__":
     test_vector_comparison()
     test_vulkan_copy()