You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/13 11:44:42 UTC

[tvm] 02/02: [Unity] Fix ForceNarrowI32 with pod arguments (#14605)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 4ba82881f5fef9cd0a3efc005e09b6f6d6e6d470
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Thu Apr 13 03:11:21 2023 -0400

    [Unity] Fix ForceNarrowI32 with pod arguments (#14605)
    
    This PR fixes ForceNarrowI32 for functions that comes
    with pod arguments.
    
    Test cases are added
    
    Co-authored-by: Junru Shao <ju...@apache.org>
---
 include/tvm/tir/data_type_rewriter.h               |  9 ++--
 src/tir/ir/data_type_rewriter.cc                   | 63 +++++++++++++++++++---
 src/tir/transforms/force_narrow_index_to_i32.cc    | 10 ++++
 ...st_transform_legalize_ops_search_statistical.py | 11 ++--
 .../relax/test_transform_legalize_ops_unary.py     |  2 +-
 ...test_tir_transform_force_narrow_index_to_i32.py | 21 ++++++++
 6 files changed, 99 insertions(+), 17 deletions(-)

diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h
index 5f72f75ede..76b53b18ca 100644
--- a/include/tvm/tir/data_type_rewriter.h
+++ b/include/tvm/tir/data_type_rewriter.h
@@ -74,6 +74,7 @@ class DataTypeLegalizer : public StmtExprMutator {
   PrimExpr VisitExpr_(const GENode* op) override;
   PrimExpr VisitExpr_(const CallNode* op) override;
   PrimExpr VisitExpr_(const CastNode* op) override;
+  PrimExpr VisitExpr_(const LetNode* op) override;
 
   using StmtExprMutator::VisitExpr_;
   using StmtExprMutator::VisitStmt_;
@@ -115,6 +116,8 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
   PrimExpr VisitExpr_(const GTNode* op) override;
   PrimExpr VisitExpr_(const GENode* op) override;
   PrimExpr VisitExpr_(const CallNode* op) override;
+  PrimExpr VisitExpr_(const SelectNode* op) override;
+
   Stmt VisitStmt_(const ForNode* op) override;
 
   Buffer VisitBuffer(const Buffer& buffer);
@@ -146,9 +149,9 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
   using Parent = IndexDataTypeRewriter;
   using Parent::VisitExpr_;
   using Parent::VisitStmt_;
-  PrimExpr VisitExpr_(const IntImmNode* op) final;
-  PrimExpr VisitExpr_(const VarNode* op) final;
-  PrimExpr VisitExpr_(const CastNode* op) final;
+  PrimExpr VisitExpr_(const IntImmNode* op) override;
+  PrimExpr VisitExpr_(const VarNode* op) override;
+  PrimExpr VisitExpr_(const CastNode* op) override;
 
   DataType target_data_type_ = DataType::Int(64);
 };
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 8da7cfdd5b..97ad1b7cc3 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -107,25 +107,39 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
   return StmtExprMutator::VisitStmt_(op);
 }
 
+PrimExpr DataTypeLegalizer::VisitExpr_(const LetNode* op) {
+  PrimExpr value = this->VisitExpr(op->value);
+  Var var = op->var;
+
+  if (value.dtype() != op->var->dtype) {
+    var = op->var.copy_with_dtype(value.dtype());
+    var_remap_[op->var.get()] = var;
+  }
+
+  PrimExpr new_body = this->VisitExpr(op->body);
+
+  if (value.same_as(op->value) && new_body.same_as(op->body)) {
+    return GetRef<PrimExpr>(op);
+  } else {
+    return Let(var, value, new_body, op->span);
+  }
+}
+
 Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) {
   PrimExpr value = this->VisitExpr(op->value);
-  auto new_var = op->var.copy_with_dtype(value.dtype());
+  Var var = op->var;
 
   if (value.dtype() != op->var->dtype) {
-    var_remap_[op->var.get()] = new_var;
+    var = op->var.copy_with_dtype(value.dtype());
+    var_remap_[op->var.get()] = var;
   }
 
   Stmt new_body = this->VisitStmt(op->body);
 
   if (value.same_as(op->value) && new_body.same_as(op->body)) {
     return GetRef<Stmt>(op);
-  } else if (value.dtype() == op->var->dtype) {
-    auto n = CopyOnWrite(op);
-    n->value = std::move(value);
-    n->body = std::move(new_body);
-    return Stmt(n);
   } else {
-    return LetStmt(new_var, value, new_body, op->span);
+    return LetStmt(var, value, new_body, op->span);
   }
 }
 
@@ -542,6 +556,26 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) {
   return Parent::VisitExpr_(op);
 }
 
+PrimExpr IndexDataTypeRewriter::VisitExpr_(const SelectNode* op) {
+  bool is_condition = true;
+  std::swap(is_condition_, is_condition);
+  PrimExpr condition = this->VisitExpr(op->condition);
+  std::swap(is_condition_, is_condition);
+  PrimExpr true_value = this->VisitExpr(op->true_value);
+  PrimExpr false_value = this->VisitExpr(op->false_value);
+
+  if (condition.same_as(op->condition) && true_value.same_as(op->true_value) &&
+      false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) {
+    return GetRef<PrimExpr>(op);
+  } else {
+    int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits());
+    DataType dtype = true_value.dtype().with_bits(bits);
+    if (true_value.dtype() != dtype) true_value = cast(dtype, true_value);
+    if (false_value.dtype() != dtype) false_value = cast(dtype, false_value);
+    return Select(condition, true_value, false_value);
+  }
+}
+
 #undef TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH
 
 IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type)
@@ -552,7 +586,20 @@ PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
   for (const auto& [var, buffer] : func->buffer_map) {
     new_buffer_map.Set(var, VisitBuffer(buffer));
   }
+  // remap params
+  bool is_enabled = true;
+  std::swap(is_enabled_, is_enabled);
+  Array<Var> params = func->params.Map([this](Var param) {
+    if (param.dtype().is_int()) {
+      return Downcast<Var>(this->VisitExpr(param));
+    } else {
+      return param;
+    }
+  });
+  std::swap(is_enabled_, is_enabled);
+
   PrimFuncNode* new_func = func.CopyOnWrite();
+  new_func->params = std::move(params);
   new_func->buffer_map = std::move(new_buffer_map);
   new_func->body = VisitStmt(std::move(new_func->body));
   return func;
diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc
index 70dc554e12..c559360bf5 100644
--- a/src/tir/transforms/force_narrow_index_to_i32.cc
+++ b/src/tir/transforms/force_narrow_index_to_i32.cc
@@ -24,6 +24,7 @@
  */
 
 #include <tvm/tir/data_type_rewriter.h>
+#include <tvm/tir/op.h>
 #include <tvm/tir/transform.h>
 
 namespace tvm {
@@ -48,6 +49,15 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer {
   explicit Int32DTypeNarrower(PrimFunc func)
       : IndexDataTypeNormalizer(DataType::Int(32)), func_(std::move(func)) {}
 
+  PrimExpr VisitExpr_(const IntImmNode* op) final {
+    // ignore the enabled condition and always rewrite i64
+    if (op->dtype == DataType::Int(64)) {
+      ICHECK_LE(op->value, Downcast<Integer>(max_value(target_data_type_))->value);
+      return IntImm(DataType::Int(32), op->value);
+    }
+    return GetRef<IntImm>(op);
+  }
+
   Stmt VisitStmt_(const BlockNode* block) final {
     Block block_ = Downcast<Block>(IndexDataTypeNormalizer::VisitStmt_(block));
     // Check if the allocated integer buffers have dtype other than int32.
diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 886e880c31..b612bbbae5 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -16,10 +16,11 @@
 # under the License.
 
 import tvm
-from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T, ir as I
 import tvm.testing
-
+from tvm.relax.transform import LegalizeOps
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 
 ##################### Search #####################
 
@@ -48,7 +49,7 @@ def test_where():
                     ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
                     T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
                     T.writes(T_where[ax0, ax1, ax2])
-                    T_where[ax0, ax1, ax2] = T.Select(0 < T.Cast("int32", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
+                    T_where[ax0, ax1, ax2] = T.Select(T.int64(0) < T.Cast("int64", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
     # fmt: on
 
     mod = LegalizeOps()(Where)
@@ -92,7 +93,7 @@ def test_where_symbolic():
                     ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
                     T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
                     T.writes(T_where[ax0, ax1, ax2])
-                    T_where[ax0, ax1, ax2] = T.Select(0 < T.Cast("int32", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
+                    T_where[ax0, ax1, ax2] = T.Select(T.int64(0) < T.Cast("int64", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
     # fmt: on
 
     mod = LegalizeOps()(Where)
diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py b/tests/python/relax/test_transform_legalize_ops_unary.py
index 1a5d474c3e..398fffdbdb 100644
--- a/tests/python/relax/test_transform_legalize_ops_unary.py
+++ b/tests/python/relax/test_transform_legalize_ops_unary.py
@@ -866,7 +866,7 @@ def test_sign_int():
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(rxplaceholder[v_ax0, v_ax1])
                     T.writes(T_sign[v_ax0, v_ax1])
-                    T_sign[v_ax0, v_ax1] = T.Select(0 < rxplaceholder[v_ax0, v_ax1], 1, T.Select(rxplaceholder[v_ax0, v_ax1] < 0, -1, 0))
+                    T_sign[v_ax0, v_ax1] = T.Select(T.int64(0) < T.Cast("int64", rxplaceholder[v_ax0, v_ax1]), 1, T.Select(T.Cast("int64", rxplaceholder[v_ax0, v_ax1]) < T.int64(0), -1, 0))
 
         @R.function
         def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"):
diff --git a/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
index f275d438a7..6c12229b5c 100644
--- a/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
+++ b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
@@ -216,5 +216,26 @@ def test_fail_on_buffer_map():
         tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"]
 
 
+def test_pod_params_and_select():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(
+            A: T.Buffer((T.int64(4),), "float32"), B: T.Buffer((T.int64(4),), "float32"), n: T.int64
+        ):
+            for i in T.serial(T.int64(4)):
+                B[i] = T.Select(T.int64(1) <= i, A[i + n], T.Cast("float32", i))
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32"), n: T.int32):
+            for i in range(4):
+                B[i] = T.Select(1 <= i, A[i + n], T.Cast("float32", i))
+
+    after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before)
+    tvm.ir.assert_structural_equal(Expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()