You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/03 08:29:05 UTC
[tvm] branch main updated: [TVMScript] Add object path tracing to StructuralEqual (#12101)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 39ffe0a5ce [TVMScript] Add object path tracing to StructuralEqual (#12101)
39ffe0a5ce is described below
commit 39ffe0a5ce14c2105b7d48ee420e82e194787d8f
Author: Greg Bonik <gb...@octoml.ai>
AuthorDate: Wed Aug 3 01:28:59 2022 -0700
[TVMScript] Add object path tracing to StructuralEqual (#12101)
Motivation: when two IR objects fail a structural equality check, currently there is no easy way to
find out which part of the IR caused the mismatch. In this PR, we modify the `StructuralEqual`
infrastructure to also optionally return a pair of `ObjectPath` objects that point to the mismatch.
(See https://github.com/apache/tvm/pull/11977). In the upcoming PRs, we will pass these paths to the
TIR printer, so that it could highlight the mismatch location nicely.
Tracking issue: https://github.com/apache/tvm/issues/11912
---
include/tvm/node/reflection.h | 6 +
include/tvm/node/structural_equal.h | 157 ++++++++++++--
python/tvm/ir/base.py | 34 ++-
python/tvm/runtime/__init__.py | 1 +
python/tvm/runtime/object_path.py | 16 ++
src/node/reflection.cc | 44 ++++
src/node/structural_equal.cc | 237 +++++++++++++++++++--
src/node/structural_hash.cc | 162 +++++++++++++-
src/tir/analysis/deep_equal.cc | 13 +-
.../unittest/test_container_structural_equal.py | 155 ++++++++++++++
.../unittest/test_tir_structural_equal_hash.py | 188 +++++++++++++++-
11 files changed, 969 insertions(+), 44 deletions(-)
diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h
index 5d514e24d8..f547b5a707 100644
--- a/include/tvm/node/reflection.h
+++ b/include/tvm/node/reflection.h
@@ -404,5 +404,11 @@ inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr
}
}
+/*!
+ * \brief Given an object and an address of its attribute, return the key of the attribute.
+ * \return nullptr if no attribute with the given address exists.
+ */
+Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address);
+
} // namespace tvm
#endif // TVM_NODE_REFLECTION_H_
diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h
index 6c25c3d2d2..b51021fe40 100644
--- a/include/tvm/node/structural_equal.h
+++ b/include/tvm/node/structural_equal.h
@@ -24,6 +24,7 @@
#define TVM_NODE_STRUCTURAL_EQUAL_H_
#include <tvm/node/functor.h>
+#include <tvm/node/object_path.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/data_type.h>
@@ -56,6 +57,27 @@ class BaseValueEqual {
}
};
+/*!
+ * \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
+ */
+class ObjectPathPairNode : public Object {
+ public:
+ ObjectPath lhs_path;
+ ObjectPath rhs_path;
+
+ ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);
+
+ static constexpr const char* _type_key = "ObjectPathPair";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object);
+};
+
+class ObjectPathPair : public ObjectRef {
+ public:
+ ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode);
+};
+
/*!
* \brief Content-aware structural equality comparator for objects.
*
@@ -99,7 +121,10 @@ class StructuralEqual : public BaseValueEqual {
* equality checking. Instead, it can store the necessary equality conditions
* and check later via an internally managed stack.
*/
-class SEqualReducer : public BaseValueEqual {
+class SEqualReducer {
+ private:
+ struct PathTracingData;
+
public:
/*! \brief Internal handler that defines custom behaviors.. */
class Handler {
@@ -110,12 +135,24 @@ class SEqualReducer : public BaseValueEqual {
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether do we allow remap variables if possible.
+ * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
*
* \return false if there is an immediate failure, true otherwise.
* \note This function may save the equality condition of (lhs == rhs) in an internal
* stack and try to resolve later.
*/
- virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0;
+ virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+ const Optional<ObjectPathPair>& current_paths) = 0;
+
+ /*!
+ * \brief Mark the comparison as failed, but don't fail immediately.
+ *
+ * This is useful for producing better error messages when comparing containers.
+ * For example, if two array sizes mismatch, it's better to mark the comparison as failed
+ * but compare array elements anyway, so that we could find the true first mismatch.
+ */
+ virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
+
/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
@@ -129,28 +166,72 @@ class SEqualReducer : public BaseValueEqual {
* \brief Mark current comparison as graph node equal comparison.
*/
virtual void MarkGraphNode() = 0;
- };
- using BaseValueEqual::operator();
+ protected:
+ using PathTracingData = SEqualReducer::PathTracingData;
+ };
/*! \brief default constructor */
SEqualReducer() = default;
/*!
* \brief Constructor with a specific handler.
* \param handler The equal handler for objects.
+ * \param tracing_data Optional pointer to the path tracing data.
* \param map_free_vars Whether or not to map free variables.
*/
- explicit SEqualReducer(Handler* handler, bool map_free_vars)
- : handler_(handler), map_free_vars_(map_free_vars) {}
+ explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
+ : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
+
+ /*!
+ * \brief Reduce condition to comparison of two attribute values.
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \return the immediate check result.
+ */
+ bool operator()(const double& lhs, const double& rhs) const;
+ bool operator()(const int64_t& lhs, const int64_t& rhs) const;
+ bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
+ bool operator()(const int& lhs, const int& rhs) const;
+ bool operator()(const bool& lhs, const bool& rhs) const;
+ bool operator()(const std::string& lhs, const std::string& rhs) const;
+ bool operator()(const DataType& lhs, const DataType& rhs) const;
+
+ template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+ bool operator()(const ENum& lhs, const ENum& rhs) const {
+ using Underlying = typename std::underlying_type<ENum>::type;
+ static_assert(std::is_same<Underlying, int>::value,
+ "Enum must have `int` as the underlying type");
+ return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
+ }
+
+ /*!
+ * \brief Reduce condition to comparison of two objects.
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \return the immediate check result.
+ */
+ bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
+
/*!
* \brief Reduce condition to comparison of two objects.
+ *
+ * Like `operator()`, but with an additional `paths` parameter that specifies explicit object
+ * paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
+ * objects like Array and Map, or other custom objects that store nested objects that are not
+ * simply attributes.
+ *
+ * Can only be called when `IsPathTracingEnabled()` is `true`.
+ *
* \param lhs The left operand.
* \param rhs The right operand.
+ * \param paths Object paths for `lhs` and `rhs`.
* \return the immediate check result.
*/
- bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
- return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
+ bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
+ ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
+ return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
}
+
/*!
* \brief Reduce condition to comparison of two definitions,
* where free vars can be mapped.
@@ -162,9 +243,8 @@ class SEqualReducer : public BaseValueEqual {
* \param rhs The right operand.
* \return the immediate check result.
*/
- bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
- return handler_->SEqualReduce(lhs, rhs, true);
- }
+ bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
+
/*!
* \brief Reduce condition to comparison of two arrays.
* \param lhs The left operand.
@@ -173,13 +253,20 @@ class SEqualReducer : public BaseValueEqual {
*/
template <typename T>
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
- // quick specialization for Array to reduce amount of recursion
- // depth as array comparison is pretty common.
- if (lhs.size() != rhs.size()) return false;
- for (size_t i = 0; i < lhs.size(); ++i) {
- if (!(operator()(lhs[i], rhs[i]))) return false;
+ if (tracing_data_ == nullptr) {
+ // quick specialization for Array to reduce amount of recursion
+ // depth as array comparison is pretty common.
+ if (lhs.size() != rhs.size()) return false;
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ if (!(operator()(lhs[i], rhs[i]))) return false;
+ }
+ return true;
}
- return true;
+
+ // If tracing is enabled, fall back to the regular path
+ const ObjectRef& lhs_obj = lhs;
+ const ObjectRef& rhs_obj = rhs;
+ return (*this)(lhs_obj, rhs_obj);
}
/*!
* \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
@@ -198,11 +285,43 @@ class SEqualReducer : public BaseValueEqual {
/*! \return Get the internal handler. */
Handler* operator->() const { return handler_; }
+ /*! \brief Check if this reducer is tracing paths to the first mismatch. */
+ bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
+
+ /*!
+ * \brief Get the paths of the currently compared objects.
+ *
+ * Can only be called when `IsPathTracingEnabled()` is true.
+ */
+ const ObjectPathPair& GetCurrentObjectPaths() const;
+
+ /*!
+ * \brief Specify the object paths of a detected mismatch.
+ *
+ * Can only be called when `IsPathTracingEnabled()` is true.
+ */
+ void RecordMismatchPaths(const ObjectPathPair& paths) const;
+
private:
+ bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
+
+ bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+ const ObjectPathPair* paths) const;
+
+ static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
+ const void* rhs_address,
+ const PathTracingData* tracing_data);
+
+ template <typename T>
+ static bool CompareAttributeValues(const T& lhs, const T& rhs,
+ const PathTracingData* tracing_data);
+
/*! \brief Internal class pointer. */
- Handler* handler_;
+ Handler* handler_ = nullptr;
+ /*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
+ const PathTracingData* tracing_data_ = nullptr;
/*! \brief Whether or not to map free vars. */
- bool map_free_vars_;
+ bool map_free_vars_ = false;
};
} // namespace tvm
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 00514b472d..5b26d5e4fb 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -191,8 +191,8 @@ def structural_equal(lhs, rhs, map_free_vars=False):
The left operand.
map_free_vars : bool
- Whether or not shall we map free vars that does
- not bound to any definitions as equal to each other.
+ Whether free variables (i.e. variables without a definition site) should be mapped
+ as equal to each other.
Return
------
@@ -209,6 +209,36 @@ def structural_equal(lhs, rhs, map_free_vars=False):
return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))
+def get_first_structural_mismatch(lhs, rhs, map_free_vars=False):
+ """Like structural_equal(), but returns the ObjectPaths of the first detected mismatch.
+
+ Parameters
+ ----------
+ lhs : Object
+ The left operand.
+
+ rhs : Object
+ The left operand.
+
+ map_free_vars : bool
+ Whether free variables (i.e. variables without a definition site) should be mapped
+ as equal to each other.
+
+ Returns
+ -------
+ mismatch: Optional[Tuple[ObjectPath, ObjectPath]]
+ `None` if `lhs` and `rhs` are structurally equal.
+ Otherwise, a tuple of two ObjectPath objects that point to the first detected mismtach.
+ """
+ lhs = tvm.runtime.convert(lhs)
+ rhs = tvm.runtime.convert(rhs)
+ mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars)
+ if mismatch is None:
+ return None
+ else:
+ return mismatch.lhs_path, mismatch.rhs_path
+
+
def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""Assert lhs and rhs are structurally equal to each other.
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index 114f01dd0e..502de73721 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -19,6 +19,7 @@
# class exposures
from .packed_func import PackedFunc
from .object import Object
+from .object_path import ObjectPath, ObjectPathPair
from .object_generic import ObjectGeneric, ObjectTypes
from .ndarray import NDArray, DataType, DataTypeCode, Device
from .module import Module, num_threads
diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py
index 3eabce1f86..c4ec58a596 100644
--- a/python/tvm/runtime/object_path.py
+++ b/python/tvm/runtime/object_path.py
@@ -34,6 +34,7 @@ __all__ = (
"MissingArrayElementPath",
"MapValuePath",
"MissingMapEntryPath",
+ "ObjectPathPair",
)
@@ -122,3 +123,18 @@ class MapValuePath(ObjectPath):
@tvm._ffi.register_object("MissingMapEntryPath")
class MissingMapEntryPath(ObjectPath):
pass
+
+
+@tvm._ffi.register_object("ObjectPathPair")
+class ObjectPathPair(Object):
+ """
+ Pair of ObjectPaths, one for each object being tested for structural equality.
+ """
+
+ @property
+ def lhs_path(self) -> ObjectPath:
+ return _ffi_node_api.ObjectPathPairLhsPath(self)
+
+ @property
+ def rhs_path(self) -> ObjectPath:
+ return _ffi_node_api.ObjectPathPairRhsPath(self)
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index a7c3493e7f..a0f83f6cf5 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -281,4 +281,48 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);
TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);
+
+namespace {
+// Attribute visitor class for finding the attribute key by its address
+class GetAttrKeyByAddressVisitor : public AttrVisitor {
+ public:
+ explicit GetAttrKeyByAddressVisitor(const void* attr_address)
+ : attr_address_(attr_address), key_(nullptr) {}
+
+ void Visit(const char* key, double* value) final { DoVisit(key, value); }
+ void Visit(const char* key, int64_t* value) final { DoVisit(key, value); }
+ void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); }
+ void Visit(const char* key, int* value) final { DoVisit(key, value); }
+ void Visit(const char* key, bool* value) final { DoVisit(key, value); }
+ void Visit(const char* key, std::string* value) final { DoVisit(key, value); }
+ void Visit(const char* key, void** value) final { DoVisit(key, value); }
+ void Visit(const char* key, DataType* value) final { DoVisit(key, value); }
+ void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); }
+ void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); }
+
+ const char* GetKey() const { return key_; }
+
+ private:
+ const void* attr_address_;
+ const char* key_;
+
+ void DoVisit(const char* key, const void* candidate) {
+ if (attr_address_ == candidate) {
+ key_ = key;
+ }
+ }
+};
+} // anonymous namespace
+
+Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address) {
+ GetAttrKeyByAddressVisitor visitor(attr_address);
+ ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object), &visitor);
+ const char* key = visitor.GetKey();
+ if (key == nullptr) {
+ return NullOpt;
+ } else {
+ return String(key);
+ }
+}
+
} // namespace tvm
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index 8e52af60d2..01874c0536 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -22,6 +22,7 @@
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
+#include <tvm/node/object_path.h>
#include <tvm/node/reflection.h>
#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>
@@ -30,6 +31,25 @@
namespace tvm {
+TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode);
+
+TVM_REGISTER_GLOBAL("node.ObjectPathPairLhsPath")
+ .set_body_typed([](const ObjectPathPair& object_path_pair) {
+ return object_path_pair->lhs_path;
+ });
+
+TVM_REGISTER_GLOBAL("node.ObjectPathPairRhsPath")
+ .set_body_typed([](const ObjectPathPair& object_path_pair) {
+ return object_path_pair->rhs_path;
+ });
+
+ObjectPathPairNode::ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path)
+ : lhs_path(std::move(lhs_path)), rhs_path(std::move(rhs_path)) {}
+
+ObjectPathPair::ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path) {
+ data_ = make_object<ObjectPathPairNode>(std::move(lhs_path), std::move(rhs_path));
+}
+
// Define the dispatch function here since primary user is in this file.
bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
SEqualReducer equal) const {
@@ -42,6 +62,133 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
return fsequal_reduce_[tindex](self, other, equal);
}
+struct SEqualReducer::PathTracingData {
+ ObjectPathPair current_paths;
+ ObjectRef lhs_object;
+ ObjectRef rhs_object;
+ Optional<ObjectPathPair>* first_mismatch;
+
+ ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const {
+ Optional<String> lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), &lhs);
+ Optional<String> rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), &rhs);
+ return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key),
+ current_paths->rhs_path->Attr(rhs_attr_key));
+ }
+};
+
+bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
+ if (tracing_data_ == nullptr) {
+ // Fast path: no tracing
+ return handler_->SEqualReduce(lhs, rhs, map_free_vars_, NullOpt);
+ }
+ return ObjectAttrsEqual(lhs, rhs, map_free_vars_, nullptr);
+}
+
+bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
+ if (tracing_data_ == nullptr) {
+ // Fast path: no tracing
+ return handler_->SEqualReduce(lhs, rhs, true, NullOpt);
+ }
+ return ObjectAttrsEqual(lhs, rhs, true, nullptr);
+}
+
+/* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch(
+ const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) {
+ if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) {
+ Optional<String> lhs_attr_key =
+ GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address);
+ Optional<String> rhs_attr_key =
+ GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address);
+ *tracing_data->first_mismatch =
+ ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key),
+ tracing_data->current_paths->rhs_path->Attr(rhs_attr_key));
+ }
+}
+
+template <typename T>
+/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs,
+ const PathTracingData* tracing_data) {
+ if (BaseValueEqual()(lhs, rhs)) {
+ return true;
+ } else {
+ GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
+ return false;
+ }
+}
+
+bool SEqualReducer::operator()(const double& lhs, const double& rhs) const {
+ return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const {
+ return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const {
+ return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const int& lhs, const int& rhs) const {
+ return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const {
+ return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const {
+ return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const {
+ return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address,
+ const void* rhs_address) const {
+ if (lhs == rhs) {
+ return true;
+ } else {
+ GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_);
+ return false;
+ }
+}
+
+const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const {
+ ICHECK(tracing_data_ != nullptr)
+ << "GetCurrentObjectPaths() can only be called when path tracing is enabled";
+ return tracing_data_->current_paths;
+}
+
+void SEqualReducer::RecordMismatchPaths(const ObjectPathPair& paths) const {
+ ICHECK(tracing_data_ != nullptr)
+ << "RecordMismatchPaths() can only be called when path tracing is enabled";
+ if (!tracing_data_->first_mismatch->defined()) {
+ *tracing_data_->first_mismatch = paths;
+ }
+}
+
+bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+ const ObjectPathPair* paths) const {
+ if (tracing_data_ == nullptr) {
+ // Fast path: no tracing
+ return handler_->SEqualReduce(lhs, rhs, map_free_vars, NullOpt);
+ }
+
+ // Slow path: tracing object paths for better error reporting
+
+ ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths;
+
+ if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) {
+ return true;
+ } else {
+ if (!tracing_data_->first_mismatch->defined()) {
+ *tracing_data_->first_mismatch = new_paths;
+ }
+ return false;
+ }
+}
+
/*!
* \brief A non recursive stack based SEqual handler that can remaps vars.
*
@@ -53,9 +200,11 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
*/
class RemapVarSEqualHandler : public SEqualReducer::Handler {
public:
- explicit RemapVarSEqualHandler(bool assert_mode) : assert_mode_(assert_mode) {}
+ explicit RemapVarSEqualHandler(bool assert_mode, Optional<ObjectPathPair>* first_mismatch)
+ : assert_mode_(assert_mode), first_mismatch_(first_mismatch) {}
- bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
+ bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+ const Optional<ObjectPathPair>& current_paths) final {
// We cannot use check lhs.same_as(rhs) to check equality.
// if we choose to enable var remapping.
//
@@ -82,11 +231,16 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;
+
// need to push to pending tasks in this case
- pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars));
+ pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths);
return true;
};
- return CheckResult(run(), lhs, rhs);
+ return CheckResult(run(), lhs, rhs, current_paths);
+ }
+
+ void DeferFail(const ObjectPathPair& mismatch_paths) final {
+ pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths);
}
void MarkGraphNode() final {
@@ -108,7 +262,16 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
pending_tasks_.clear();
equal_map_lhs_.clear();
equal_map_rhs_.clear();
- if (!SEqualReduce(lhs, rhs, map_free_vars)) return false;
+
+ Optional<ObjectPathPair> current_paths;
+ if (IsPathTracingEnabled()) {
+ auto root_path = ObjectPath::Root();
+ current_paths = ObjectPathPair(root_path, root_path);
+ }
+ if (!SEqualReduce(lhs, rhs, map_free_vars, current_paths)) {
+ return false;
+ }
+
ICHECK_EQ(pending_tasks_.size(), 1U);
ICHECK(allow_push_to_stack_);
task_stack_.emplace_back(std::move(pending_tasks_.back()));
@@ -118,7 +281,11 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
protected:
// Check the result.
- bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
+ bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs,
+ const Optional<ObjectPathPair>& current_paths) {
+ if (IsPathTracingEnabled() && !result && !first_mismatch_->defined()) {
+ *first_mismatch_ = current_paths;
+ }
if (assert_mode_ && !result) {
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
<< PrettyPrint(lhs) << std::endl
@@ -137,6 +304,13 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
// Caution: entry becomes invalid when the stack changes
auto& entry = task_stack_.back();
+ if (entry.force_fail) {
+ if (IsPathTracingEnabled() && !first_mismatch_->defined()) {
+ *first_mismatch_ = entry.current_paths;
+ }
+ return false;
+ }
+
if (entry.children_expanded) {
// When all the children has expanded and visited.
// This means all the condition checks for
@@ -161,7 +335,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
// which populates the pending tasks.
ICHECK_EQ(pending_tasks_.size(), 0U);
allow_push_to_stack_ = false;
- if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars)) return false;
+ if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars, entry.current_paths))
+ return false;
allow_push_to_stack_ = true;
// Push pending tasks in reverse order, so earlier tasks get to
// expand first in the stack
@@ -175,7 +350,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
}
// The default equal as registered in the structural equal vtable.
- bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+ bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+ const Optional<ObjectPathPair>& current_paths) {
auto compute = [=]() {
ICHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index());
// skip entries that already have equality maps.
@@ -184,10 +360,18 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;
+
// Run reduce check for free nodes.
- return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, map_free_vars));
+ if (!IsPathTracingEnabled()) {
+ return vtable_->SEqualReduce(lhs.get(), rhs.get(),
+ SEqualReducer(this, nullptr, map_free_vars));
+ } else {
+ PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_};
+ return vtable_->SEqualReduce(lhs.get(), rhs.get(),
+ SEqualReducer(this, &tracing_data, map_free_vars));
+ }
};
- return CheckResult(compute(), lhs, rhs);
+ return CheckResult(compute(), lhs, rhs, current_paths);
}
private:
@@ -197,17 +381,32 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
ObjectRef lhs;
/*! \brief The rhs operand to be compared. */
ObjectRef rhs;
+ /*! \brief If path tracing is enabled, paths taken so far from the root to `lhs` and `rhs`
+ * objects. */
+ Optional<ObjectPathPair> current_paths;
/*! \brief The map free var argument. */
bool map_free_vars;
/*! \brief Whether the children has been expanded via SEqualReduce */
bool children_expanded{false};
/*! \brief whether the task is about graph equality(need remap). */
bool graph_equal{false};
+ /*! \brief whether the task should return "false" without actually comparing anything */
+ bool force_fail{false};
Task() = default;
- Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars)
- : lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {}
+ Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars, Optional<ObjectPathPair> current_paths)
+ : lhs(lhs),
+ rhs(rhs),
+ current_paths(std::move(current_paths)),
+ map_free_vars(map_free_vars) {}
+
+ struct ForceFailTag {}; // dispatch tag for the constructor below
+ Task(ForceFailTag, const ObjectPathPair& current_paths)
+ : current_paths(current_paths), force_fail(true) {}
};
+
+ bool IsPathTracingEnabled() const { return first_mismatch_ != nullptr; }
+
// list of pending tasks to be pushed to the stack.
std::vector<Task> pending_tasks_;
// Internal task stack to executed the task.
@@ -216,6 +415,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
bool allow_push_to_stack_{true};
// If in assert mode, must return true, and will throw error otherwise.
bool assert_mode_{false};
+ // Location to store the paths to the first detected mismatch, or nullptr to disable path tracing.
+ Optional<ObjectPathPair>* first_mismatch_;
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
// map from lhs to rhs
@@ -227,11 +428,19 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
bool map_free_vars) {
- return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
+ return RemapVarSEqualHandler(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars);
+ });
+
+TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
+ .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+ Optional<ObjectPathPair> first_mismatch;
+ bool equal = RemapVarSEqualHandler(false, &first_mismatch).Equal(lhs, rhs, map_free_vars);
+ ICHECK(equal == !first_mismatch.defined());
+ return first_mismatch;
});
bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
- return RemapVarSEqualHandler(false).Equal(lhs, rhs, false);
+ return RemapVarSEqualHandler(false, nullptr).Equal(lhs, rhs, false);
}
} // namespace tvm
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 23811e2190..b40b1751fb 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -22,6 +22,7 @@
#include <dmlc/memory_io.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
+#include <tvm/node/object_path.h>
#include <tvm/node/reflection.h>
#include <tvm/node/structural_hash.h>
#include <tvm/runtime/container/adt.h>
@@ -395,12 +396,73 @@ struct ArrayNodeTrait {
}
static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) {
+ if (equal.IsPathTracingEnabled()) {
+ return SEqualReduceTraced(lhs, rhs, equal);
+ }
+
if (lhs->size() != rhs->size()) return false;
for (size_t i = 0; i < lhs->size(); ++i) {
if (!equal(lhs->at(i), rhs->at(i))) return false;
}
return true;
}
+
+ private:
+ static bool SEqualReduceTraced(const ArrayNode* lhs, const ArrayNode* rhs,
+ const SEqualReducer& equal) {
+ size_t min_size = std::min(lhs->size(), rhs->size());
+ const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths();
+
+ for (size_t index = 0; index < min_size; ++index) {
+ ObjectPathPair element_paths = {array_paths->lhs_path->ArrayIndex(index),
+ array_paths->rhs_path->ArrayIndex(index)};
+ if (!equal(lhs->at(index), rhs->at(index), element_paths)) {
+ return false;
+ }
+ }
+
+ if (lhs->size() == rhs->size()) {
+ return true;
+ }
+
+ // If the array length is mismatched, don't report it immediately.
+ // Instead, defer the failure until we visit all children.
+ //
+ // This is for human readability. For example, say we have two sequences
+ //
+ // (1) a b c d e f g h i j k l m
+ // (2) a b c d e g h i j k l m
+ //
+ // If we directly report a mismatch at the end of the array right now,
+ // the user will see that array (1) has an element `m` at index 12 but array (2)
+ // has no index 12 because it's too short:
+ //
+ // (1) a b c d e f g h i j k l m
+ // ^error here
+ // (2) a b c d e g h i j k l m
+ // ^ error here
+ //
+ // This is not very helpful. Instead, if we defer reporting this mismatch until all elements
+ // are fully visited, we can be much more helpful with pointing out the location:
+ //
+ // (1) a b c d e f g h i j k l m
+ // ^
+ // error here
+ //
+ // (2) a b c d e g h i j k l m
+ // ^
+ // error here
+ if (lhs->size() > min_size) {
+ equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
+ array_paths->rhs_path->MissingArrayElement(min_size)});
+ } else {
+ equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
+ array_paths->rhs_path->ArrayIndex(min_size)});
+ }
+
+ // Can return `true` pretending that everything is good since we have deferred the failure.
+ return true;
+ }
};
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
.set_creator([](const std::string&) -> ObjectPtr<Object> {
@@ -501,13 +563,105 @@ struct MapNodeTrait {
return true;
}
+ static bool IsStringMap(const MapNode* map) {
+ return std::all_of(map->begin(), map->end(),
+ [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
+ }
+
+ static bool SEqualReduceTracedForOMap(const MapNode* lhs, const MapNode* rhs,
+ const SEqualReducer& equal) {
+ const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths();
+
+ std::vector<const Object*> seen_rhs_keys;
+
+ // First, check that every key from `lhs` is also in `rhs`,
+ // and their values are mapped to each other.
+ for (const auto& kv : *lhs) {
+ ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first);
+
+ ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
+ if (!rhs_key.defined()) {
+ equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
+ return false;
+ }
+
+ auto it = rhs->find(rhs_key);
+ if (it == rhs->end()) {
+ equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
+ return false;
+ }
+
+ if (!equal(kv.second, it->second, {lhs_path, map_paths->rhs_path->MapValue(it->first)})) {
+ return false;
+ }
+
+ seen_rhs_keys.push_back(it->first.get());
+ }
+
+ std::sort(seen_rhs_keys.begin(), seen_rhs_keys.end());
+
+ // Second, check that we have visited every `rhs` key when iterating over `lhs`.
+ for (const auto& kv : *rhs) {
+ if (!std::binary_search(seen_rhs_keys.begin(), seen_rhs_keys.end(), kv.first.get())) {
+ equal.RecordMismatchPaths(
+ {map_paths->lhs_path->MissingMapEntry(), map_paths->rhs_path->MapValue(kv.first)});
+ return false;
+ }
+ }
+
+ ICHECK(lhs->size() == rhs->size());
+ return true;
+ }
+
+ static bool SEqualReduceTracedForSMap(const MapNode* lhs, const MapNode* rhs,
+ const SEqualReducer& equal) {
+ const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths();
+
+ // First, check that every key from `lhs` is also in `rhs`, and their values are equal.
+ for (const auto& kv : *lhs) {
+ ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first);
+ auto it = rhs->find(kv.first);
+ if (it == rhs->end()) {
+ equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
+ return false;
+ }
+
+ if (!equal(kv.second, it->second, {lhs_path, map_paths->rhs_path->MapValue(it->first)})) {
+ return false;
+ }
+ }
+
+ // Second, make sure every key from `rhs` is also in `lhs`.
+ for (const auto& kv : *rhs) {
+ ObjectPath rhs_path = map_paths->rhs_path->MapValue(kv.first);
+ if (!lhs->count(kv.first)) {
+ equal.RecordMismatchPaths({map_paths->lhs_path->MissingMapEntry(), rhs_path});
+ return false;
+ }
+ }
+
+ ICHECK(lhs->size() == rhs->size());
+ return true;
+ }
+
+ static bool SEqualReduceTraced(const MapNode* lhs, const MapNode* rhs,
+ const SEqualReducer& equal) {
+ if (IsStringMap(lhs)) {
+ return SEqualReduceTracedForSMap(lhs, rhs, equal);
+ } else {
+ return SEqualReduceTracedForOMap(lhs, rhs, equal);
+ }
+ }
+
static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
+ if (equal.IsPathTracingEnabled()) {
+ return SEqualReduceTraced(lhs, rhs, equal);
+ }
+
if (rhs->size() != lhs->size()) return false;
if (rhs->size() == 0) return true;
- bool ls = std::all_of(lhs->begin(), lhs->end(),
- [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
- bool rs = std::all_of(rhs->begin(), rhs->end(),
- [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
+ bool ls = IsStringMap(lhs);
+ bool rs = IsStringMap(rhs);
if (ls != rs) {
return false;
}
diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc
index 7f48cc4392..451855c8f8 100644
--- a/src/tir/analysis/deep_equal.cc
+++ b/src/tir/analysis/deep_equal.cc
@@ -21,6 +21,7 @@
* \file tir/analysis/deep_equal.cc
* \brief Deep equality checking.
*/
+#include <tvm/node/object_path.h>
#include <tvm/node/reflection.h>
#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>
@@ -32,21 +33,25 @@ namespace tir {
class DeepCmpSEqualHandler : public SEqualReducer::Handler {
public:
// use direct recursion.
- bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
+ bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+ const Optional<ObjectPathPair>&) final {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
- return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false));
+ return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, nullptr, false)) &&
+ !fail_;
}
- ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); }
+ void DeferFail(const ObjectPathPair&) final { fail_ = true; }
+ ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); }
void MarkGraphNode() final {}
private:
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
+ bool fail_ = false;
};
bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
@@ -62,7 +67,7 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
if (lhs.as<AnyNode>()) {
return false;
}
- return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false);
+ return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, NullOpt);
}
TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal")
diff --git a/tests/python/unittest/test_container_structural_equal.py b/tests/python/unittest/test_container_structural_equal.py
new file mode 100644
index 0000000000..cdd9ffb7af
--- /dev/null
+++ b/tests/python/unittest/test_container_structural_equal.py
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.ir.base import get_first_structural_mismatch
+from tvm.runtime import ObjectPath
+
+
+def get_first_mismatch_ensure_symmetry(a, b):
+ mismatch = get_first_structural_mismatch(a, b)
+ mismatch_swapped = get_first_structural_mismatch(b, a)
+
+ if mismatch is None and mismatch_swapped is None:
+ return None
+
+ if (
+ mismatch is None
+ or mismatch_swapped is None
+ or mismatch[0] != mismatch_swapped[1]
+ or mismatch[1] != mismatch_swapped[0]
+ ):
+ raise AssertionError(
+ "get_first_structural_mismatch(a, b) and get_first_structural_mismatch(b, a) returned"
+ " inconsistent results '{}' and '{}' for a='{}', b='{}'".format(
+ mismatch, mismatch_swapped, a, b
+ )
+ )
+
+ a_path, b_path = mismatch
+ b_path_swapped, a_path_swapped = mismatch_swapped
+ assert a_path == a_path_swapped
+ assert b_path == b_path_swapped
+
+ return mismatch
+
+
+@pytest.mark.parametrize(
+ "a, b, expected_a_path, expected_b_path",
+ [
+ (
+ [1, 2, 3],
+ [1, 4, 3],
+ ObjectPath.root().array_index(1).attr("value"),
+ ObjectPath.root().array_index(1).attr("value"),
+ ),
+ (
+ [1, 2, 3],
+ [10, 2, 30],
+ ObjectPath.root().array_index(0).attr("value"),
+ ObjectPath.root().array_index(0).attr("value"),
+ ),
+ (
+ [1, 3, 4],
+ [1, 2, 3, 4],
+ ObjectPath.root().array_index(1).attr("value"),
+ ObjectPath.root().array_index(1).attr("value"),
+ ),
+ (
+ [1, 2, 3],
+ [1, 2, 3, 4],
+ ObjectPath.root().missing_array_element(3),
+ ObjectPath.root().array_index(3),
+ ),
+ (
+ [],
+ [1],
+ ObjectPath.root().missing_array_element(0),
+ ObjectPath.root().array_index(0),
+ ),
+ ],
+)
+def test_array_structural_mismatch(a, b, expected_a_path, expected_b_path):
+ a = tvm.runtime.convert(a)
+ b = tvm.runtime.convert(b)
+ a_path, b_path = get_first_mismatch_ensure_symmetry(a, b)
+ assert a_path == expected_a_path
+ assert b_path == expected_b_path
+
+
+@pytest.mark.parametrize(
+ "contents",
+ [
+ [],
+ [1],
+ [1, 2, 3],
+ ],
+)
+def test_array_structural_equal_to_self(contents):
+ a = tvm.runtime.convert(list(contents))
+ b = tvm.runtime.convert(list(contents))
+ assert get_first_mismatch_ensure_symmetry(a, b) is None
+
+
+@pytest.mark.parametrize(
+ "a, b, expected_a_path, expected_b_path",
+ [
+ (
+ dict(a=3, b=4),
+ dict(a=3, b=5),
+ ObjectPath.root().map_value("b").attr("value"),
+ ObjectPath.root().map_value("b").attr("value"),
+ ),
+ (
+ dict(a=3, b=4),
+ dict(a=3, b=4, c=5),
+ ObjectPath.root().missing_map_entry(),
+ ObjectPath.root().map_value("c"),
+ ),
+ ],
+)
+def test_string_map_structural_mismatch(a, b, expected_a_path, expected_b_path):
+ a = tvm.runtime.convert(a)
+ b = tvm.runtime.convert(b)
+ a_path, b_path = get_first_mismatch_ensure_symmetry(a, b)
+ assert a_path == expected_a_path
+ assert b_path == expected_b_path
+
+
+@pytest.mark.parametrize(
+ "contents",
+ [
+ dict(),
+ dict(a=1),
+ dict(a=3, b=4, c=5),
+ ],
+)
+def test_string_structural_equal_to_self(contents):
+ a = tvm.runtime.convert(dict(contents))
+ b = tvm.runtime.convert(dict(contents))
+ assert get_first_mismatch_ensure_symmetry(a, b) is None
+
+
+# The behavior of structural equality for maps with non-string keys is fairly specific
+# to IR variables because it assumes that map keys have been "mapped" using
+# `SEqualReducer::FreeVarEqualImpl()`. So we leave this case to TIR tests.
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py
index ff02f1e369..d5feb21f0d 100644
--- a/tests/python/unittest/test_tir_structural_equal_hash.py
+++ b/tests/python/unittest/test_tir_structural_equal_hash.py
@@ -18,6 +18,7 @@ import tvm
import numpy as np
import pytest
from tvm import te
+from tvm.runtime import ObjectPath
def consistent_equal(x, y, map_free_vars=False):
@@ -29,7 +30,7 @@ def consistent_equal(x, y, map_free_vars=False):
if struct_equal0 != struct_equal1:
raise ValueError(
- "Non-communicative {} vs {}, sequal0={}, sequal1={}".format(
+ "Non-commutative {} vs {}, sequal0={}, sequal1={}".format(
x, y, struct_equal0, struct_equal1
)
)
@@ -45,6 +46,28 @@ def consistent_equal(x, y, map_free_vars=False):
return struct_equal0
+def get_sequal_mismatch(x, y, map_free_vars=False):
+ mismatch_0 = tvm.ir.base.get_first_structural_mismatch(x, y, map_free_vars)
+ mismatch_1 = tvm.ir.base.get_first_structural_mismatch(y, x, map_free_vars)
+
+ if mismatch_0 is None and mismatch_1 is None:
+ return None
+
+ if (
+ mismatch_0 is None
+ or mismatch_1 is None
+ or mismatch_0[0] != mismatch_1[1]
+ or mismatch_0[1] != mismatch_1[0]
+ ):
+ raise ValueError(
+ "Non-commutative {} vs {}, mismatch_0={}, mismatch_1={}".format(
+ x, y, mismatch_0, mismatch_1
+ )
+ )
+
+ return mismatch_0
+
+
def test_exprs():
# save load json
x = tvm.tir.const(1, "int32")
@@ -107,6 +130,47 @@ def test_prim_func():
tvm.ir.assert_structural_equal(mod0, mod1)
+def test_prim_func_param_count_mismatch():
+ x = te.var("x")
+ y = te.var("y")
+ z = te.var("z")
+ # counter example of same equality
+ func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x))
+ func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x))
+ lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
+ expected_lhs_path = ObjectPath.root().attr("params").missing_array_element(2)
+ expected_rhs_path = ObjectPath.root().attr("params").array_index(2)
+ assert lhs_path == expected_lhs_path
+ assert rhs_path == expected_rhs_path
+
+
+def test_prim_func_param_dtype_mismatch():
+ x = te.var("x")
+ y_0 = te.var("y", dtype="int32")
+ y_1 = te.var("z", dtype="float32")
+ # counter example of same equality
+ func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x))
+ func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x))
+ lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
+ expected_path = ObjectPath.root().attr("params").array_index(1).attr("dtype")
+ assert lhs_path == expected_path
+ assert rhs_path == expected_path
+
+
+def test_prim_func_body_mismatch():
+ x_0 = te.var("x")
+ y_0 = te.var("y")
+ x_1 = te.var("x")
+ y_1 = te.var("y")
+ # counter example of same equality
+ func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0))
+ func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1))
+ lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
+ expected_path = ObjectPath.root().attr("body").attr("value").attr("b")
+ assert lhs_path == expected_path
+ assert rhs_path == expected_path
+
+
def test_array():
x = np.arange(10)
nx = tvm.nd.array(x)
@@ -183,6 +247,44 @@ def test_buffer_storage_scope():
assert not consistent_equal(func0, func2)
+def test_buffer_map_mismatch():
+ x = te.var("x")
+ buffer_0 = tvm.tir.decl_buffer((10, 10))
+ buffer_0_clone = tvm.tir.decl_buffer((10, 10))
+ buffer_1 = tvm.tir.decl_buffer((10, 20))
+
+ func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0})
+ func_0_clone = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0_clone})
+ func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_1})
+
+ lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
+ expected_path = (
+ ObjectPath.root().attr("buffer_map").map_value(x).attr("shape").array_index(1).attr("value")
+ )
+ assert lhs_path == expected_path
+ assert rhs_path == expected_path
+
+ assert get_sequal_mismatch(func_0, func_0_clone) is None
+
+
+def test_buffer_map_length_mismatch():
+ x = te.var("x")
+ y = te.var("x")
+
+ buffer_0 = tvm.tir.decl_buffer((10, 10))
+ buffer_1 = tvm.tir.decl_buffer((10, 20))
+
+ func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0})
+ func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0, y: buffer_1})
+
+ lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
+
+ expected_lhs_path = ObjectPath.root().attr("buffer_map").missing_map_entry()
+ assert lhs_path == expected_lhs_path
+ expected_rhs_path = ObjectPath.root().attr("buffer_map").map_value(y)
+ assert rhs_path == expected_rhs_path
+
+
def test_buffer_load_store():
b = tvm.tir.decl_buffer((10, 10), "float32")
x = tvm.tir.BufferLoad(b, [0, 1])
@@ -208,6 +310,90 @@ def test_while():
assert consistent_equal(wx, wy, map_free_vars=True)
+def test_while_condition_mismatch():
+ x = tvm.tir.Var("x", "int32")
+ w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
+ w_1 = tvm.tir.While(x < 0, tvm.tir.Evaluate(x))
+ lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
+ expected_path = ObjectPath.root().attr("condition")
+ assert lhs_path == expected_path
+ assert rhs_path == expected_path
+
+
+def test_while_body_mismatch():
+ x = tvm.tir.Var("x", "int32")
+ w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
+ w_1 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x + 1))
+ lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
+ expected_path = ObjectPath.root().attr("body").attr("value")
+ assert lhs_path == expected_path
+ assert rhs_path == expected_path
+
+
+def test_seq_mismatch():
+ x = tvm.tir.Var("x", "int32")
+ seq_0 = tvm.tir.SeqStmt(
+ [
+ tvm.tir.Evaluate(x),
+ tvm.tir.Evaluate(x + 1),
+ tvm.tir.Evaluate(x + 2),
+ tvm.tir.Evaluate(x + 3),
+ ]
+ )
+ seq_1 = tvm.tir.SeqStmt(
+ [
+ tvm.tir.Evaluate(x),
+ tvm.tir.Evaluate(x + 1),
+ tvm.tir.Evaluate(x + 99),
+ tvm.tir.Evaluate(x + 3),
+ ]
+ )
+ lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
+ expected_path = (
+ ObjectPath.root().attr("seq").array_index(2).attr("value").attr("b").attr("value")
+ )
+ assert lhs_path == expected_path
+ assert rhs_path == expected_path
+
+
+def test_seq_mismatch_different_lengths():
+ # Make sure we report a difference inside the array first, rather than the difference in length
+ x = tvm.tir.Var("x", "int32")
+ seq_0 = tvm.tir.SeqStmt(
+ [
+ tvm.tir.Evaluate(x),
+ tvm.tir.Evaluate(x + 1),
+ tvm.tir.Evaluate(x + 2),
+ tvm.tir.Evaluate(x + 3),
+ ]
+ )
+ seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 3)])
+ lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
+ expected_path = (
+ ObjectPath.root().attr("seq").array_index(2).attr("value").attr("b").attr("value")
+ )
+ assert lhs_path == expected_path
+ assert rhs_path == expected_path
+
+
+def test_seq_length_mismatch():
+ x = tvm.tir.Var("x", "int32")
+ seq_0 = tvm.tir.SeqStmt(
+ [
+ tvm.tir.Evaluate(x),
+ tvm.tir.Evaluate(x + 1),
+ tvm.tir.Evaluate(x + 2),
+ tvm.tir.Evaluate(x + 3),
+ ]
+ )
+ seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 2)])
+ lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
+ expected_lhs_path = ObjectPath.root().attr("seq").array_index(3)
+ expected_rhs_path = ObjectPath.root().attr("seq").missing_array_element(3)
+ assert lhs_path == expected_lhs_path
+ assert rhs_path == expected_rhs_path
+
+
if __name__ == "__main__":
test_exprs()
test_prim_func()