You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lu...@apache.org on 2022/07/18 16:04:03 UTC
[tvm] branch main updated: [TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973)
This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9c7aaace43 [TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973)
9c7aaace43 is described below
commit 9c7aaace4355c67403be563de3059d34fb8e29f5
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Jul 18 11:03:54 2022 -0500
[TIR] Moved PrimExpr operator overload from op.h to expr.h (#11973)
* [TIR] Moved PrimExpr operator overload from op.h to expr.h
If a compilation unit includes `<tvm/ir/expr.h>`, but does not include
`<tvm/tir/op.h>`, the operator overloads for `ObjectRef` are declared,
but the operator overloads for `PrimExpr` are not. In this case, any
use of `expr_a == expr_b` would use `ObjectRef`'s implementation and
compare reference equality of the two expressions, rather than
returning a `PrimExpr` that represents the comparison. By having the
operator overloads in the `<tvm/ir/expr.h>` header file, directly
adjacent to the `PrimExpr` declaration, the correct overload must be
available whenever the `PrimExpr` can be used.
Even though this would only impact `operator==`, `operator!=`, and
`operator<`, the three operators defined for `ObjectRef`, this PR
moves all operator overloads to `expr.h` for consistency.
The named version of the operators (e.g. `tvm::add`) do not have
overloaded variants, and so they are intentionally kept in
`<tvm/tir/op.h>`.
* Explicitly convert TVMRetValue to bool in target.cc
Needed to avoid ambiguity between `TVMRetValue -> bool` conversion and
`TVMRetValue -> int -> PrimExpr` conversion.
* Used vector/unordered_set to track BufferInfoExtractor::call_order_
Use of `std::set<Call>` had ambiguity between `operator<` by
`PrimExpr` or by `ObjectRef`.
The comment for `call_order_` implied that the previous usage of
`std::set<Call>` was intended to have a de-duplicated list in the
order of occurrence. However, the `std::set` was ordered by
`ObjectRef::operator<`, not by insertion order. Switching to using a
`vector` for ordering and `unordered_set` for de-duplication resolves
this issue, and also removes the use of `operator<`.
* Remove C-style cast to fix lint error
---
include/tvm/ir/expr.h | 214 +++++++++++++++++++++++++++
include/tvm/tir/op.h | 195 ------------------------
src/target/target.cc | 9 +-
src/tir/usmp/analysis/extract_buffer_info.cc | 11 +-
4 files changed, 228 insertions(+), 201 deletions(-)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index b2cfc295b6..5e358ed50e 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -133,6 +133,220 @@ class PrimExpr : public BaseExpr {
TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
};
+/*!
+ * \brief add operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief subtraction operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief negation.
+ *
+ * \param a input.
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator-(PrimExpr a);
+
+/*!
+ * \brief multiplication operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief division operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief left shift operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief right shift operator
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief greater
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief greater_equal
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief less
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief less_equal
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief equal
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief not_equal
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief and
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note This operator does eager constant folding.
+ */
+TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief or
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note This operator does eager constant folding.
+ */
+TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief not
+ *
+ * \param a left operand
+ * \return The result expression.
+ * \note This operator does eager constant folding.
+ */
+TVM_DLL PrimExpr operator!(PrimExpr a);
+
+/*!
+ * \brief take bitwise and of two values
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief take bitwise or of two values
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief take bitwise xor of two values
+ *
+ * \param a left operand
+ * \param b right operand
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
+
+/*!
+ * \brief take bitwise negation of two values
+ *
+ * \param a the input expression.
+ * \return The result expression.
+ * \note this function does eager constant folding for
+ * index types(int32, int64) when possible.
+ */
+TVM_DLL PrimExpr operator~(PrimExpr a);
+
/*!
* \brief Base node of all non-primitive expressions.
*
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 34935aec61..7236c6a611 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -42,7 +42,6 @@ namespace tvm {
// Most common operators can be overloaded by argument type(PrimExpr).
// So we put them under the root namespace.
-// It is also necessary to overload operators for PrimExpr.
//
// We put more developer oriented APIs -- make_const and is_const under tir
// as they are more specific to the tir namespace.
@@ -143,16 +142,6 @@ TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief add operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
/*!
* \brief subtraction operator
*
@@ -164,16 +153,6 @@ TVM_DLL PrimExpr operator+(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief subtraction operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
/*!
* \brief negation.
*
@@ -184,15 +163,6 @@ TVM_DLL PrimExpr operator-(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span());
-/*!
- * \brief negation.
- *
- * \param a input.
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator-(PrimExpr a);
/*!
* \brief multiplication operator
*
@@ -204,26 +174,6 @@ TVM_DLL PrimExpr operator-(PrimExpr a);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief multiplication operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator*(PrimExpr a, PrimExpr b);
-/*!
- * \brief division operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
/*!
* \brief left shift operator
*
@@ -235,16 +185,6 @@ TVM_DLL PrimExpr operator/(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief left shift operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
/*!
* \brief right shift operator
*
@@ -256,16 +196,6 @@ TVM_DLL PrimExpr operator<<(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief right shift operator
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
/*!
* \brief greater
*
@@ -277,16 +207,6 @@ TVM_DLL PrimExpr operator>>(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief greater
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
/*!
* \brief greater_equal
*
@@ -298,16 +218,6 @@ TVM_DLL PrimExpr operator>(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief greater_equal
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
/*!
* \brief less
*
@@ -319,16 +229,6 @@ TVM_DLL PrimExpr operator>=(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief less
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
/*!
* \brief less_equal
*
@@ -340,16 +240,6 @@ TVM_DLL PrimExpr operator<(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief less_equal
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
/*!
* \brief equal
*
@@ -361,16 +251,6 @@ TVM_DLL PrimExpr operator<=(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief equal
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
/*!
* \brief not_equal
*
@@ -382,16 +262,6 @@ TVM_DLL PrimExpr operator==(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief not_equal
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
/*!
* \brief and
*
@@ -402,15 +272,6 @@ TVM_DLL PrimExpr operator!=(PrimExpr a, PrimExpr b);
* \note This operator does eager constant folding.
*/
TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief and
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note This operator does eager constant folding.
- */
-TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
/*!
* \brief or
*
@@ -421,15 +282,6 @@ TVM_DLL PrimExpr operator&&(PrimExpr a, PrimExpr b);
* \note This operator does eager constant folding.
*/
TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief or
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note This operator does eager constant folding.
- */
-TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
/*!
* \brief not
*
@@ -439,14 +291,6 @@ TVM_DLL PrimExpr operator||(PrimExpr a, PrimExpr b);
* \note This operator does eager constant folding.
*/
TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span());
-/*!
- * \brief not
- *
- * \param a left operand
- * \return The result expression.
- * \note This operator does eager constant folding.
- */
-TVM_DLL PrimExpr operator!(PrimExpr a);
/*!
* \brief compute division in C semantics.
*
@@ -601,16 +445,6 @@ TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span());
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief take bitwise and of two values
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
/*!
* \brief take bitwise or of two values
*
@@ -622,16 +456,6 @@ TVM_DLL PrimExpr operator&(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief take bitwise or of two values
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
/*!
* \brief take bitwise xor of two values
*
@@ -643,16 +467,6 @@ TVM_DLL PrimExpr operator|(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span());
-/*!
- * \brief take bitwise xor of two values
- *
- * \param a left operand
- * \param b right operand
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
/*!
* \brief take bitwise negation of two values
*
@@ -663,15 +477,6 @@ TVM_DLL PrimExpr operator^(PrimExpr a, PrimExpr b);
* index types(int32, int64) when possible.
*/
TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span());
-/*!
- * \brief take bitwise negation of two values
- *
- * \param a the input expression.
- * \return The result expression.
- * \note this function does eager constant folding for
- * index types(int32, int64) when possible.
- */
-TVM_DLL PrimExpr operator~(PrimExpr a);
/*!
* \brief Conditional expression.
*
diff --git a/src/target/target.cc b/src/target/target.cc
index 07b347f098..01f9bfaeec 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -847,10 +847,11 @@ std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
TVMRetValue ret;
api->GetAttr(device, runtime::kExist, &ret);
- if (!ret) {
- ICHECK(ret) << "Requested reading the parameters for " << target->kind->name
- << " from device_id " << device_id << ", but device_id " << device_id
- << " doesn't exist. Using default target parameters.";
+ bool device_exists = ret;
+ if (!device_exists) {
+ ICHECK(device_exists) << "Requested reading the parameters for " << target->kind->name
+ << " from device_id " << device_id << ", but device_id " << device_id
+ << " doesn't exist. Using default target parameters.";
return output;
}
diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc
index ba8f6aa911..74d428f6dd 100644
--- a/src/tir/usmp/analysis/extract_buffer_info.cc
+++ b/src/tir/usmp/analysis/extract_buffer_info.cc
@@ -92,7 +92,11 @@ class BufferInfoExtractor : public StmtExprVisitor {
/*!
* \brief Records the order of calls in the main for stability.
*/
- std::set<Call> call_order_;
+ std::vector<Call> call_order_;
+ /*!
+ * \brief Lookup to avoid adding duplicates to `call_order_`.
+ */
+ std::unordered_set<Call, ObjectPtrHash, ObjectPtrEqual> call_order_contents_;
/*!
* \brief Records first access in-terms of Stmts to each buffer per call
*
@@ -469,7 +473,10 @@ void BufferInfoExtractor::VisitPrimFunc(const PrimFunc& func, const Call& call)
scope_stack_.top().allocate_nodes,
scope_stack_.top().allocate_const_nodes,
scope_stack_.top().initial_stmt_of_the_nested_loops};
- call_order_.insert(call);
+ if (call_order_contents_.count(call) == 0) {
+ call_order_contents_.insert(call);
+ call_order_.push_back(call);
+ }
scope_stack_.push(si);
this->VisitStmt(func->body);
scope_stack_.pop();