You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2022/05/02 17:53:57 UTC
[tvm] branch main updated: [TIR] Reduced duplication in op.h (#11129)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 169f824d69 [TIR] Reduced duplication in op.h (#11129)
169f824d69 is described below
commit 169f824d69442ac9c19f320e2ddb9baec32af68e
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon May 2 12:53:51 2022 -0500
[TIR] Reduced duplication in op.h (#11129)
* [TIR] Reduced duplication in op.h
Previously, `is_positive_int`, `is_negative_int`, `is_const_int`, and
`as_const_int` had nearly duplicate type-checking logic. This allowed
handling of Broadcast nodes to be diverge between the
implementations. (e.g. `is_const_int(Broadcast(4,1), 4)` returns
true, but `is_positive_int(Broadcast(4,1))` returns false.)
This changes `as_const_int` to contain the type-checking logic,
including the handling of Broadcast nodes, with the other three
functions implemented in terms of `as_const_int`.
* Test case, removing BroadcastNode handling from as_const_int
Rather than extending it to apply in more cases, seeing if it is safe
to extract this functionality out to a separate function.
---
include/tvm/tir/op.h | 41 +++++++++--------------------------------
1 file changed, 9 insertions(+), 32 deletions(-)
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 9c3ea135c6..5b63016d2f 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -977,9 +977,9 @@ inline const int64_t* as_const_int(const PrimExpr& x) {
if (!x.defined()) return nullptr;
if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) {
return &(op->value);
- } else {
- return nullptr;
}
+
+ return nullptr;
}
/*!
@@ -1051,17 +1051,7 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr
TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
// Implementation details after this
-inline bool is_const_int(const PrimExpr& x) {
- if (x.as<tir::IntImmNode>()) {
- return true;
- } else if (const auto* op = x.as<tir::BroadcastNode>()) {
- const PrimExpr& val = op->value;
- if (val.as<tir::IntImmNode>()) {
- return true;
- }
- }
- return false;
-}
+inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); }
inline bool is_const_number(const PrimExpr& x) {
if (x.as<tir::IntImmNode>()) {
@@ -1075,31 +1065,18 @@ inline bool is_const_number(const PrimExpr& x) {
}
inline bool is_positive_const(const PrimExpr& a) {
- if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
- return op->value > 0;
- } else {
- return false;
- }
+ const int64_t* as_int = as_const_int(a);
+ return as_int && (*as_int > 0);
}
inline bool is_negative_const(const PrimExpr& a) {
- if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
- return op->value < 0;
- } else {
- return false;
- }
+ const int64_t* as_int = as_const_int(a);
+ return as_int && (*as_int < 0);
}
inline bool is_const_int(const PrimExpr& x, int64_t value) {
- if (const auto* op = x.as<tir::IntImmNode>()) {
- return op->value == value;
- } else if (const auto* op = x.as<tir::BroadcastNode>()) {
- const PrimExpr& val = op->value;
- if (const auto* opv = val.as<tir::IntImmNode>()) {
- return opv->value == value;
- }
- }
- return false;
+ const int64_t* as_int = as_const_int(x);
+ return as_int && (*as_int == value);
}
inline bool is_no_op(const tir::Stmt& stmt) {