You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/06 13:30:02 UTC

[tvm] branch unity updated: [Fix] symbolic thread extent program compilation (#14516)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8474255a04 [Fix] symbolic thread extent program compilation (#14516)
8474255a04 is described below

commit 8474255a04e241c7edb3faaafff3a9fbb7a83052
Author: Bohan Hou <32...@users.noreply.github.com>
AuthorDate: Thu Apr 6 06:29:55 2023 -0700

    [Fix] symbolic thread extent program compilation (#14516)
    
    Fix when compiling a program with symbolic thread extent, which is important to dyn kernels.
---
 src/tir/transforms/memhammer_lower_auto_copy.cc | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/memhammer_lower_auto_copy.cc b/src/tir/transforms/memhammer_lower_auto_copy.cc
index 1446dca308..67badf4521 100644
--- a/src/tir/transforms/memhammer_lower_auto_copy.cc
+++ b/src/tir/transforms/memhammer_lower_auto_copy.cc
@@ -750,7 +750,9 @@ class ThreadExtentCollector : public StmtVisitor {
   }
   void VisitStmt_(const ForNode* op) final {
     if (op->thread_binding.defined() && op->thread_binding.value()->iter_type == kThreadIndex) {
-      thread_extent_.Set(op->thread_binding.value()->thread_tag, Downcast<Integer>(op->extent));
+      if (const auto* extent = op->extent.as<IntImmNode>()) {
+        thread_extent_.Set(op->thread_binding.value()->thread_tag, GetRef<Integer>(extent));
+      }
     }
     StmtVisitor::VisitStmt_(op);
   }