You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/01/07 04:21:49 UTC
[tvm] branch main updated: [TIR] Fix dtype mismatch error due to LetStmt (#13710)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 088bc118c7 [TIR] Fix dtype mismatch error due to LetStmt (#13710)
088bc118c7 is described below
commit 088bc118c7a0abd263b634dc88be59813652251c
Author: masahi <ma...@gmail.com>
AuthorDate: Sat Jan 7 13:21:42 2023 +0900
[TIR] Fix dtype mismatch error due to LetStmt (#13710)
* [TIR] Fix dtype mismatch error due to LetStmt
* add comment
* improve letstmt visitor
* remove SubstituteWithDataTypeLegalization
* consolidate vmap look up logic in the base class
---
include/tvm/tir/data_type_rewriter.h | 5 ++-
src/tir/ir/data_type_rewriter.cc | 43 ++++++++++++++----
src/tir/transforms/narrow_datatype.cc | 11 +----
tests/python/unittest/test_te_create_primfunc.py | 55 +++++++++++++++++++++++-
4 files changed, 95 insertions(+), 19 deletions(-)
diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h
index bf90aaedfe..5f72f75ede 100644
--- a/include/tvm/tir/data_type_rewriter.h
+++ b/include/tvm/tir/data_type_rewriter.h
@@ -53,6 +53,8 @@ class DataTypeLegalizer : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) override;
Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
+ Stmt VisitStmt_(const LetStmtNode* op) override;
+ PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const SelectNode* op) override;
PrimExpr VisitExpr_(const RampNode* op) override;
PrimExpr VisitExpr_(const AddNode* op) override;
@@ -79,6 +81,8 @@ class DataTypeLegalizer : public StmtExprMutator {
// a map from IterVar before rewrite to that after rewrite,
// ensures one old IterVar maps to exactly one new IterVar
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
+ // a map from original vars to ones with new dtype
+ std::unordered_map<const VarNode*, Var> var_remap_;
};
/*!
@@ -123,7 +127,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
// indicator of condition
bool is_condition_{false};
- Map<Var, Var> var_remap_;
Map<Buffer, Buffer> buffer_remap_;
};
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 1c61f0bf15..f0f0d84644 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -107,6 +107,35 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
return StmtExprMutator::VisitStmt_(op);
}
+Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) {
+ PrimExpr value = this->VisitExpr(op->value);
+ auto new_var = op->var.copy_with_dtype(value.dtype());
+
+ if (value.dtype() != op->var->dtype) {
+ var_remap_[op->var.get()] = new_var;
+ }
+
+ Stmt new_body = this->VisitStmt(op->body);
+
+ if (value.same_as(op->value) && new_body.same_as(op->body)) {
+ return GetRef<Stmt>(op);
+ } else if (value.dtype() == op->var->dtype) {
+ auto n = CopyOnWrite(op);
+ n->value = std::move(value);
+ n->body = std::move(new_body);
+ return Stmt(n);
+ } else {
+ return LetStmt(new_var, value, new_body, op->span);
+ }
+}
+
+PrimExpr DataTypeLegalizer::VisitExpr_(const VarNode* op) {
+ if (auto it = var_remap_.find(op); it != var_remap_.end()) {
+ return it->second;
+ }
+ return GetRef<Var>(op);
+}
+
PrimExpr DataTypeLegalizer::VisitExpr_(const SelectNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
PrimExpr true_value = this->VisitExpr(op->true_value);
@@ -397,6 +426,9 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
Buffer new_buffer = GetRemappedBuffer(op->buffer);
auto value = this->VisitExpr(op->value);
+ if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
+ value = cast(new_buffer->dtype, value);
+ }
auto indices = VisitIndices(op->indices);
if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) ||
@@ -535,15 +567,10 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const IntImmNode* op) {
}
PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
- if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
- return (*it).second;
- }
- if (is_enabled_ && op->dtype != target_data_type_) {
- Var new_var = GetRef<Var>(op).copy_with_dtype(target_data_type_);
- var_remap_.Set(GetRef<Var>(op), new_var);
- return std::move(new_var);
+ if (is_enabled_ && op->dtype != target_data_type_ && !var_remap_.count(op)) {
+ var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
}
- return GetRef<PrimExpr>(op);
+ return DataTypeLegalizer::VisitExpr_(op);
}
PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) {
diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc
index 2f116a0229..e1dc2f5bf1 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -233,12 +233,8 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
}
PrimExpr VisitExpr_(const VarNode* op) final {
- if (auto it = var_remap_.find(GetRef<Var>(op)); it != var_remap_.end()) {
- return (*it).second;
- } else if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
- Var v = Var(op->name_hint, visitor_.vmap[op]);
- var_remap_.Set(GetRef<Var>(op), v);
- return v;
+ if (auto it = visitor_.vmap.find(op); !var_remap_.count(op) && it != visitor_.vmap.end()) {
+ var_remap_[op] = Var(op->name_hint, it->second);
}
return Parent::VisitExpr_(op);
}
@@ -266,9 +262,6 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
private:
// the internal visitor to deduce the narrowed dtype
DataTypeVisitor visitor_;
- // a map from Var before rewrite to that after rewrite,
- // ensures one old Var maps to exactly one new Var
- std::unordered_map<const VarNode*, Var> vmap_;
};
Stmt NarrowDataType(Stmt stmt, int target_bits) {
diff --git a/tests/python/unittest/test_te_create_primfunc.py b/tests/python/unittest/test_te_create_primfunc.py
index 271e0a339c..c13ede0831 100644
--- a/tests/python/unittest/test_te_create_primfunc.py
+++ b/tests/python/unittest/test_te_create_primfunc.py
@@ -18,7 +18,7 @@
import numpy as np
import tvm
import tvm.testing
-from tvm import te, tir, topi
+from tvm import te, tir, topi, relay
from tvm.script import tir as T
import pytest
@@ -636,5 +636,58 @@ def test_reshape():
_check_workload(te_reshape, tir_reshape, index_dtype_override="int64")
+@T.prim_func
+def argmax_expected(
+ p0: T.Buffer[(T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "uint8"],
+ p0_red: T.Buffer[(T.int64(1), T.int64(56), T.int64(56)), "int32"],
+):
+ T.func_attr({"global_symbol": "main", "tir.noalias": True})
+ p0_red_temp_v0 = T.alloc_buffer([T.int64(1), T.int64(56), T.int64(56)], dtype="int32")
+ p0_red_temp_v1 = T.alloc_buffer([T.int64(1), T.int64(56), T.int64(56)], dtype="uint8")
+ for ax0, ax1, ax2, k1 in T.grid(T.int64(1), T.int64(56), T.int64(56), T.int64(64)):
+ with T.block("p0_red_temp"):
+ v_ax0, v_ax1, v_ax2, v_k1 = T.axis.remap("SSSR", [ax0, ax1, ax2, k1])
+ T.reads(p0[v_ax0, v_k1, v_ax1, v_ax2])
+ T.writes(p0_red_temp_v0[v_ax0, v_ax1, v_ax2], p0_red_temp_v1[v_ax0, v_ax1, v_ax2])
+ with T.init():
+ p0_red_temp_v0[v_ax0, v_ax1, v_ax2] = -1
+ p0_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.uint8(0)
+ v_p0_red_temp_v0: T.int64 = T.Select(
+ p0_red_temp_v1[v_ax0, v_ax1, v_ax2] > p0[v_ax0, v_k1, v_ax1, v_ax2]
+ or (
+ p0_red_temp_v1[v_ax0, v_ax1, v_ax2] == p0[v_ax0, v_k1, v_ax1, v_ax2]
+ and T.Cast("int64", p0_red_temp_v0[v_ax0, v_ax1, v_ax2]) < v_k1
+ ),
+ T.Cast("int64", p0_red_temp_v0[v_ax0, v_ax1, v_ax2]),
+ v_k1,
+ )
+ v_p0_red_temp_v1: T.uint8 = T.Select(
+ p0_red_temp_v1[v_ax0, v_ax1, v_ax2] > p0[v_ax0, v_k1, v_ax1, v_ax2],
+ p0_red_temp_v1[v_ax0, v_ax1, v_ax2],
+ p0[v_ax0, v_k1, v_ax1, v_ax2],
+ )
+ p0_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.Cast("int32", v_p0_red_temp_v0)
+ p0_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_p0_red_temp_v1
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(56), T.int64(56)):
+ with T.block("p0_red"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(p0_red_temp_v0[v_ax0, v_ax1, v_ax2])
+ T.writes(p0_red[v_ax0, v_ax1, v_ax2])
+ p0_red[v_ax0, v_ax1, v_ax2] = p0_red_temp_v0[v_ax0, v_ax1, v_ax2]
+
+
+def test_argmax():
+ data = relay.var("data", shape=(1, 64, 56, 56), dtype="uint8")
+ mod = tvm.IRModule.from_expr(relay.argmax(data, axis=1))
+
+ target = tvm.target.Target("llvm")
+
+ opt_mod, _ = relay.optimize(mod, params={}, target=target)
+
+ prim_func = relay.backend.te_compiler.lower_to_primfunc(opt_mod["main"].body.op, target)
+
+ tvm.ir.assert_structural_equal(prim_func, argmax_expected)
+
+
if __name__ == "__main__":
tvm.testing.main()