You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/13 11:44:42 UTC
[tvm] 02/02: [Unity] Fix ForceNarrowI32 with pod arguments (#14605)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 4ba82881f5fef9cd0a3efc005e09b6f6d6e6d470
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Thu Apr 13 03:11:21 2023 -0400
[Unity] Fix ForceNarrowI32 with pod arguments (#14605)
This PR fixes ForceNarrowI32 for functions that comes
with pod arguments.
Test cases are added
Co-authored-by: Junru Shao <ju...@apache.org>
---
include/tvm/tir/data_type_rewriter.h | 9 ++--
src/tir/ir/data_type_rewriter.cc | 63 +++++++++++++++++++---
src/tir/transforms/force_narrow_index_to_i32.cc | 10 ++++
...st_transform_legalize_ops_search_statistical.py | 11 ++--
.../relax/test_transform_legalize_ops_unary.py | 2 +-
...test_tir_transform_force_narrow_index_to_i32.py | 21 ++++++++
6 files changed, 99 insertions(+), 17 deletions(-)
diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h
index 5f72f75ede..76b53b18ca 100644
--- a/include/tvm/tir/data_type_rewriter.h
+++ b/include/tvm/tir/data_type_rewriter.h
@@ -74,6 +74,7 @@ class DataTypeLegalizer : public StmtExprMutator {
PrimExpr VisitExpr_(const GENode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
PrimExpr VisitExpr_(const CastNode* op) override;
+ PrimExpr VisitExpr_(const LetNode* op) override;
using StmtExprMutator::VisitExpr_;
using StmtExprMutator::VisitStmt_;
@@ -115,6 +116,8 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
PrimExpr VisitExpr_(const GTNode* op) override;
PrimExpr VisitExpr_(const GENode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override;
+ PrimExpr VisitExpr_(const SelectNode* op) override;
+
Stmt VisitStmt_(const ForNode* op) override;
Buffer VisitBuffer(const Buffer& buffer);
@@ -146,9 +149,9 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter {
using Parent = IndexDataTypeRewriter;
using Parent::VisitExpr_;
using Parent::VisitStmt_;
- PrimExpr VisitExpr_(const IntImmNode* op) final;
- PrimExpr VisitExpr_(const VarNode* op) final;
- PrimExpr VisitExpr_(const CastNode* op) final;
+ PrimExpr VisitExpr_(const IntImmNode* op) override;
+ PrimExpr VisitExpr_(const VarNode* op) override;
+ PrimExpr VisitExpr_(const CastNode* op) override;
DataType target_data_type_ = DataType::Int(64);
};
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 8da7cfdd5b..97ad1b7cc3 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -107,25 +107,39 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
return StmtExprMutator::VisitStmt_(op);
}
+PrimExpr DataTypeLegalizer::VisitExpr_(const LetNode* op) {
+ PrimExpr value = this->VisitExpr(op->value);
+ Var var = op->var;
+
+ if (value.dtype() != op->var->dtype) {
+ var = op->var.copy_with_dtype(value.dtype());
+ var_remap_[op->var.get()] = var;
+ }
+
+ PrimExpr new_body = this->VisitExpr(op->body);
+
+ if (value.same_as(op->value) && new_body.same_as(op->body)) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ return Let(var, value, new_body, op->span);
+ }
+}
+
Stmt DataTypeLegalizer::VisitStmt_(const LetStmtNode* op) {
PrimExpr value = this->VisitExpr(op->value);
- auto new_var = op->var.copy_with_dtype(value.dtype());
+ Var var = op->var;
if (value.dtype() != op->var->dtype) {
- var_remap_[op->var.get()] = new_var;
+ var = op->var.copy_with_dtype(value.dtype());
+ var_remap_[op->var.get()] = var;
}
Stmt new_body = this->VisitStmt(op->body);
if (value.same_as(op->value) && new_body.same_as(op->body)) {
return GetRef<Stmt>(op);
- } else if (value.dtype() == op->var->dtype) {
- auto n = CopyOnWrite(op);
- n->value = std::move(value);
- n->body = std::move(new_body);
- return Stmt(n);
} else {
- return LetStmt(new_var, value, new_body, op->span);
+ return LetStmt(var, value, new_body, op->span);
}
}
@@ -542,6 +556,26 @@ PrimExpr IndexDataTypeRewriter::VisitExpr_(const CallNode* op) {
return Parent::VisitExpr_(op);
}
+PrimExpr IndexDataTypeRewriter::VisitExpr_(const SelectNode* op) {
+ bool is_condition = true;
+ std::swap(is_condition_, is_condition);
+ PrimExpr condition = this->VisitExpr(op->condition);
+ std::swap(is_condition_, is_condition);
+ PrimExpr true_value = this->VisitExpr(op->true_value);
+ PrimExpr false_value = this->VisitExpr(op->false_value);
+
+ if (condition.same_as(op->condition) && true_value.same_as(op->true_value) &&
+ false_value.same_as(op->false_value) && true_value.dtype() == false_value.dtype()) {
+ return GetRef<PrimExpr>(op);
+ } else {
+ int bits = std::max(true_value.dtype().bits(), false_value.dtype().bits());
+ DataType dtype = true_value.dtype().with_bits(bits);
+ if (true_value.dtype() != dtype) true_value = cast(dtype, true_value);
+ if (false_value.dtype() != dtype) false_value = cast(dtype, false_value);
+ return Select(condition, true_value, false_value);
+ }
+}
+
#undef TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH
IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType target_data_type)
@@ -552,7 +586,20 @@ PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
for (const auto& [var, buffer] : func->buffer_map) {
new_buffer_map.Set(var, VisitBuffer(buffer));
}
+ // remap params
+ bool is_enabled = true;
+ std::swap(is_enabled_, is_enabled);
+ Array<Var> params = func->params.Map([this](Var param) {
+ if (param.dtype().is_int()) {
+ return Downcast<Var>(this->VisitExpr(param));
+ } else {
+ return param;
+ }
+ });
+ std::swap(is_enabled_, is_enabled);
+
PrimFuncNode* new_func = func.CopyOnWrite();
+ new_func->params = std::move(params);
new_func->buffer_map = std::move(new_buffer_map);
new_func->body = VisitStmt(std::move(new_func->body));
return func;
diff --git a/src/tir/transforms/force_narrow_index_to_i32.cc b/src/tir/transforms/force_narrow_index_to_i32.cc
index 70dc554e12..c559360bf5 100644
--- a/src/tir/transforms/force_narrow_index_to_i32.cc
+++ b/src/tir/transforms/force_narrow_index_to_i32.cc
@@ -24,6 +24,7 @@
*/
#include <tvm/tir/data_type_rewriter.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
namespace tvm {
@@ -48,6 +49,15 @@ class Int32DTypeNarrower : public IndexDataTypeNormalizer {
explicit Int32DTypeNarrower(PrimFunc func)
: IndexDataTypeNormalizer(DataType::Int(32)), func_(std::move(func)) {}
+ PrimExpr VisitExpr_(const IntImmNode* op) final {
+ // ignore the enabled condition and always rewrite i64
+ if (op->dtype == DataType::Int(64)) {
+ ICHECK_LE(op->value, Downcast<Integer>(max_value(target_data_type_))->value);
+ return IntImm(DataType::Int(32), op->value);
+ }
+ return GetRef<IntImm>(op);
+ }
+
Stmt VisitStmt_(const BlockNode* block) final {
Block block_ = Downcast<Block>(IndexDataTypeNormalizer::VisitStmt_(block));
// Check if the allocated integer buffers have dtype other than int32.
diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 886e880c31..b612bbbae5 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -16,10 +16,11 @@
# under the License.
import tvm
-from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
-
+from tvm.relax.transform import LegalizeOps
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
##################### Search #####################
@@ -48,7 +49,7 @@ def test_where():
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
T.writes(T_where[ax0, ax1, ax2])
- T_where[ax0, ax1, ax2] = T.Select(0 < T.Cast("int32", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
+ T_where[ax0, ax1, ax2] = T.Select(T.int64(0) < T.Cast("int64", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
# fmt: on
mod = LegalizeOps()(Where)
@@ -92,7 +93,7 @@ def test_where_symbolic():
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder[ax0, ax1, T.int64(0)], rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
T.writes(T_where[ax0, ax1, ax2])
- T_where[ax0, ax1, ax2] = T.Select(0 < T.Cast("int32", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
+ T_where[ax0, ax1, ax2] = T.Select(T.int64(0) < T.Cast("int64", rxplaceholder[ax0, ax1, T.int64(0)]), rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, T.int64(0)])
# fmt: on
mod = LegalizeOps()(Where)
diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py b/tests/python/relax/test_transform_legalize_ops_unary.py
index 1a5d474c3e..398fffdbdb 100644
--- a/tests/python/relax/test_transform_legalize_ops_unary.py
+++ b/tests/python/relax/test_transform_legalize_ops_unary.py
@@ -866,7 +866,7 @@ def test_sign_int():
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(rxplaceholder[v_ax0, v_ax1])
T.writes(T_sign[v_ax0, v_ax1])
- T_sign[v_ax0, v_ax1] = T.Select(0 < rxplaceholder[v_ax0, v_ax1], 1, T.Select(rxplaceholder[v_ax0, v_ax1] < 0, -1, 0))
+ T_sign[v_ax0, v_ax1] = T.Select(T.int64(0) < T.Cast("int64", rxplaceholder[v_ax0, v_ax1]), 1, T.Select(T.Cast("int64", rxplaceholder[v_ax0, v_ax1]) < T.int64(0), -1, 0))
@R.function
def main(x: R.Tensor((2, 3), dtype="int32")) -> R.Tensor((2, 3), dtype="int32"):
diff --git a/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
index f275d438a7..6c12229b5c 100644
--- a/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
+++ b/tests/python/unittest/test_tir_transform_force_narrow_index_to_i32.py
@@ -216,5 +216,26 @@ def test_fail_on_buffer_map():
tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"]
+def test_pod_params_and_select():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(
+ A: T.Buffer((T.int64(4),), "float32"), B: T.Buffer((T.int64(4),), "float32"), n: T.int64
+ ):
+ for i in T.serial(T.int64(4)):
+ B[i] = T.Select(T.int64(1) <= i, A[i + n], T.Cast("float32", i))
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32"), n: T.int32):
+ for i in range(4):
+ B[i] = T.Select(1 <= i, A[i + n], T.Cast("float32", i))
+
+ after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before)
+ tvm.ir.assert_structural_equal(Expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()