You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2024/03/29 17:58:29 UTC

(tvm) branch main updated: [Codegen] Add check to disable invalid reinterpret (#16786)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 109804cc6a [Codegen] Add check to disable invalid reinterpret (#16786)
109804cc6a is described below

commit 109804cc6a8854953f761aa5575b02e33e8dbd9c
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Mar 29 10:58:23 2024 -0700

    [Codegen] Add check to disable invalid reinterpret (#16786)
    
    * [Codegen] Add check to disable invalid reinterpret
---
 src/target/source/codegen_c.cc                   |  9 +++++++--
 tests/python/codegen/test_target_codegen_cuda.py | 10 ++++++++++
 2 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index abb62f2faf..009fc1672a 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -672,10 +672,15 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
       this->PrintExpr(op->args[0], os);
       os << " == NULL)";
     } else if (op->op.same_as(builtin::reinterpret())) {
+      auto target_dtype = op->dtype;
+      auto source_dtype = op->args[0]->dtype;
+      CHECK_EQ(target_dtype.lanes() * target_dtype.bits(),
+               source_dtype.lanes() * source_dtype.bits())
+          << "reinterpret expects source and target to have the same number of bits";
       int ssa_scope = BeginScope();
-      std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
+      std::string rhs = SSAGetID(PrintExpr(op->args[0]), source_dtype);
       os << "(*(";
-      this->PrintType(op->dtype, os);
+      this->PrintType(target_dtype, os);
       os << " *)(&(" << rhs << ")))";
       EndScope(ssa_scope);
     } else if (op->op.same_as(builtin::isnan())) {
diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py
index 5fb7526b21..23ba0fc3ce 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -1116,5 +1116,15 @@ def test_cuda_thread_sync_inside_condition():
     tvm.build(mod, target="cuda")
 
 
+def test_invalid_reinterpret():
+    @T.prim_func
+    def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None:
+        for tx in T.thread_binding(4, "threadIdx.x"):
+            B[tx] = T.reinterpret("uint8", A[tx])
+
+    with pytest.raises(tvm.error.TVMError):
+        tvm.build(func, target="cuda")
+
+
 if __name__ == "__main__":
     tvm.testing.main()