You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/05/22 01:57:58 UTC

[tvm] branch main updated: [TIR] Expand unit tests for ConvertSSA (#14892)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a094bce74 [TIR] Expand unit tests for ConvertSSA (#14892)
5a094bce74 is described below

commit 5a094bce746283967f35d11edd22e4803cdefe73
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Sun May 21 20:57:48 2023 -0500

    [TIR] Expand unit tests for ConvertSSA (#14892)
    
    * [TIR] Expand unit tests for ConvertSSA
    
    Prior to this PR, there was a single test which invoked ConvertSSA and
    checked that no error was thrown.  This PR adds additional test cases
    for nested variable definition, and for variables de-duplicated across
    separate function calls in an `IRModule`.
    
    Of the additional tests, the behavior tested by
    `TestDedupAutoBroadcastBuffer` and `TestReusedBufferParameter` fails
    on main, and is resolved by this PR.
    
    * Update tests to avoid relying on TVMScript output of non-SSA
    
    The "before" cases must have SSA violations, which is not valid in
    TIR.  As such, future versions of the TVMScript parser may remove the
    SSA violations in the process of parsing.  Updated unit tests
    introduce SSA violations through the Python API, to avoid this
    potential breakage.
---
 python/tvm/tir/transform/transform.py              |  18 ++
 src/tir/transforms/ir_utils.cc                     |  22 +-
 .../unittest/test_tir_transform_convert_ssa.py     | 253 +++++++++++++++++++++
 .../python/unittest/test_tir_transform_ir_utils.py |  40 ----
 4 files changed, 290 insertions(+), 43 deletions(-)

diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index f3aae306be..f2ce437814 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -341,6 +341,24 @@ def Simplify():
     return _ffi_api.Simplify()  # type: ignore
 
 
+def ConvertSSA():
+    """Convert an IRModule to be SSA form.
+
+    This pass handles cases where the same `tir.Var` appears in
+    multiple functions within the same module.  For example, after
+    extracting a fragment from one function into another, where the
+    same `tir.Var` may be defined both as within the body of the
+    original function, and as a parameter within the hoisted function.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+
+    """
+    return _ffi_api.ConvertSSA()  # type: ignore
+
+
 def InstrumentBoundCheckers():
     """Instruments bound checkers.
 
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index b3829529ee..9b47d84e6a 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -146,6 +146,11 @@ class IRConvertSSA final : public StmtExprMutator {
       bool made_change = false;
       for (const auto& [var, buffer] : func->buffer_map) {
         auto new_var = GetRemappedVar(var);
+        if (defined_.count(buffer->data.get())) {
+          redefines.emplace_back(this, buffer->data);
+        } else {
+          defined_.insert(buffer->data.get());
+        }
         auto new_buf = GetRemappedBuffer(buffer);
 
         made_change = made_change || !var.same_as(new_var) || !buffer.same_as(new_buf);
@@ -159,6 +164,10 @@ class IRConvertSSA final : public StmtExprMutator {
     }();
 
     auto attrs = [&]() -> DictAttrs {
+      if (!func->attrs.defined()) {
+        return DictAttrs();
+      }
+
       Map<String, ObjectRef> dict;
       bool made_change = false;
 
@@ -278,9 +287,14 @@ class IRConvertSSA final : public StmtExprMutator {
     // new buffer, pushing it onto the scoped stack of existing
     // buffers.  This will be popped when the new_buffer_var
     // redefinition is popped.
-    Buffer new_buf(new_buffer_var, buf->dtype, shape, strides, elem_offset, buf->name,
-                   buf->data_alignment, buf->offset_factor, buf->buffer_type, buf->axis_separators,
-                   buf->span);
+    Buffer new_buf = buf;
+    {
+      auto write_ptr = new_buf.CopyOnWrite();
+      write_ptr->data = new_buffer_var;
+      write_ptr->shape = shape;
+      write_ptr->strides = strides;
+      write_ptr->elem_offset = elem_offset;
+    }
     buffers.push_back(new_buf);
     return new_buf;
   }
@@ -702,6 +716,8 @@ Pass ConvertSSA() {
   return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {});
 }
 
+TVM_REGISTER_GLOBAL("tir.transform.ConvertSSA").set_body_typed(ConvertSSA);
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_convert_ssa.py b/tests/python/unittest/test_tir_transform_convert_ssa.py
new file mode 100644
index 0000000000..918fe6b907
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_convert_ssa.py
@@ -0,0 +1,253 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T, ir as I
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+    transform = tvm.tir.transform.ConvertSSA()
+
+
+class TestReuseInSequentialLetStmt(BaseBeforeAfter):
+    """De-dup sequential variable bindings"""
+
+    def before(self):
+        # Manually construct the PrimFunc body, as SSA violations are
+        # not valid TIR, and may not be expressible in future versions
+        # of TVMSCript.
+        var = tir.Var("var", "int32")
+        sequential_bindings = tir.SeqStmt(
+            [
+                tir.LetStmt(var, 16, tir.Evaluate(var)),
+                tir.LetStmt(var, 32, tir.Evaluate(var)),
+            ]
+        )
+        func = tir.PrimFunc([], sequential_bindings)
+
+        return func
+
+    def expected(self):
+        @T.prim_func
+        def func():
+            with T.LetStmt(T.int32(16)) as var1:
+                T.evaluate(var1)
+            with T.LetStmt(T.int32(32)) as var2:
+                T.evaluate(var2)
+
+        return func
+
+
+class TestReuseInNestedLetStmt(BaseBeforeAfter):
+    """De-dup nested bindings
+
+    Use of a variable with nested bindings is de-duplicated to refer
+    to the inner-most binding that contains the use site.
+    """
+
+    def before(self):
+        # Manually construct the PrimFunc body, as SSA violations are
+        # not valid TIR, and may not be expressible in future versions
+        # of TVMSCript.
+        var = tir.Var("var", "int32")
+        inner_let = tir.LetStmt(var, 16, tir.Evaluate(var))
+        outer_let = tir.LetStmt(
+            var,
+            32,
+            tir.SeqStmt(
+                [
+                    tir.Evaluate(var),
+                    inner_let,
+                    tir.Evaluate(var),
+                ]
+            ),
+        )
+        func = tir.PrimFunc([], outer_let)
+
+        return func
+
+    def expected(self):
+        @T.prim_func
+        def func():
+            with T.LetStmt(T.int32(32)) as outer:
+                T.evaluate(outer)
+                with T.LetStmt(T.int32(16)) as inner:
+                    T.evaluate(inner)
+                T.evaluate(outer)
+
+        return func
+
+
+class TestReusedVarAcrossModule(BaseBeforeAfter):
+    """De-duplicate Var bindings across entire module"""
+
+    def before(self):
+        @T.prim_func
+        def func():
+            with T.LetStmt(10) as var:
+                T.evaluate(var)
+
+        return tvm.IRModule({"func_a": func, "func_b": func})
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def func_a():
+                var = T.int32(10)
+                T.evaluate(var)
+
+            @T.prim_func
+            def func_b():
+                var = T.int32(10)
+                T.evaluate(var)
+
+        return mod
+
+
+class TestReusedParameter(BaseBeforeAfter):
+    """De-duplicate Var usage in parameters
+
+    In this test, the same `tir.Var` instance is used for the
+    parameter `n` in both functions.
+    """
+
+    def before(self):
+        @T.prim_func
+        def func(n: T.int32):
+            T.evaluate(n)
+
+        return tvm.IRModule({"func_a": func, "func_b": func})
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def func_a(n: T.int32):
+                T.evaluate(n)
+
+            @T.prim_func
+            def func_b(n: T.int32):
+                T.evaluate(n)
+
+        return mod
+
+
+class TestReusedBufferObj(BaseBeforeAfter):
+    """De-duplicate buffer usage across entire module"""
+
+    def before(self):
+        @T.prim_func
+        def func(a: T.handle("float32")):
+            A = T.Buffer(shape=1, dtype="float32", data=a)
+            T.evaluate(A[0])
+
+        return tvm.IRModule({"func_a": func, "func_b": func})
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def func_a(a: T.handle("float32")):
+                A = T.Buffer(shape=1, dtype="float32", data=a)
+                T.evaluate(A[0])
+
+            @T.prim_func
+            def func_b(a: T.handle("float32")):
+                A = T.Buffer(shape=1, dtype="float32", data=a)
+                T.evaluate(A[0])
+
+        return mod
+
+
+class TestReusedBufferParameter(BaseBeforeAfter):
+    """De-duplicate buffer_map across entire module"""
+
+    def before(self):
+        @T.prim_func
+        def func(A: T.Buffer(1, "float32")):
+            T.evaluate(A[0])
+
+        return tvm.IRModule({"func_a": func, "func_b": func})
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def func_a(A: T.Buffer(1, "float32")):
+                T.evaluate(A[0])
+
+            @T.prim_func
+            def func_b(A: T.Buffer(1, "float32")):
+                T.evaluate(A[0])
+
+        return mod
+
+
+def test_no_change_if_already_ssa():
+    """A module that is already SSA should be unchanged"""
+
+    @I.ir_module
+    class before:
+        @T.prim_func
+        def func(A: T.Buffer(1, "float32")):
+            T.evaluate(A[0])
+
+    after = tvm.tir.transform.ConvertSSA()(before)
+    tvm.ir.assert_structural_equal(before, after)
+    assert before.same_as(after)
+
+
+class TestDedupAutoBroadcastBuffer(BaseBeforeAfter):
+    """De-dup auto-broadcast buffers
+
+    Auto-broadcast buffers can define additional variables during the
+    `Buffer::Buffer` constructor for the strides.  This is intended to
+    be used for match buffers, where these variables are defined based
+    on the argument being passed in.
+
+    These additional variables can cause errors when copying a buffer
+    with the `Buffer::Buffer` constructor.  If a buffer has non-empty
+    shape, empty strides, and kAutoBroadcast type, then the resulting
+    buffer will have additional strides defined.  Such a buffer can
+    result from lowering of a scalar buffer, which will be flattened
+    to a shape of [1].
+
+    Previous implementations of ConvertSSA incorrectly handled this
+    case, resulting in undefined stride variables.
+    """
+
+    def _make_func(self):
+        @T.prim_func
+        def func(a: T.handle):
+            A = T.match_buffer(a, shape=(), dtype="float32", buffer_type="auto")
+            A[()] = 1.0
+
+        return tvm.lower(func)["main"]
+
+    def before(self):
+        func = self._make_func()
+        return tvm.IRModule({"func_a": func, "func_b": func})
+
+    def expected(self):
+        return tvm.IRModule({"func_a": self._make_func(), "func_b": self._make_func()})
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_ir_utils.py b/tests/python/unittest/test_tir_transform_ir_utils.py
deleted file mode 100644
index d2cae35161..0000000000
--- a/tests/python/unittest/test_tir_transform_ir_utils.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import pytest
-import tvm
-from tvm import tir, ir
-
-
-def test_convert_ssa():
-    dtype = "int32"
-    zero = tir.const(0)
-    nop = tir.Evaluate(zero)
-    var_type = ir.PointerType(ir.PrimType(dtype))
-    v = tir.Var("i1", var_type)
-    buf = tir.decl_buffer([16], dtype=dtype, data=v)
-    let = tir.LetStmt(v, v, nop)
-    load = tir.Evaluate(tir.BufferLoad(buf, [zero]))
-    seq = tir.SeqStmt([let, let, load])
-    func = tir.PrimFunc([], seq)
-    mod = tvm.IRModule({"main": func})
-    mod = tir.transform.InjectVirtualThread()(
-        mod
-    )  # Use pass InjectVirtualThread to invoke ConvertSSA
-
-
-if __name__ == "__main__":
-    tvm.testing.main()