You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/07/18 16:04:03 UTC

[tvm] branch main updated: [TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973)

This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9c7aaace43 [TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973)
9c7aaace43 is described below

commit 9c7aaace4355c67403be563de3059d34fb8e29f5
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Jul 18 11:03:54 2022 -0500

    [TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973)
    
    * [TIR] Moved PrimExpr operator overload from op.h to expr.h
    
    If a compilation unit includes `<tvm/ir/expr.h>`, but does not include
    `<tvm/tir/op.h>`, the operator overloads for `ObjectRef` are declared,
    but the operator overloads for `PrimExpr` are not.  In this case, any
    use of `expr_a == expr_b` would use `ObjectRef`'s implementation and
    compare reference equality of the two expressions, rather than
    returning a `PrimExpr` that represents the comparison.  By having the
    operator overloads in the `<tvm/ir/expr.h>` header file, directly
    adjacent to the `PrimExpr` declaration, the correct overload must be
    available whenever the `PrimExpr` can be used.
    
    Even though this would only impact `operator==`, `operator!=`, and
    `operator<`, the three operators defined for `ObjectRef`, this PR
    moves all operator overloads to `expr.h` for consistency.
    
    The named version of the operators (e.g. `tvm::add`) do not have
    overloaded variants, and so they are intentionally kept in
    `<tvm/tir/op.h>`.
    
    * Explicitly convert TVMRetValue to bool in target.cc
    
    Needed to avoid ambiguity between `TVMRetValue -> bool` conversion and
    `TVMRetValue -> int -> PrimExpr` conversion.
    
    * Used vector/unordered_set to track BufferInfoExtractor::call_order_
    
    Use of `std::set<Call>` had ambiguity between `operator<` by
    `PrimExpr` or by `ObjectRef`.
    
    The comment for `call_order_` implied that the previous usage of
    `std::set<Call>` was intended to have a de-duplicated list in the
    order of occurrence.  However, the `std::set` was ordered by
    `ObjectRef::operator<`, not by insertion order.  Switching to using a
    `vector` for ordering and `unordered_set` for de-duplication resolves
    this issue, and also removes the use of `operator<`.
    
    * Remove C-style cast to fix lint error
---
 include/tvm/ir/expr.h                        | 214 +++++++++++++++++++++++++++
 include/tvm/tir/op.h                         | 195 ------------------------
 src/target/target.cc                         |   9 +-
 src/tir/usmp/analysis/extract_buffer_info.cc |  11 +-
 4 files changed, 228 insertions(+), 201 deletions(-)

diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index b2cfc295b6..5e358ed50e 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -133,6 +133,220 @@ class PrimExpr : public BaseExpr {
   TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
 };
 
+/*!
+ * \brief add operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief subtraction operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief negation.
+ *
+ * \param a input.
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator-(PrimExpr a);
+
+/*!
+ * \brief multiplication operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief division operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief left shift operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief right shift operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief greater
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief greater_equal
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief less
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief less_equal
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief equal
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief not_equal
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief and
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note This operator does eager constant folding.
+ */
+TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief or
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note This operator does eager constant folding.
+ */
+TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief not
+ *
+ * \param a left operand
+ * \return The result expression.
+ * \note This operator does eager constant folding.
+ */
+TVM_DLL PrimExpr operator!(PrimExpr a);
+
+/*!
+ * \brief take bitwise and of two values
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief take bitwise or of two values
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief take bitwise xor of two values
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief take bitwise negation of two values
+ *
+ * \param a the input expression.
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ *       index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator~(PrimExpr a);
+
 /*!
  * \brief Base node of all non-primitive expressions.
  *
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 34935aec61..7236c6a611 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -42,7 +42,6 @@ namespace tvm {
 
 // Most common operators can be overloaded by argument type(PrimExpr).
 // So we put them under the root namespace.
-// It is also necessary to overload operators for PrimExpr.
 //
 // We put more developer oriented APIs -- make_const and is_const under tir
 // as they are more specific to the tir namespace.
@@ -143,16 +142,6 @@ TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief add operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
 /*!
  * \brief subtraction operator
  *
@@ -164,16 +153,6 @@ TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief subtraction operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
 /*!
  * \brief negation.
  *
@@ -184,15 +163,6 @@ TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span());
-/*!
- * \brief negation.
- *
- * \param a input.
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator-(PrimExpr a);
 /*!
  * \brief multiplication operator
  *
@@ -204,26 +174,6 @@ TVM_DLL PrimExpr operator-(PrimExpr a);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief multiplication operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
-/*!
- * \brief division operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
 /*!
  * \brief left shift operator
  *
@@ -235,16 +185,6 @@ TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief left shift operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
 /*!
  * \brief right shift operator
  *
@@ -256,16 +196,6 @@ TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief right shift operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
 /*!
  * \brief greater
  *
@@ -277,16 +207,6 @@ TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief greater
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
 /*!
  * \brief greater_equal
  *
@@ -298,16 +218,6 @@ TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief greater_equal
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
 /*!
  * \brief less
  *
@@ -319,16 +229,6 @@ TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief less
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
 /*!
  * \brief less_equal
  *
@@ -340,16 +240,6 @@ TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief less_equal
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
 /*!
  * \brief equal
  *
@@ -361,16 +251,6 @@ TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief equal
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
 /*!
  * \brief not_equal
  *
@@ -382,16 +262,6 @@ TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief not_equal
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
 /*!
  * \brief and
  *
@@ -402,15 +272,6 @@ TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
  * \note This operator does eager constant folding.
  */
 TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief and
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note This operator does eager constant folding.
- */
-TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
 /*!
  * \brief or
  *
@@ -421,15 +282,6 @@ TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
  * \note This operator does eager constant folding.
  */
 TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief or
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note This operator does eager constant folding.
- */
-TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
 /*!
  * \brief not
  *
@@ -439,14 +291,6 @@ TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
  * \note This operator does eager constant folding.
  */
 TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span());
-/*!
- * \brief not
- *
- * \param a left operand
- * \return The result expression.
- * \note This operator does eager constant folding.
- */
-TVM_DLL PrimExpr operator!(PrimExpr a);
 /*!
  * \brief compute division in C semantics.
  *
@@ -601,16 +445,6 @@ TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span());
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief take bitwise and of two values
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
 /*!
  * \brief take bitwise or of two values
  *
@@ -622,16 +456,6 @@ TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief take bitwise or of two values
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
 /*!
  * \brief take bitwise xor of two values
  *
@@ -643,16 +467,6 @@ TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief take bitwise xor of two values
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
 /*!
  * \brief take bitwise negation of two values
  *
@@ -663,15 +477,6 @@ TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
  *       index types(int32, int64) when possible.
  */
 TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span());
-/*!
- * \brief take bitwise negation of two values
- *
- * \param a the input expression.
- * \return The result expression.
- * \note this function does eager constant folding for
- *       index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator~(PrimExpr a);
 /*!
  * \brief Conditional expression.
  *
diff --git a/src/target/target.cc b/src/target/target.cc
index 07b347f098..01f9bfaeec 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -847,10 +847,11 @@ std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
 
   TVMRetValue ret;
   api->GetAttr(device, runtime::kExist, &ret);
-  if (!ret) {
-    ICHECK(ret) << "Requested reading the parameters for " << target->kind->name
-                << " from device_id " << device_id << ", but device_id " << device_id
-                << " doesn't exist.  Using default target parameters.";
+  bool device_exists = ret;
+  if (!device_exists) {
+    ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name
+                          << " from device_id " << device_id << ", but device_id " << device_id
+                          << " doesn't exist.  Using default target parameters.";
     return output;
   }
 
diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc
index ba8f6aa911..74d428f6dd 100644
--- a/src/tir/usmp/analysis/extract_buffer_info.cc
+++ b/src/tir/usmp/analysis/extract_buffer_info.cc
@@ -92,7 +92,11 @@ class BufferInfoExtractor : public StmtExprVisitor {
   /*!
    * \brief Records the order of calls in the main for stability.
    */
-  std::set<Call> call_order_;
+  std::vector<Call> call_order_;
+  /*!
+   * \brief Lookup to avoid adding duplicates to `call_order_`.
+   */
+  std::unordered_set<Call, ObjectPtrHash, ObjectPtrEqual> call_order_contents_;
   /*!
    * \brief Records first access in-terms of Stmts to each buffer per call
    *
@@ -469,7 +473,10 @@ void BufferInfoExtractor::VisitPrimFunc(const PrimFunc& func, const Call& call)
                scope_stack_.top().allocate_nodes,
                scope_stack_.top().allocate_const_nodes,
                scope_stack_.top().initial_stmt_of_the_nested_loops};
-  call_order_.insert(call);
+  if (call_order_contents_.count(call) == 0) {
+    call_order_contents_.insert(call);
+    call_order_.push_back(call);
+  }
   scope_stack_.push(si);
   this->VisitStmt(func->body);
   scope_stack_.pop();