You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sy...@apache.org on 2023/07/10 11:35:18 UTC
[tvm] branch main updated: [TIR] Call TVMBackendFreeWorkspace inside LetStmt (#15253)
This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 30d684216c [TIR] Call TVMBackendFreeWorkspace inside LetStmt (#15253)
30d684216c is described below
commit 30d684216c8cf0692ff5f7294050ad965541c7ad
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Jul 10 07:35:12 2023 -0400
[TIR] Call TVMBackendFreeWorkspace inside LetStmt (#15253)
* [TIR] Call TVMBackendFreeWorkspace inside LetStmt
Prior to this commit, the call to `TVMBackendFreeWorkspace` occurred
outside the `LetStmt` that defined the workspace pointer. While works
with current codegen, as the code produced for `LetStmt` does not
check for out-of-scope access, this access of an out-of-scope should
be avoided.
This commit updates `LowerTVMBuiltin` to produce the call to
`TVMBackendFreeWorkspace` at the end of the `LetStmt`'s body, rather
than just after the `LetStmt`.
* [TIR] Output AttrStmt "storage_alignment" inside the var binding
Prior to this commit, the `AttrStmt` providing the storage alignment
was placed outside the `LetStmt` that defines the variable. As a
result, the alignment assumption is never actually used, as
`CodeGenLLVM::VisitStmt_(const AttrStmtNode*)` only creates an
alignment assumption for in-scope variables.
This commit moves the storage alignment `AttrStmt` to be inside the
`LetStmt`, rather than outside.
---
src/tir/transforms/lower_tvm_builtin.cc | 54 +++++++++++++---------
.../test_tir_transform_lower_tvm_builtin.py | 23 ++++-----
2 files changed, 41 insertions(+), 36 deletions(-)
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index df7a885985..2868af0b07 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -249,24 +249,36 @@ class BuiltinLower : public StmtExprMutator {
ICHECK(device_id_) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
- Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
- throw_last_error),
- op->body});
- Stmt alloca = LetStmt(op->buffer_var,
- Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
- {cast(DataType::Int(32), device_type_.value()),
- cast(DataType::Int(32), device_id_.value()), total_bytes,
- IntImm(DataType::Int(32), op->dtype.code()),
- IntImm(DataType::Int(32), op->dtype.bits())}),
- body);
-
+ Stmt alloc_nullptr_check = IfThenElse(
+ Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error);
PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_.value()),
cast(DataType::Int(32), device_id_.value()), op->buffer_var});
Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
- body = SeqStmt({alloca, free_stmt});
+
+ Stmt body = op->body;
+ std::vector<Stmt> nest;
+ while (auto opt = body.as<DeclBuffer>()) {
+ auto decl = opt.value();
+ body = decl->body;
+ decl.CopyOnWrite()->body = Evaluate(0);
+ nest.push_back(decl);
+ }
+
+ body = SeqStmt::Flatten(body, free_stmt);
+ body = MergeNest(nest, body);
+ body = SeqStmt::Flatten(alloc_nullptr_check, body);
+
body = AttrStmt(op->buffer_var, attr::storage_alignment,
make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
+ body = LetStmt(op->buffer_var,
+ Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
+ {cast(DataType::Int(32), device_type_.value()),
+ cast(DataType::Int(32), device_id_.value()), total_bytes,
+ IntImm(DataType::Int(32), op->dtype.code()),
+ IntImm(DataType::Int(32), op->dtype.bits())}),
+ body);
+
return body;
}
@@ -569,9 +581,15 @@ class BuiltinLower : public StmtExprMutator {
ICHECK(device_id_) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));
+ PrimExpr storage_scope = call->args[0];
+ Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(),
+ {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(),
+ storage_scope, let->var});
+ Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
+
Stmt body = SeqStmt(
{IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error),
- let->body});
+ let->body, free_stmt});
DataType dtype =
let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;
@@ -593,15 +611,7 @@ class BuiltinLower : public StmtExprMutator {
Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args);
Stmt alloca = LetStmt(let->var, call_packed, body);
-
- PrimExpr storage_scope = call->args[0];
- Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(),
- {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(),
- storage_scope, let->var});
-
- Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
- body = SeqStmt({alloca, free_stmt});
- return body;
+ return alloca;
}
private:
diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
index cf2e3f045b..21db36d1f9 100644
--- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
@@ -14,9 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import tvm
+import tvm.testing
+
from tvm import te
from tvm.script import tir as T
+
import numpy as np
@@ -202,14 +206,6 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter):
This test validates the current behavior of LowerTVMBuiltin. This
unit test may be improved in the future by addressing:
- - The AttrStmt for "storage_alignment" occurs outside the LetStmt
- that defines the pointer, which is currently required by
- CodeGenLLVM. This fails to match when `map_free_vars=False`
- (default), because the first occurrence is undefined.
-
- - The call to TVMBackendFreeWorkspace uses the allocated pointer,
- but occurs outside the LetStmt.
-
- TVMScript always produces "handle" dtype for
`T.tvm_throw_last_error`, while LowerTVMBuiltin outputs "int32"
dtype.
@@ -227,13 +223,12 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter):
def expected():
T.func_attr({"target": T.target("llvm")})
- ptr = T.handle("float32", "global")
+ ptr: T.handle("float32") = T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32)
T.attr(ptr, "storage_alignment", 64)
- with T.LetStmt(T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32), var=ptr):
- if T.isnullptr(ptr):
- T.Call("int32", "tir.tvm_throw_last_error", [])
- buf = T.decl_buffer((16,), data=ptr)
- buf[0] = T.float32(0)
+ if T.isnullptr(ptr):
+ T.Call("int32", "tir.tvm_throw_last_error", [])
+ buf = T.decl_buffer((16,), data=ptr)
+ buf[0] = T.float32(0)
if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0:
T.Call("int32", "tir.tvm_throw_last_error", [])