You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/07/10 11:35:18 UTC

[tvm] branch main updated: [TIR] Call TVMBackendFreeWorkspace inside LetStmt (#15253)

This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 30d684216c [TIR] Call TVMBackendFreeWorkspace inside LetStmt (#15253)
30d684216c is described below

commit 30d684216c8cf0692ff5f7294050ad965541c7ad
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Jul 10 07:35:12 2023 -0400

    [TIR] Call TVMBackendFreeWorkspace inside LetStmt (#15253)
    
    * [TIR] Call TVMBackendFreeWorkspace inside LetStmt
    
    Prior to this commit, the call to `TVMBackendFreeWorkspace` occurred
    outside the `LetStmt` that defined the workspace pointer.  While works
    with current codegen, as the code produced for `LetStmt` does not
    check for out-of-scope access, this access of an out-of-scope should
    be avoided.
    
    This commit updates `LowerTVMBuiltin` to produce the call to
    `TVMBackendFreeWorkspace` at the end of the `LetStmt`'s body, rather
    than just after the `LetStmt`.
    
    * [TIR] Output AttrStmt "storage_alignment" inside the var binding
    
    Prior to this commit, the `AttrStmt` providing the storage alignment
    was placed outside the `LetStmt` that defines the variable. As a
    result, the alignment assumption is never actually used, as
    `CodeGenLLVM::VisitStmt_(const AttrStmtNode*)` only creates an
    alignment assumption for in-scope variables.
    
    This commit moves the storage alignment `AttrStmt` to be inside the
    `LetStmt`, rather than outside.
---
 src/tir/transforms/lower_tvm_builtin.cc            | 54 +++++++++++++---------
 .../test_tir_transform_lower_tvm_builtin.py        | 23 ++++-----
 2 files changed, 41 insertions(+), 36 deletions(-)

diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index df7a885985..2868af0b07 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -249,24 +249,36 @@ class BuiltinLower : public StmtExprMutator {
     ICHECK(device_id_) << "Unknown device id in current IR";
     Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
 
-    Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
-                                    throw_last_error),
-                         op->body});
-    Stmt alloca = LetStmt(op->buffer_var,
-                          Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
-                               {cast(DataType::Int(32), device_type_.value()),
-                                cast(DataType::Int(32), device_id_.value()), total_bytes,
-                                IntImm(DataType::Int(32), op->dtype.code()),
-                                IntImm(DataType::Int(32), op->dtype.bits())}),
-                          body);
-
+    Stmt alloc_nullptr_check = IfThenElse(
+        Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error);
     PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
                             {cast(DataType::Int(32), device_type_.value()),
                              cast(DataType::Int(32), device_id_.value()), op->buffer_var});
     Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
-    body = SeqStmt({alloca, free_stmt});
+
+    Stmt body = op->body;
+    std::vector<Stmt> nest;
+    while (auto opt = body.as<DeclBuffer>()) {
+      auto decl = opt.value();
+      body = decl->body;
+      decl.CopyOnWrite()->body = Evaluate(0);
+      nest.push_back(decl);
+    }
+
+    body = SeqStmt::Flatten(body, free_stmt);
+    body = MergeNest(nest, body);
+    body = SeqStmt::Flatten(alloc_nullptr_check, body);
+
     body = AttrStmt(op->buffer_var, attr::storage_alignment,
                     make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
+    body = LetStmt(op->buffer_var,
+                   Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
+                        {cast(DataType::Int(32), device_type_.value()),
+                         cast(DataType::Int(32), device_id_.value()), total_bytes,
+                         IntImm(DataType::Int(32), op->dtype.code()),
+                         IntImm(DataType::Int(32), op->dtype.bits())}),
+                   body);
+
     return body;
   }
 
@@ -569,9 +581,15 @@ class BuiltinLower : public StmtExprMutator {
     ICHECK(device_id_) << "Unknown device id in current IR";
     Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
 
+    PrimExpr storage_scope = call->args[0];
+    Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(),
+                        {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(),
+                         storage_scope, let->var});
+    Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
+
     Stmt body = SeqStmt(
         {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error),
-         let->body});
+         let->body, free_stmt});
 
     DataType dtype =
         let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;
@@ -593,15 +611,7 @@ class BuiltinLower : public StmtExprMutator {
 
     Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args);
     Stmt alloca = LetStmt(let->var, call_packed, body);
-
-    PrimExpr storage_scope = call->args[0];
-    Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(),
-                        {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(),
-                         storage_scope, let->var});
-
-    Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
-    body = SeqStmt({alloca, free_stmt});
-    return body;
+    return alloca;
   }
 
  private:
diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
index cf2e3f045b..21db36d1f9 100644
--- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
@@ -14,9 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import tvm
+import tvm.testing
+
 from tvm import te
 from tvm.script import tir as T
+
 import numpy as np
 
 
@@ -202,14 +206,6 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter):
     This test validates the current behavior of LowerTVMBuiltin.  This
     unit test may be improved in the future by addressing:
 
-    - The AttrStmt for "storage_alignment" occurs outside the LetStmt
-      that defines the pointer, which is currently required by
-      CodeGenLLVM.  This fails to match when `map_free_vars=False`
-      (default), because the first occurrence is undefined.
-
-    - The call to TVMBackendFreeWorkspace uses the allocated pointer,
-      but occurs outside the LetStmt.
-
     - TVMScript always produces "handle" dtype for
       `T.tvm_throw_last_error`, while LowerTVMBuiltin outputs "int32"
       dtype.
@@ -227,13 +223,12 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter):
 
     def expected():
         T.func_attr({"target": T.target("llvm")})
-        ptr = T.handle("float32", "global")
+        ptr: T.handle("float32") = T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32)
         T.attr(ptr, "storage_alignment", 64)
-        with T.LetStmt(T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32), var=ptr):
-            if T.isnullptr(ptr):
-                T.Call("int32", "tir.tvm_throw_last_error", [])
-            buf = T.decl_buffer((16,), data=ptr)
-            buf[0] = T.float32(0)
+        if T.isnullptr(ptr):
+            T.Call("int32", "tir.tvm_throw_last_error", [])
+        buf = T.decl_buffer((16,), data=ptr)
+        buf[0] = T.float32(0)
         if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0:
             T.Call("int32", "tir.tvm_throw_last_error", [])