You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2022/05/02 17:53:57 UTC

[tvm] branch main updated: [TIR] Reduced duplication in op.h (#11129)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 169f824d69 [TIR] Reduced duplication in op.h (#11129)
169f824d69 is described below

commit 169f824d69442ac9c19f320e2ddb9baec32af68e
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon May 2 12:53:51 2022 -0500

    [TIR] Reduced duplication in op.h (#11129)
    
    * [TIR] Reduced duplication in op.h
    
    Previously, `is_positive_int`, `is_negative_int`, `is_const_int`, and
    `as_const_int` had nearly duplicate type-checking logic.  This allowed
    handling of Broadcast nodes to be diverge between the
    implementations.  (e.g. `is_const_int(Broadcast(4,1), 4)` returns
    true, but `is_positive_int(Broadcast(4,1))` returns false.)
    
    This changes `as_const_int` to contain the type-checking logic,
    including the handling of Broadcast nodes, with the other three
    functions implemented in terms of `as_const_int`.
    
    * Test case, removing BroadcastNode handling from as_const_int
    
    Rather than extending it to apply in more cases, seeing if it is safe
    to extract this functionality out to a separate function.
---
 include/tvm/tir/op.h | 41 +++++++++--------------------------------
 1 file changed, 9 insertions(+), 32 deletions(-)

diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 9c3ea135c6..5b63016d2f 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -977,9 +977,9 @@ inline const int64_t* as_const_int(const PrimExpr& x) {
   if (!x.defined()) return nullptr;
   if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) {
     return &(op->value);
-  } else {
-    return nullptr;
   }
+
+  return nullptr;
 }
 
 /*!
@@ -1051,17 +1051,7 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr
 TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);
 
 // Implementation details after this
-inline bool is_const_int(const PrimExpr& x) {
-  if (x.as<tir::IntImmNode>()) {
-    return true;
-  } else if (const auto* op = x.as<tir::BroadcastNode>()) {
-    const PrimExpr& val = op->value;
-    if (val.as<tir::IntImmNode>()) {
-      return true;
-    }
-  }
-  return false;
-}
+inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); }
 
 inline bool is_const_number(const PrimExpr& x) {
   if (x.as<tir::IntImmNode>()) {
@@ -1075,31 +1065,18 @@ inline bool is_const_number(const PrimExpr& x) {
 }
 
 inline bool is_positive_const(const PrimExpr& a) {
-  if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
-    return op->value > 0;
-  } else {
-    return false;
-  }
+  const int64_t* as_int = as_const_int(a);
+  return as_int && (*as_int > 0);
 }
 
 inline bool is_negative_const(const PrimExpr& a) {
-  if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
-    return op->value < 0;
-  } else {
-    return false;
-  }
+  const int64_t* as_int = as_const_int(a);
+  return as_int && (*as_int < 0);
 }
 
 inline bool is_const_int(const PrimExpr& x, int64_t value) {
-  if (const auto* op = x.as<tir::IntImmNode>()) {
-    return op->value == value;
-  } else if (const auto* op = x.as<tir::BroadcastNode>()) {
-    const PrimExpr& val = op->value;
-    if (const auto* opv = val.as<tir::IntImmNode>()) {
-      return opv->value == value;
-    }
-  }
-  return false;
+  const int64_t* as_int = as_const_int(x);
+  return as_int && (*as_int == value);
 }
 
 inline bool is_no_op(const tir::Stmt& stmt) {