You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/29 17:58:29 UTC
(tvm) branch main updated: [Codegen] Add check to disable invalid reinterpret (#16786)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 109804cc6a [Codegen] Add check to disable invalid reinterpret (#16786)
109804cc6a is described below
commit 109804cc6a8854953f761aa5575b02e33e8dbd9c
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Mar 29 10:58:23 2024 -0700
[Codegen] Add check to disable invalid reinterpret (#16786)
* [Codegen] Add check to disable invalid reinterpret
---
src/target/source/codegen_c.cc | 9 +++++++--
tests/python/codegen/test_target_codegen_cuda.py | 10 ++++++++++
2 files changed, 17 insertions(+), 2 deletions(-)
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index abb62f2faf..009fc1672a 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -672,10 +672,15 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
this->PrintExpr(op->args[0], os);
os << " == NULL)";
} else if (op->op.same_as(builtin::reinterpret())) {
+ auto target_dtype = op->dtype;
+ auto source_dtype = op->args[0]->dtype;
+ CHECK_EQ(target_dtype.lanes() * target_dtype.bits(),
+ source_dtype.lanes() * source_dtype.bits())
+ << "reinterpret expects source and target to have the same number of bits";
int ssa_scope = BeginScope();
- std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
+ std::string rhs = SSAGetID(PrintExpr(op->args[0]), source_dtype);
os << "(*(";
- this->PrintType(op->dtype, os);
+ this->PrintType(target_dtype, os);
os << " *)(&(" << rhs << ")))";
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::isnan())) {
diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py
index 5fb7526b21..23ba0fc3ce 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -1116,5 +1116,15 @@ def test_cuda_thread_sync_inside_condition():
tvm.build(mod, target="cuda")
+def test_invalid_reinterpret():
+ @T.prim_func
+ def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None:
+ for tx in T.thread_binding(4, "threadIdx.x"):
+ B[tx] = T.reinterpret("uint8", A[tx])
+
+ with pytest.raises(tvm.error.TVMError):
+ tvm.build(func, target="cuda")
+
+
if __name__ == "__main__":
tvm.testing.main()