You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/01/07 04:21:49 UTC

[tvm] branch main updated: [TIR] Fix dtype mismatch error due to LetStmt (#13710)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 088bc118c7 [TIR] Fix dtype mismatch error due to LetStmt (#13710)
088bc118c7 is described below

commit 088bc118c7a0abd263b634dc88be59813652251c
Author: masahi <ma...@gmail.com>
AuthorDate: Sat Jan 7 13:21:42 2023 +0900

    [TIR] Fix dtype mismatch error due to LetStmt (#13710)
    
    * [TIR] Fix dtype mismatch error due to LetStmt
    
    * add comment
    
    * improve letstmt visitor
    
    * remove SubstituteWithDataTypeLegalization
    
    * consolidate vmap look up logic in the base class
---
 include/tvm/tir/data_type_rewriter.h             |  5 ++-
 src/tir/ir/data_type_rewriter.cc                 | 43 ++++++++++++++----
 src/tir/transforms/narrow_datatype.cc            | 11 +----
 tests/python/unittest/test_te_create_primfunc.py | 55 +++++++++++++++++++++++-
 4 files changed, 95 insertions(+), 19 deletions(-)

diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h
index bf90aaedfe..5f72f75ede 100644
--- a/include/tvm/tir/data_type_rewriter.h
+++ b/include/tvm/tir/data_type_rewriter.h
@@ -53,6 +53,8 @@ class DataTypeLegalizer : public StmtExprMutator {
   Stmt VisitStmt_(const AttrStmtNode* op) override;
   Stmt VisitStmt_(const BlockRealizeNode* op) override;
   Stmt VisitStmt_(const BlockNode* op) override;
+  Stmt VisitStmt_(const LetStmtNode* op) override;
+  PrimExpr VisitExpr_(const VarNode* op) override;
   PrimExpr VisitExpr_(const SelectNode* op) override;
   PrimExpr VisitExpr_(const RampNode* op) override;
   PrimExpr VisitExpr_(const AddNode* op) override;
@@ -79,6 +81,8 @@ class DataTypeLegalizer : public StmtExprMutator {
   // a map from IterVar before rewrite to that after rewrite,
   // ensures one old IterVar maps to exactly one new IterVar
   std::unordered_map<const IterVarNode*, IterVar> ivmap_;
+  // a map from original vars to ones with new dtype
+  std::unordered_map<const VarNode*, Var> var_remap_;
 };
 
 /*!
@@ -123,7 +127,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
   // indicator of condition
   bool is_condition_{false};
 
-  Map<Var, Var> var_remap_;
   Map<Buffer, Buffer> buffer_remap_;
 };
 
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 1c61f0bf15..f0f0d84644 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -107,6 +107,35 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
   return StmtExprMutator::VisitStmt_(op);
 }
 
+Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) {
+  PrimExpr value = this->VisitExpr(op->value);
+  auto new_var = op->var.copy_with_dtype(value.dtype());
+
+  if (value.dtype() != op->var->dtype) {
+    var_remap_[op->var.get()] = new_var;
+  }
+
+  Stmt new_body = this->VisitStmt(op->body);
+
+  if (value.same_as(op->value) && new_body.same_as(op->body)) {
+    return GetRef<Stmt>(op);
+  } else if (value.dtype() == op->var->dtype) {
+    auto n = CopyOnWrite(op);
+    n->value = std::move(value);
+    n->body = std::move(new_body);
+    return Stmt(n);
+  } else {
+    return LetStmt(new_var, value, new_body, op->span);
+  }
+}
+
+PrimExpr DataTypeLegalizer::VisitExpr_(const VarNode* op) {
+  if (auto it = var_remap_.find(op); it != var_remap_.end()) {
+    return it->second;
+  }
+  return GetRef<Var>(op);
+}
+
 PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) {
   PrimExpr condition = this->VisitExpr(op->condition);
   PrimExpr true_value = this->VisitExpr(op->true_value);
@@ -397,6 +426,9 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
 
   Buffer new_buffer = GetRemappedBuffer(op->buffer);
   auto value = this->VisitExpr(op->value);
+  if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
+    value = cast(new_buffer->dtype, value);
+  }
   auto indices = VisitIndices(op->indices);
 
   if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) ||
@@ -535,15 +567,10 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) {
 }
 
 PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
-  if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
-    return (*it).second;
-  }
-  if (is_enabled_ && op->dtype != target_data_type_) {
-    Var new_var = GetRef<Var>(op).copy_with_dtype(target_data_type_);
-    var_remap_.Set(GetRef<Var>(op), new_var);
-    return std::move(new_var);
+  if (is_enabled_ && op->dtype != target_data_type_ && !var_remap_.count(op)) {
+    var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
   }
-  return GetRef<PrimExpr>(op);
+  return DataTypeLegalizer::VisitExpr_(op);
 }
 
 PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index 2f116a0229..e1dc2f5bf1 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -233,12 +233,8 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
   }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
-    if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
-      return (*it).second;
-    } else if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
-      Var v = Var(op->name_hint, visitor_.vmap[op]);
-      var_remap_.Set(GetRef<Var>(op), v);
-      return v;
+    if (auto it = visitor_.vmap.find(op); !var_remap_.count(op) && it != visitor_.vmap.end()) {
+      var_remap_[op] = Var(op->name_hint, it->second);
     }
     return Parent::VisitExpr_(op);
   }
@@ -266,9 +262,6 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
  private:
   // the internal visitor to deduce the narrowed dtype
   DataTypeVisitor visitor_;
-  // a map from Var before rewrite to that after rewrite,
-  // ensures one old Var maps to exactly one new Var
-  std::unordered_map<const VarNode*, Var> vmap_;
 };
 
 Stmt NarrowDataType(Stmt stmt, int target_bits) {
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index 271e0a339c..c13ede0831 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -18,7 +18,7 @@
 import numpy as np
 import tvm
 import tvm.testing
-from tvm import te, tir, topi
+from tvm import te, tir, topi, relay
 from tvm.script import tir as T
 import pytest
 
@@ -636,5 +636,58 @@ def test_reshape():
     _check_workload(te_reshape, tir_reshape, index_dtype_override="int64")
 
 
+@T.prim_func
+def argmax_expected(
+    p0: T.Buffer[(T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "uint8"],
+    p0_red: T.Buffer[(T.int64(1), T.int64(56), T.int64(56)), "int32"],
+):
+    T.func_attr({"global_symbol": "main", "tir.noalias": True})
+    p0_red_temp_v0 = T.alloc_buffer([T.int64(1), T.int64(56), T.int64(56)], dtype="int32")
+    p0_red_temp_v1 = T.alloc_buffer([T.int64(1), T.int64(56), T.int64(56)], dtype="uint8")
+    for ax0, ax1, ax2, k1 in T.grid(T.int64(1), T.int64(56), T.int64(56), T.int64(64)):
+        with T.block("p0_red_temp"):
+            v_ax0, v_ax1, v_ax2, v_k1 = T.axis.remap("SSSR", [ax0, ax1, ax2, k1])
+            T.reads(p0[v_ax0, v_k1, v_ax1, v_ax2])
+            T.writes(p0_red_temp_v0[v_ax0, v_ax1, v_ax2], p0_red_temp_v1[v_ax0, v_ax1, v_ax2])
+            with T.init():
+                p0_red_temp_v0[v_ax0, v_ax1, v_ax2] = -1
+                p0_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.uint8(0)
+            v_p0_red_temp_v0: T.int64 = T.Select(
+                p0_red_temp_v1[v_ax0, v_ax1, v_ax2] > p0[v_ax0, v_k1, v_ax1, v_ax2]
+                or (
+                    p0_red_temp_v1[v_ax0, v_ax1, v_ax2] == p0[v_ax0, v_k1, v_ax1, v_ax2]
+                    and T.Cast("int64", p0_red_temp_v0[v_ax0, v_ax1, v_ax2]) < v_k1
+                ),
+                T.Cast("int64", p0_red_temp_v0[v_ax0, v_ax1, v_ax2]),
+                v_k1,
+            )
+            v_p0_red_temp_v1: T.uint8 = T.Select(
+                p0_red_temp_v1[v_ax0, v_ax1, v_ax2] > p0[v_ax0, v_k1, v_ax1, v_ax2],
+                p0_red_temp_v1[v_ax0, v_ax1, v_ax2],
+                p0[v_ax0, v_k1, v_ax1, v_ax2],
+            )
+            p0_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.Cast("int32", v_p0_red_temp_v0)
+            p0_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_p0_red_temp_v1
+    for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(56), T.int64(56)):
+        with T.block("p0_red"):
+            v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+            T.reads(p0_red_temp_v0[v_ax0, v_ax1, v_ax2])
+            T.writes(p0_red[v_ax0, v_ax1, v_ax2])
+            p0_red[v_ax0, v_ax1, v_ax2] = p0_red_temp_v0[v_ax0, v_ax1, v_ax2]
+
+
+def test_argmax():
+    data = relay.var("data", shape=(1, 64, 56, 56), dtype="uint8")
+    mod = tvm.IRModule.from_expr(relay.argmax(data, axis=1))
+
+    target = tvm.target.Target("llvm")
+
+    opt_mod, _ = relay.optimize(mod, params={}, target=target)
+
+    prim_func = relay.backend.te_compiler.lower_to_primfunc(opt_mod["main"].body.op, target)
+
+    tvm.ir.assert_structural_equal(prim_func, argmax_expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()