You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/03 08:29:05 UTC

[tvm] branch main updated: [TVMScript] Add object path tracing to StructuralEqual (#12101)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 39ffe0a5ce [TVMScript] Add object path tracing to StructuralEqual (#12101)
39ffe0a5ce is described below

commit 39ffe0a5ce14c2105b7d48ee420e82e194787d8f
Author: Greg Bonik <gb...@octoml.ai>
AuthorDate: Wed Aug 3 01:28:59 2022 -0700

    [TVMScript] Add object path tracing to StructuralEqual (#12101)
    
    Motivation: when two IR objects fail a structural equality check, currently there is no easy way to
    find out which part of the IR caused the mismatch. In this PR, we modify the `StructuralEqual`
    infrastructure to also optionally return a pair of `ObjectPath` objects that point to the mismatch.
    (See https://github.com/apache/tvm/pull/11977). In the upcoming PRs, we will pass these paths to the
    TIR printer, so that it could highlight the mismatch location nicely.
    
    Tracking issue: https://github.com/apache/tvm/issues/11912
---
 include/tvm/node/reflection.h                      |   6 +
 include/tvm/node/structural_equal.h                | 157 ++++++++++++--
 python/tvm/ir/base.py                              |  34 ++-
 python/tvm/runtime/__init__.py                     |   1 +
 python/tvm/runtime/object_path.py                  |  16 ++
 src/node/reflection.cc                             |  44 ++++
 src/node/structural_equal.cc                       | 237 +++++++++++++++++++--
 src/node/structural_hash.cc                        | 162 +++++++++++++-
 src/tir/analysis/deep_equal.cc                     |  13 +-
 .../unittest/test_container_structural_equal.py    | 155 ++++++++++++++
 .../unittest/test_tir_structural_equal_hash.py     | 188 +++++++++++++++-
 11 files changed, 969 insertions(+), 44 deletions(-)

diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h
index 5d514e24d8..f547b5a707 100644
--- a/include/tvm/node/reflection.h
+++ b/include/tvm/node/reflection.h
@@ -404,5 +404,11 @@ inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr
   }
 }
 
+/*!
+ * \brief Given an object and an address of its attribute, return the key of the attribute.
+ * \return nullptr if no attribute with the given address exists.
+ */
+Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address);
+
 }  // namespace tvm
 #endif  // TVM_NODE_REFLECTION_H_
diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h
index 6c25c3d2d2..b51021fe40 100644
--- a/include/tvm/node/structural_equal.h
+++ b/include/tvm/node/structural_equal.h
@@ -24,6 +24,7 @@
 #define TVM_NODE_STRUCTURAL_EQUAL_H_
 
 #include <tvm/node/functor.h>
+#include <tvm/node/object_path.h>
 #include <tvm/runtime/container/array.h>
 #include <tvm/runtime/data_type.h>
 
@@ -56,6 +57,27 @@ class BaseValueEqual {
   }
 };
 
+/*!
+ * \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
+ */
+class ObjectPathPairNode : public Object {
+ public:
+  ObjectPath lhs_path;
+  ObjectPath rhs_path;
+
+  ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);
+
+  static constexpr const char* _type_key = "ObjectPathPair";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object);
+};
+
+class ObjectPathPair : public ObjectRef {
+ public:
+  ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode);
+};
+
 /*!
  * \brief Content-aware structural equality comparator for objects.
  *
@@ -99,7 +121,10 @@ class StructuralEqual : public BaseValueEqual {
  * equality checking. Instead, it can store the necessary equality conditions
  * and check later via an internally managed stack.
  */
-class SEqualReducer : public BaseValueEqual {
+class SEqualReducer {
+ private:
+  struct PathTracingData;
+
  public:
   /*! \brief Internal handler that defines custom behaviors.. */
   class Handler {
@@ -110,12 +135,24 @@ class SEqualReducer : public BaseValueEqual {
      * \param lhs The left operand.
      * \param rhs The right operand.
      * \param map_free_vars Whether do we allow remap variables if possible.
+     * \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
      *
      * \return false if there is an immediate failure, true otherwise.
      * \note This function may save the equality condition of (lhs == rhs) in an internal
      *       stack and try to resolve later.
      */
-    virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0;
+    virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+                              const Optional<ObjectPathPair>& current_paths) = 0;
+
+    /*!
+     * \brief Mark the comparison as failed, but don't fail immediately.
+     *
+     * This is useful for producing better error messages when comparing containers.
+     * For example, if two array sizes mismatch, it's better to mark the comparison as failed
+     * but compare array elements anyway, so that we could find the true first mismatch.
+     */
+    virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;
+
     /*!
      * \brief Lookup the graph node equal map for vars that are already mapped.
      *
@@ -129,28 +166,72 @@ class SEqualReducer : public BaseValueEqual {
      * \brief Mark current comparison as graph node equal comparison.
      */
     virtual void MarkGraphNode() = 0;
-  };
 
-  using BaseValueEqual::operator();
+   protected:
+    using PathTracingData = SEqualReducer::PathTracingData;
+  };
 
   /*! \brief default constructor */
   SEqualReducer() = default;
   /*!
    * \brief Constructor with a specific handler.
    * \param handler The equal handler for objects.
+   * \param tracing_data Optional pointer to the path tracing data.
    * \param map_free_vars Whether or not to map free variables.
    */
-  explicit SEqualReducer(Handler* handler, bool map_free_vars)
-      : handler_(handler), map_free_vars_(map_free_vars) {}
+  explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
+      : handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}
+
+  /*!
+   * \brief Reduce condition to comparison of two attribute values.
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   * \return the immediate check result.
+   */
+  bool operator()(const double& lhs, const double& rhs) const;
+  bool operator()(const int64_t& lhs, const int64_t& rhs) const;
+  bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
+  bool operator()(const int& lhs, const int& rhs) const;
+  bool operator()(const bool& lhs, const bool& rhs) const;
+  bool operator()(const std::string& lhs, const std::string& rhs) const;
+  bool operator()(const DataType& lhs, const DataType& rhs) const;
+
+  template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+  bool operator()(const ENum& lhs, const ENum& rhs) const {
+    using Underlying = typename std::underlying_type<ENum>::type;
+    static_assert(std::is_same<Underlying, int>::value,
+                  "Enum must have `int` as the underlying type");
+    return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
+  }
+
+  /*!
+   * \brief Reduce condition to comparison of two objects.
+   * \param lhs The left operand.
+   * \param rhs The right operand.
+   * \return the immediate check result.
+   */
+  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
+
   /*!
    * \brief Reduce condition to comparison of two objects.
+   *
+   * Like `operator()`, but with an additional `paths` parameter that specifies explicit object
+   * paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
+   * objects like Array and Map, or other custom objects that store nested objects that are not
+   * simply attributes.
+   *
+   * Can only be called when `IsPathTracingEnabled()` is `true`.
+   *
    * \param lhs The left operand.
    * \param rhs The right operand.
+   * \param paths Object paths for `lhs` and `rhs`.
    * \return the immediate check result.
    */
-  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
-    return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
+  bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
+    ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
+    return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
   }
+
   /*!
    * \brief Reduce condition to comparison of two definitions,
    *        where free vars can be mapped.
@@ -162,9 +243,8 @@ class SEqualReducer : public BaseValueEqual {
    * \param rhs The right operand.
    * \return the immediate check result.
    */
-  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
-    return handler_->SEqualReduce(lhs, rhs, true);
-  }
+  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
+
   /*!
    * \brief Reduce condition to comparison of two arrays.
    * \param lhs The left operand.
@@ -173,13 +253,20 @@ class SEqualReducer : public BaseValueEqual {
    */
   template <typename T>
   bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
-    // quick specialization for Array to reduce amount of recursion
-    // depth as array comparison is pretty common.
-    if (lhs.size() != rhs.size()) return false;
-    for (size_t i = 0; i < lhs.size(); ++i) {
-      if (!(operator()(lhs[i], rhs[i]))) return false;
+    if (tracing_data_ == nullptr) {
+      // quick specialization for Array to reduce amount of recursion
+      // depth as array comparison is pretty common.
+      if (lhs.size() != rhs.size()) return false;
+      for (size_t i = 0; i < lhs.size(); ++i) {
+        if (!(operator()(lhs[i], rhs[i]))) return false;
+      }
+      return true;
     }
-    return true;
+
+    // If tracing is enabled, fall back to the regular path
+    const ObjectRef& lhs_obj = lhs;
+    const ObjectRef& rhs_obj = rhs;
+    return (*this)(lhs_obj, rhs_obj);
   }
   /*!
    * \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
@@ -198,11 +285,43 @@ class SEqualReducer : public BaseValueEqual {
   /*! \return Get the internal handler. */
   Handler* operator->() const { return handler_; }
 
+  /*! \brief Check if this reducer is tracing paths to the first mismatch. */
+  bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }
+
+  /*!
+   * \brief Get the paths of the currently compared objects.
+   *
+   * Can only be called when `IsPathTracingEnabled()` is true.
+   */
+  const ObjectPathPair& GetCurrentObjectPaths() const;
+
+  /*!
+   * \brief Specify the object paths of a detected mismatch.
+   *
+   * Can only be called when `IsPathTracingEnabled()` is true.
+   */
+  void RecordMismatchPaths(const ObjectPathPair& paths) const;
+
  private:
+  bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
+
+  bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+                        const ObjectPathPair* paths) const;
+
+  static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
+                                                        const void* rhs_address,
+                                                        const PathTracingData* tracing_data);
+
+  template <typename T>
+  static bool CompareAttributeValues(const T& lhs, const T& rhs,
+                                     const PathTracingData* tracing_data);
+
   /*! \brief Internal class pointer. */
-  Handler* handler_;
+  Handler* handler_ = nullptr;
+  /*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
+  const PathTracingData* tracing_data_ = nullptr;
   /*! \brief Whether or not to map free vars. */
-  bool map_free_vars_;
+  bool map_free_vars_ = false;
 };
 
 }  // namespace tvm
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 00514b472d..5b26d5e4fb 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -191,8 +191,8 @@ def structural_equal(lhs, rhs, map_free_vars=False):
         The left operand.
 
     map_free_vars : bool
-        Whether or not shall we map free vars that does
-        not bound to any definitions as equal to each other.
+        Whether free variables (i.e. variables without a definition site) should be mapped
+        as equal to each other.
 
     Return
     ------
@@ -209,6 +209,36 @@ def structural_equal(lhs, rhs, map_free_vars=False):
     return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))
 
 
+def get_first_structural_mismatch(lhs, rhs, map_free_vars=False):
+    """Like structural_equal(), but returns the ObjectPaths of the first detected mismatch.
+
+    Parameters
+    ----------
+    lhs : Object
+        The left operand.
+
+    rhs : Object
+        The left operand.
+
+    map_free_vars : bool
+        Whether free variables (i.e. variables without a definition site) should be mapped
+        as equal to each other.
+
+    Returns
+    -------
+    mismatch: Optional[Tuple[ObjectPath, ObjectPath]]
+        `None` if `lhs` and `rhs` are structurally equal.
+        Otherwise, a tuple of two ObjectPath objects that point to the first detected mismtach.
+    """
+    lhs = tvm.runtime.convert(lhs)
+    rhs = tvm.runtime.convert(rhs)
+    mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars)
+    if mismatch is None:
+        return None
+    else:
+        return mismatch.lhs_path, mismatch.rhs_path
+
+
 def assert_structural_equal(lhs, rhs, map_free_vars=False):
     """Assert lhs and rhs are structurally equal to each other.
 
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index 114f01dd0e..502de73721 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -19,6 +19,7 @@
 # class exposures
 from .packed_func import PackedFunc
 from .object import Object
+from .object_path import ObjectPath, ObjectPathPair
 from .object_generic import ObjectGeneric, ObjectTypes
 from .ndarray import NDArray, DataType, DataTypeCode, Device
 from .module import Module, num_threads
diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py
index 3eabce1f86..c4ec58a596 100644
--- a/python/tvm/runtime/object_path.py
+++ b/python/tvm/runtime/object_path.py
@@ -34,6 +34,7 @@ __all__ = (
     "MissingArrayElementPath",
     "MapValuePath",
     "MissingMapEntryPath",
+    "ObjectPathPair",
 )
 
 
@@ -122,3 +123,18 @@ class MapValuePath(ObjectPath):
 @tvm._ffi.register_object("MissingMapEntryPath")
 class MissingMapEntryPath(ObjectPath):
     pass
+
+
+@tvm._ffi.register_object("ObjectPathPair")
+class ObjectPathPair(Object):
+    """
+    Pair of ObjectPaths, one for each object being tested for structural equality.
+    """
+
+    @property
+    def lhs_path(self) -> ObjectPath:
+        return _ffi_node_api.ObjectPathPairLhsPath(self)
+
+    @property
+    def rhs_path(self) -> ObjectPath:
+        return _ffi_node_api.ObjectPathPairRhsPath(self)
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index a7c3493e7f..a0f83f6cf5 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -281,4 +281,48 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
 TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);
 
 TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);
+
+namespace {
+// Attribute visitor class for finding the attribute key by its address
+class GetAttrKeyByAddressVisitor : public AttrVisitor {
+ public:
+  explicit GetAttrKeyByAddressVisitor(const void* attr_address)
+      : attr_address_(attr_address), key_(nullptr) {}
+
+  void Visit(const char* key, double* value) final { DoVisit(key, value); }
+  void Visit(const char* key, int64_t* value) final { DoVisit(key, value); }
+  void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); }
+  void Visit(const char* key, int* value) final { DoVisit(key, value); }
+  void Visit(const char* key, bool* value) final { DoVisit(key, value); }
+  void Visit(const char* key, std::string* value) final { DoVisit(key, value); }
+  void Visit(const char* key, void** value) final { DoVisit(key, value); }
+  void Visit(const char* key, DataType* value) final { DoVisit(key, value); }
+  void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); }
+  void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); }
+
+  const char* GetKey() const { return key_; }
+
+ private:
+  const void* attr_address_;
+  const char* key_;
+
+  void DoVisit(const char* key, const void* candidate) {
+    if (attr_address_ == candidate) {
+      key_ = key;
+    }
+  }
+};
+}  // anonymous namespace
+
+Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address) {
+  GetAttrKeyByAddressVisitor visitor(attr_address);
+  ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object), &visitor);
+  const char* key = visitor.GetKey();
+  if (key == nullptr) {
+    return NullOpt;
+  } else {
+    return String(key);
+  }
+}
+
 }  // namespace tvm
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index 8e52af60d2..01874c0536 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -22,6 +22,7 @@
 #include <tvm/ir/module.h>
 #include <tvm/node/functor.h>
 #include <tvm/node/node.h>
+#include <tvm/node/object_path.h>
 #include <tvm/node/reflection.h>
 #include <tvm/node/structural_equal.h>
 #include <tvm/runtime/registry.h>
@@ -30,6 +31,25 @@
 
 namespace tvm {
 
+TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode);
+
+TVM_REGISTER_GLOBAL("node.ObjectPathPairLhsPath")
+    .set_body_typed([](const ObjectPathPair& object_path_pair) {
+      return object_path_pair->lhs_path;
+    });
+
+TVM_REGISTER_GLOBAL("node.ObjectPathPairRhsPath")
+    .set_body_typed([](const ObjectPathPair& object_path_pair) {
+      return object_path_pair->rhs_path;
+    });
+
+ObjectPathPairNode::ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path)
+    : lhs_path(std::move(lhs_path)), rhs_path(std::move(rhs_path)) {}
+
+ObjectPathPair::ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path) {
+  data_ = make_object<ObjectPathPairNode>(std::move(lhs_path), std::move(rhs_path));
+}
+
 // Define the dispatch function here since primary user is in this file.
 bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
                                     SEqualReducer equal) const {
@@ -42,6 +62,133 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
   return fsequal_reduce_[tindex](self, other, equal);
 }
 
+struct SEqualReducer::PathTracingData {
+  ObjectPathPair current_paths;
+  ObjectRef lhs_object;
+  ObjectRef rhs_object;
+  Optional<ObjectPathPair>* first_mismatch;
+
+  ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const {
+    Optional<String> lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), &lhs);
+    Optional<String> rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), &rhs);
+    return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key),
+                          current_paths->rhs_path->Attr(rhs_attr_key));
+  }
+};
+
+bool SEqualReducer::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
+  if (tracing_data_ == nullptr) {
+    // Fast path: no tracing
+    return handler_->SEqualReduce(lhs, rhs, map_free_vars_, NullOpt);
+  }
+  return ObjectAttrsEqual(lhs, rhs, map_free_vars_, nullptr);
+}
+
+bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
+  if (tracing_data_ == nullptr) {
+    // Fast path: no tracing
+    return handler_->SEqualReduce(lhs, rhs, true, NullOpt);
+  }
+  return ObjectAttrsEqual(lhs, rhs, true, nullptr);
+}
+
+/* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch(
+    const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) {
+  if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) {
+    Optional<String> lhs_attr_key =
+        GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address);
+    Optional<String> rhs_attr_key =
+        GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address);
+    *tracing_data->first_mismatch =
+        ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key),
+                       tracing_data->current_paths->rhs_path->Attr(rhs_attr_key));
+  }
+}
+
+template <typename T>
+/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs,
+                                                        const PathTracingData* tracing_data) {
+  if (BaseValueEqual()(lhs, rhs)) {
+    return true;
+  } else {
+    GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
+    return false;
+  }
+}
+
+bool SEqualReducer::operator()(const double& lhs, const double& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const int& lhs, const int& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const {
+  return CompareAttributeValues(lhs, rhs, tracing_data_);
+}
+
+bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address,
+                                   const void* rhs_address) const {
+  if (lhs == rhs) {
+    return true;
+  } else {
+    GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_);
+    return false;
+  }
+}
+
+const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const {
+  ICHECK(tracing_data_ != nullptr)
+      << "GetCurrentObjectPaths() can only be called when path tracing is enabled";
+  return tracing_data_->current_paths;
+}
+
+void SEqualReducer::RecordMismatchPaths(const ObjectPathPair& paths) const {
+  ICHECK(tracing_data_ != nullptr)
+      << "RecordMismatchPaths() can only be called when path tracing is enabled";
+  if (!tracing_data_->first_mismatch->defined()) {
+    *tracing_data_->first_mismatch = paths;
+  }
+}
+
+bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+                                     const ObjectPathPair* paths) const {
+  if (tracing_data_ == nullptr) {
+    // Fast path: no tracing
+    return handler_->SEqualReduce(lhs, rhs, map_free_vars, NullOpt);
+  }
+
+  // Slow path: tracing object paths for better error reporting
+
+  ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths;
+
+  if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) {
+    return true;
+  } else {
+    if (!tracing_data_->first_mismatch->defined()) {
+      *tracing_data_->first_mismatch = new_paths;
+    }
+    return false;
+  }
+}
+
 /*!
  * \brief A non recursive stack based SEqual handler that can remaps vars.
  *
@@ -53,9 +200,11 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other,
  */
 class RemapVarSEqualHandler : public SEqualReducer::Handler {
  public:
-  explicit RemapVarSEqualHandler(bool assert_mode) : assert_mode_(assert_mode) {}
+  explicit RemapVarSEqualHandler(bool assert_mode, Optional<ObjectPathPair>* first_mismatch)
+      : assert_mode_(assert_mode), first_mismatch_(first_mismatch) {}
 
-  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
+  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+                    const Optional<ObjectPathPair>& current_paths) final {
     // We cannot use check lhs.same_as(rhs) to check equality.
     // if we choose to enable var remapping.
     //
@@ -82,11 +231,16 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
         return it->second.same_as(rhs);
       }
       if (equal_map_rhs_.count(rhs)) return false;
+
       // need to push to pending tasks in this case
-      pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars));
+      pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths);
       return true;
     };
-    return CheckResult(run(), lhs, rhs);
+    return CheckResult(run(), lhs, rhs, current_paths);
+  }
+
+  void DeferFail(const ObjectPathPair& mismatch_paths) final {
+    pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths);
   }
 
   void MarkGraphNode() final {
@@ -108,7 +262,16 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
     pending_tasks_.clear();
     equal_map_lhs_.clear();
     equal_map_rhs_.clear();
-    if (!SEqualReduce(lhs, rhs, map_free_vars)) return false;
+
+    Optional<ObjectPathPair> current_paths;
+    if (IsPathTracingEnabled()) {
+      auto root_path = ObjectPath::Root();
+      current_paths = ObjectPathPair(root_path, root_path);
+    }
+    if (!SEqualReduce(lhs, rhs, map_free_vars, current_paths)) {
+      return false;
+    }
+
     ICHECK_EQ(pending_tasks_.size(), 1U);
     ICHECK(allow_push_to_stack_);
     task_stack_.emplace_back(std::move(pending_tasks_.back()));
@@ -118,7 +281,11 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
 
  protected:
   // Check the result.
-  bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
+  bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs,
+                   const Optional<ObjectPathPair>& current_paths) {
+    if (IsPathTracingEnabled() && !result && !first_mismatch_->defined()) {
+      *first_mismatch_ = current_paths;
+    }
     if (assert_mode_ && !result) {
       LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
                  << PrettyPrint(lhs) << std::endl
@@ -137,6 +304,13 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
       // Caution: entry becomes invalid when the stack changes
       auto& entry = task_stack_.back();
 
+      if (entry.force_fail) {
+        if (IsPathTracingEnabled() && !first_mismatch_->defined()) {
+          *first_mismatch_ = entry.current_paths;
+        }
+        return false;
+      }
+
       if (entry.children_expanded) {
         // When all the children has expanded and visited.
         // This means all the condition checks for
@@ -161,7 +335,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
         // which populates the pending tasks.
         ICHECK_EQ(pending_tasks_.size(), 0U);
         allow_push_to_stack_ = false;
-        if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars)) return false;
+        if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars, entry.current_paths))
+          return false;
         allow_push_to_stack_ = true;
         // Push pending tasks in reverse order, so earlier tasks get to
         // expand first in the stack
@@ -175,7 +350,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
   }
 
   // The default equal as registered in the structural equal vtable.
-  bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+  bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+                            const Optional<ObjectPathPair>& current_paths) {
     auto compute = [=]() {
       ICHECK(lhs.defined() && rhs.defined() && lhs->type_index() == rhs->type_index());
       // skip entries that already have equality maps.
@@ -184,10 +360,18 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
         return it->second.same_as(rhs);
       }
       if (equal_map_rhs_.count(rhs)) return false;
+
       // Run reduce check for free nodes.
-      return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, map_free_vars));
+      if (!IsPathTracingEnabled()) {
+        return vtable_->SEqualReduce(lhs.get(), rhs.get(),
+                                     SEqualReducer(this, nullptr, map_free_vars));
+      } else {
+        PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_};
+        return vtable_->SEqualReduce(lhs.get(), rhs.get(),
+                                     SEqualReducer(this, &tracing_data, map_free_vars));
+      }
     };
-    return CheckResult(compute(), lhs, rhs);
+    return CheckResult(compute(), lhs, rhs, current_paths);
   }
 
  private:
@@ -197,17 +381,32 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
     ObjectRef lhs;
     /*! \brief The rhs operand to be compared. */
     ObjectRef rhs;
+    /*! \brief If path tracing is enabled, paths taken so far from the root to `lhs` and `rhs`
+     * objects. */
+    Optional<ObjectPathPair> current_paths;
     /*! \brief The map free var argument. */
     bool map_free_vars;
     /*! \brief Whether the children has been expanded via SEqualReduce */
     bool children_expanded{false};
     /*! \brief whether the task is about graph equality(need remap). */
     bool graph_equal{false};
+    /*! \brief whether the task should return "false" without actually comparing anything */
+    bool force_fail{false};
 
     Task() = default;
-    Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars)
-        : lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {}
+    Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars, Optional<ObjectPathPair> current_paths)
+        : lhs(lhs),
+          rhs(rhs),
+          current_paths(std::move(current_paths)),
+          map_free_vars(map_free_vars) {}
+
+    struct ForceFailTag {};  // dispatch tag for the constructor below
+    Task(ForceFailTag, const ObjectPathPair& current_paths)
+        : current_paths(current_paths), force_fail(true) {}
   };
+
+  bool IsPathTracingEnabled() const { return first_mismatch_ != nullptr; }
+
   // list of pending tasks to be pushed to the stack.
   std::vector<Task> pending_tasks_;
   // Internal task stack to executed the task.
@@ -216,6 +415,8 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
   bool allow_push_to_stack_{true};
   //  If in assert mode, must return true, and will throw error otherwise.
   bool assert_mode_{false};
+  // Location to store the paths to the first detected mismatch, or nullptr to disable path tracing.
+  Optional<ObjectPathPair>* first_mismatch_;
   // reflection vtable
   ReflectionVTable* vtable_ = ReflectionVTable::Global();
   // map from lhs to rhs
@@ -227,11 +428,19 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
 TVM_REGISTER_GLOBAL("node.StructuralEqual")
     .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
                        bool map_free_vars) {
-      return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
+      return RemapVarSEqualHandler(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars);
+    });
+
+TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
+    .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+      Optional<ObjectPathPair> first_mismatch;
+      bool equal = RemapVarSEqualHandler(false, &first_mismatch).Equal(lhs, rhs, map_free_vars);
+      ICHECK(equal == !first_mismatch.defined());
+      return first_mismatch;
     });
 
 bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
-  return RemapVarSEqualHandler(false).Equal(lhs, rhs, false);
+  return RemapVarSEqualHandler(false, nullptr).Equal(lhs, rhs, false);
 }
 
 }  // namespace tvm
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 23811e2190..b40b1751fb 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -22,6 +22,7 @@
 #include <dmlc/memory_io.h>
 #include <tvm/node/functor.h>
 #include <tvm/node/node.h>
+#include <tvm/node/object_path.h>
 #include <tvm/node/reflection.h>
 #include <tvm/node/structural_hash.h>
 #include <tvm/runtime/container/adt.h>
@@ -395,12 +396,73 @@ struct ArrayNodeTrait {
   }
 
   static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) {
+    if (equal.IsPathTracingEnabled()) {
+      return SEqualReduceTraced(lhs, rhs, equal);
+    }
+
     if (lhs->size() != rhs->size()) return false;
     for (size_t i = 0; i < lhs->size(); ++i) {
       if (!equal(lhs->at(i), rhs->at(i))) return false;
     }
     return true;
   }
+
+ private:
+  static bool SEqualReduceTraced(const ArrayNode* lhs, const ArrayNode* rhs,
+                                 const SEqualReducer& equal) {
+    size_t min_size = std::min(lhs->size(), rhs->size());
+    const ObjectPathPair& array_paths = equal.GetCurrentObjectPaths();
+
+    for (size_t index = 0; index < min_size; ++index) {
+      ObjectPathPair element_paths = {array_paths->lhs_path->ArrayIndex(index),
+                                      array_paths->rhs_path->ArrayIndex(index)};
+      if (!equal(lhs->at(index), rhs->at(index), element_paths)) {
+        return false;
+      }
+    }
+
+    if (lhs->size() == rhs->size()) {
+      return true;
+    }
+
+    // If the array length is mismatched, don't report it immediately.
+    // Instead, defer the failure until we visit all children.
+    //
+    // This is for human readability. For example, say we have two sequences
+    //
+    //    (1)     a b c d e f g h i j k l m
+    //    (2)     a b c d e g h i j k l m
+    //
+    // If we directly report a mismatch at the end of the array right now,
+    // the user will see that array (1) has an element `m` at index 12 but array (2)
+    // has no index 12 because it's too short:
+    //
+    //    (1)     a b c d e f g h i j k l m
+    //                                    ^error here
+    //    (2)     a b c d e g h i j k l m
+    //                                    ^ error here
+    //
+    // This is not very helpful. Instead, if we defer reporting this mismatch until all elements
+    // are fully visited, we can be much more helpful with pointing out the location:
+    //
+    //    (1)     a b c d e f g h i j k l m
+    //                      ^
+    //                   error here
+    //
+    //    (2)     a b c d e g h i j k l m
+    //                      ^
+    //                  error here
+    if (lhs->size() > min_size) {
+      equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
+                        array_paths->rhs_path->MissingArrayElement(min_size)});
+    } else {
+      equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
+                        array_paths->rhs_path->ArrayIndex(min_size)});
+    }
+
+    // Can return `true` pretending that everything is good since we have deferred the failure.
+    return true;
+  }
 };
 TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
     .set_creator([](const std::string&) -> ObjectPtr<Object> {
@@ -501,13 +563,105 @@ struct MapNodeTrait {
     return true;
   }
 
+  static bool IsStringMap(const MapNode* map) {
+    return std::all_of(map->begin(), map->end(),
+                       [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
+  }
+
+  static bool SEqualReduceTracedForOMap(const MapNode* lhs, const MapNode* rhs,
+                                        const SEqualReducer& equal) {
+    const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths();
+
+    std::vector<const Object*> seen_rhs_keys;
+
+    // First, check that every key from `lhs` is also in `rhs`,
+    // and their values are mapped to each other.
+    for (const auto& kv : *lhs) {
+      ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first);
+
+      ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
+      if (!rhs_key.defined()) {
+        equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
+        return false;
+      }
+
+      auto it = rhs->find(rhs_key);
+      if (it == rhs->end()) {
+        equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
+        return false;
+      }
+
+      if (!equal(kv.second, it->second, {lhs_path, map_paths->rhs_path->MapValue(it->first)})) {
+        return false;
+      }
+
+      seen_rhs_keys.push_back(it->first.get());
+    }
+
+    std::sort(seen_rhs_keys.begin(), seen_rhs_keys.end());
+
+    // Second, check that we have visited every `rhs` key when iterating over `lhs`.
+    for (const auto& kv : *rhs) {
+      if (!std::binary_search(seen_rhs_keys.begin(), seen_rhs_keys.end(), kv.first.get())) {
+        equal.RecordMismatchPaths(
+            {map_paths->lhs_path->MissingMapEntry(), map_paths->rhs_path->MapValue(kv.first)});
+        return false;
+      }
+    }
+
+    ICHECK(lhs->size() == rhs->size());
+    return true;
+  }
+
+  static bool SEqualReduceTracedForSMap(const MapNode* lhs, const MapNode* rhs,
+                                        const SEqualReducer& equal) {
+    const ObjectPathPair& map_paths = equal.GetCurrentObjectPaths();
+
+    // First, check that every key from `lhs` is also in `rhs`, and their values are equal.
+    for (const auto& kv : *lhs) {
+      ObjectPath lhs_path = map_paths->lhs_path->MapValue(kv.first);
+      auto it = rhs->find(kv.first);
+      if (it == rhs->end()) {
+        equal.RecordMismatchPaths({lhs_path, map_paths->rhs_path->MissingMapEntry()});
+        return false;
+      }
+
+      if (!equal(kv.second, it->second, {lhs_path, map_paths->rhs_path->MapValue(it->first)})) {
+        return false;
+      }
+    }
+
+    // Second, make sure every key from `rhs` is also in `lhs`.
+    for (const auto& kv : *rhs) {
+      ObjectPath rhs_path = map_paths->rhs_path->MapValue(kv.first);
+      if (!lhs->count(kv.first)) {
+        equal.RecordMismatchPaths({map_paths->lhs_path->MissingMapEntry(), rhs_path});
+        return false;
+      }
+    }
+
+    ICHECK(lhs->size() == rhs->size());
+    return true;
+  }
+
+  static bool SEqualReduceTraced(const MapNode* lhs, const MapNode* rhs,
+                                 const SEqualReducer& equal) {
+    if (IsStringMap(lhs)) {
+      return SEqualReduceTracedForSMap(lhs, rhs, equal);
+    } else {
+      return SEqualReduceTracedForOMap(lhs, rhs, equal);
+    }
+  }
+
   static bool SEqualReduce(const MapNode* lhs, const MapNode* rhs, SEqualReducer equal) {
+    if (equal.IsPathTracingEnabled()) {
+      return SEqualReduceTraced(lhs, rhs, equal);
+    }
+
     if (rhs->size() != lhs->size()) return false;
     if (rhs->size() == 0) return true;
-    bool ls = std::all_of(lhs->begin(), lhs->end(),
-                          [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
-    bool rs = std::all_of(rhs->begin(), rhs->end(),
-                          [](const auto& v) { return v.first->template IsInstance<StringObj>(); });
+    bool ls = IsStringMap(lhs);
+    bool rs = IsStringMap(rhs);
     if (ls != rs) {
       return false;
     }
diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc
index 7f48cc4392..451855c8f8 100644
--- a/src/tir/analysis/deep_equal.cc
+++ b/src/tir/analysis/deep_equal.cc
@@ -21,6 +21,7 @@
  * \file tir/analysis/deep_equal.cc
  * \brief Deep equality checking.
  */
+#include <tvm/node/object_path.h>
 #include <tvm/node/reflection.h>
 #include <tvm/node/structural_equal.h>
 #include <tvm/runtime/registry.h>
@@ -32,21 +33,25 @@ namespace tir {
 class DeepCmpSEqualHandler : public SEqualReducer::Handler {
  public:
   // use direct recursion.
-  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) final {
+  bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
+                    const Optional<ObjectPathPair>&) final {
     if (lhs.same_as(rhs)) return true;
     if (!lhs.defined() && rhs.defined()) return false;
     if (!rhs.defined() && lhs.defined()) return false;
     if (lhs->type_index() != rhs->type_index()) return false;
-    return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, false));
+    return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this, nullptr, false)) &&
+           !fail_;
   }
 
-  ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); }
+  void DeferFail(const ObjectPathPair&) final { fail_ = true; }
 
+  ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); }
   void MarkGraphNode() final {}
 
  private:
   // reflection vtable
   ReflectionVTable* vtable_ = ReflectionVTable::Global();
+  bool fail_ = false;
 };
 
 bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
@@ -62,7 +67,7 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
   if (lhs.as<AnyNode>()) {
     return false;
   }
-  return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false);
+  return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false, NullOpt);
 }
 
 TVM_REGISTER_GLOBAL("tir.analysis.expr_deep_equal")
diff --git a/tests/python/unittest/test_container_structural_equal.py b/tests/python/unittest/test_container_structural_equal.py
new file mode 100644
index 0000000000..cdd9ffb7af
--- /dev/null
+++ b/tests/python/unittest/test_container_structural_equal.py
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.ir.base import get_first_structural_mismatch
+from tvm.runtime import ObjectPath
+
+
+def get_first_mismatch_ensure_symmetry(a, b):
+    mismatch = get_first_structural_mismatch(a, b)
+    mismatch_swapped = get_first_structural_mismatch(b, a)
+
+    if mismatch is None and mismatch_swapped is None:
+        return None
+
+    if (
+        mismatch is None
+        or mismatch_swapped is None
+        or mismatch[0] != mismatch_swapped[1]
+        or mismatch[1] != mismatch_swapped[0]
+    ):
+        raise AssertionError(
+            "get_first_structural_mismatch(a, b) and get_first_structural_mismatch(b, a) returned"
+            " inconsistent results '{}' and '{}' for a='{}', b='{}'".format(
+                mismatch, mismatch_swapped, a, b
+            )
+        )
+
+    a_path, b_path = mismatch
+    b_path_swapped, a_path_swapped = mismatch_swapped
+    assert a_path == a_path_swapped
+    assert b_path == b_path_swapped
+
+    return mismatch
+
+
+@pytest.mark.parametrize(
+    "a, b, expected_a_path, expected_b_path",
+    [
+        (
+            [1, 2, 3],
+            [1, 4, 3],
+            ObjectPath.root().array_index(1).attr("value"),
+            ObjectPath.root().array_index(1).attr("value"),
+        ),
+        (
+            [1, 2, 3],
+            [10, 2, 30],
+            ObjectPath.root().array_index(0).attr("value"),
+            ObjectPath.root().array_index(0).attr("value"),
+        ),
+        (
+            [1, 3, 4],
+            [1, 2, 3, 4],
+            ObjectPath.root().array_index(1).attr("value"),
+            ObjectPath.root().array_index(1).attr("value"),
+        ),
+        (
+            [1, 2, 3],
+            [1, 2, 3, 4],
+            ObjectPath.root().missing_array_element(3),
+            ObjectPath.root().array_index(3),
+        ),
+        (
+            [],
+            [1],
+            ObjectPath.root().missing_array_element(0),
+            ObjectPath.root().array_index(0),
+        ),
+    ],
+)
+def test_array_structural_mismatch(a, b, expected_a_path, expected_b_path):
+    a = tvm.runtime.convert(a)
+    b = tvm.runtime.convert(b)
+    a_path, b_path = get_first_mismatch_ensure_symmetry(a, b)
+    assert a_path == expected_a_path
+    assert b_path == expected_b_path
+
+
+@pytest.mark.parametrize(
+    "contents",
+    [
+        [],
+        [1],
+        [1, 2, 3],
+    ],
+)
+def test_array_structural_equal_to_self(contents):
+    a = tvm.runtime.convert(list(contents))
+    b = tvm.runtime.convert(list(contents))
+    assert get_first_mismatch_ensure_symmetry(a, b) is None
+
+
+@pytest.mark.parametrize(
+    "a, b, expected_a_path, expected_b_path",
+    [
+        (
+            dict(a=3, b=4),
+            dict(a=3, b=5),
+            ObjectPath.root().map_value("b").attr("value"),
+            ObjectPath.root().map_value("b").attr("value"),
+        ),
+        (
+            dict(a=3, b=4),
+            dict(a=3, b=4, c=5),
+            ObjectPath.root().missing_map_entry(),
+            ObjectPath.root().map_value("c"),
+        ),
+    ],
+)
+def test_string_map_structural_mismatch(a, b, expected_a_path, expected_b_path):
+    a = tvm.runtime.convert(a)
+    b = tvm.runtime.convert(b)
+    a_path, b_path = get_first_mismatch_ensure_symmetry(a, b)
+    assert a_path == expected_a_path
+    assert b_path == expected_b_path
+
+
+@pytest.mark.parametrize(
+    "contents",
+    [
+        dict(),
+        dict(a=1),
+        dict(a=3, b=4, c=5),
+    ],
+)
+def test_string_structural_equal_to_self(contents):
+    a = tvm.runtime.convert(dict(contents))
+    b = tvm.runtime.convert(dict(contents))
+    assert get_first_mismatch_ensure_symmetry(a, b) is None
+
+
+# The behavior of structural equality for maps with non-string keys is fairly specific
+# to IR variables because it assumes that map keys have been "mapped" using
+# `SEqualReducer::FreeVarEqualImpl()`. So we leave this case to TIR tests.
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_structural_equal_hash.py b/tests/python/unittest/test_tir_structural_equal_hash.py
index ff02f1e369..d5feb21f0d 100644
--- a/tests/python/unittest/test_tir_structural_equal_hash.py
+++ b/tests/python/unittest/test_tir_structural_equal_hash.py
@@ -18,6 +18,7 @@ import tvm
 import numpy as np
 import pytest
 from tvm import te
+from tvm.runtime import ObjectPath
 
 
 def consistent_equal(x, y, map_free_vars=False):
@@ -29,7 +30,7 @@ def consistent_equal(x, y, map_free_vars=False):
 
     if struct_equal0 != struct_equal1:
         raise ValueError(
-            "Non-communicative {} vs {}, sequal0={}, sequal1={}".format(
+            "Non-commutative {} vs {}, sequal0={}, sequal1={}".format(
                 x, y, struct_equal0, struct_equal1
             )
         )
@@ -45,6 +46,28 @@ def consistent_equal(x, y, map_free_vars=False):
     return struct_equal0
 
 
+def get_sequal_mismatch(x, y, map_free_vars=False):
+    mismatch_0 = tvm.ir.base.get_first_structural_mismatch(x, y, map_free_vars)
+    mismatch_1 = tvm.ir.base.get_first_structural_mismatch(y, x, map_free_vars)
+
+    if mismatch_0 is None and mismatch_1 is None:
+        return None
+
+    if (
+        mismatch_0 is None
+        or mismatch_1 is None
+        or mismatch_0[0] != mismatch_1[1]
+        or mismatch_0[1] != mismatch_1[0]
+    ):
+        raise ValueError(
+            "Non-commutative {} vs {}, mismatch_0={}, mismatch_1={}".format(
+                x, y, mismatch_0, mismatch_1
+            )
+        )
+
+    return mismatch_0
+
+
 def test_exprs():
     # save load json
     x = tvm.tir.const(1, "int32")
@@ -107,6 +130,47 @@ def test_prim_func():
     tvm.ir.assert_structural_equal(mod0, mod1)
 
 
+def test_prim_func_param_count_mismatch():
+    x = te.var("x")
+    y = te.var("y")
+    z = te.var("z")
+    # counter example of same equality
+    func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x))
+    func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x))
+    lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
+    expected_lhs_path = ObjectPath.root().attr("params").missing_array_element(2)
+    expected_rhs_path = ObjectPath.root().attr("params").array_index(2)
+    assert lhs_path == expected_lhs_path
+    assert rhs_path == expected_rhs_path
+
+
+def test_prim_func_param_dtype_mismatch():
+    x = te.var("x")
+    y_0 = te.var("y", dtype="int32")
+    y_1 = te.var("z", dtype="float32")
+    # counter example of same equality
+    func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x))
+    func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x))
+    lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
+    expected_path = ObjectPath.root().attr("params").array_index(1).attr("dtype")
+    assert lhs_path == expected_path
+    assert rhs_path == expected_path
+
+
+def test_prim_func_body_mismatch():
+    x_0 = te.var("x")
+    y_0 = te.var("y")
+    x_1 = te.var("x")
+    y_1 = te.var("y")
+    # counter example of same equality
+    func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0))
+    func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1))
+    lhs_path, rhs_path = get_sequal_mismatch(func0, func1)
+    expected_path = ObjectPath.root().attr("body").attr("value").attr("b")
+    assert lhs_path == expected_path
+    assert rhs_path == expected_path
+
+
 def test_array():
     x = np.arange(10)
     nx = tvm.nd.array(x)
@@ -183,6 +247,44 @@ def test_buffer_storage_scope():
     assert not consistent_equal(func0, func2)
 
 
+def test_buffer_map_mismatch():
+    x = te.var("x")
+    buffer_0 = tvm.tir.decl_buffer((10, 10))
+    buffer_0_clone = tvm.tir.decl_buffer((10, 10))
+    buffer_1 = tvm.tir.decl_buffer((10, 20))
+
+    func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0})
+    func_0_clone = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0_clone})
+    func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_1})
+
+    lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
+    expected_path = (
+        ObjectPath.root().attr("buffer_map").map_value(x).attr("shape").array_index(1).attr("value")
+    )
+    assert lhs_path == expected_path
+    assert rhs_path == expected_path
+
+    assert get_sequal_mismatch(func_0, func_0_clone) is None
+
+
+def test_buffer_map_length_mismatch():
+    x = te.var("x")
+    y = te.var("x")
+
+    buffer_0 = tvm.tir.decl_buffer((10, 10))
+    buffer_1 = tvm.tir.decl_buffer((10, 20))
+
+    func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0})
+    func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0, y: buffer_1})
+
+    lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1)
+
+    expected_lhs_path = ObjectPath.root().attr("buffer_map").missing_map_entry()
+    assert lhs_path == expected_lhs_path
+    expected_rhs_path = ObjectPath.root().attr("buffer_map").map_value(y)
+    assert rhs_path == expected_rhs_path
+
+
 def test_buffer_load_store():
     b = tvm.tir.decl_buffer((10, 10), "float32")
     x = tvm.tir.BufferLoad(b, [0, 1])
@@ -208,6 +310,90 @@ def test_while():
     assert consistent_equal(wx, wy, map_free_vars=True)
 
 
+def test_while_condition_mismatch():
+    x = tvm.tir.Var("x", "int32")
+    w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
+    w_1 = tvm.tir.While(x < 0, tvm.tir.Evaluate(x))
+    lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
+    expected_path = ObjectPath.root().attr("condition")
+    assert lhs_path == expected_path
+    assert rhs_path == expected_path
+
+
+def test_while_body_mismatch():
+    x = tvm.tir.Var("x", "int32")
+    w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x))
+    w_1 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x + 1))
+    lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1)
+    expected_path = ObjectPath.root().attr("body").attr("value")
+    assert lhs_path == expected_path
+    assert rhs_path == expected_path
+
+
+def test_seq_mismatch():
+    x = tvm.tir.Var("x", "int32")
+    seq_0 = tvm.tir.SeqStmt(
+        [
+            tvm.tir.Evaluate(x),
+            tvm.tir.Evaluate(x + 1),
+            tvm.tir.Evaluate(x + 2),
+            tvm.tir.Evaluate(x + 3),
+        ]
+    )
+    seq_1 = tvm.tir.SeqStmt(
+        [
+            tvm.tir.Evaluate(x),
+            tvm.tir.Evaluate(x + 1),
+            tvm.tir.Evaluate(x + 99),
+            tvm.tir.Evaluate(x + 3),
+        ]
+    )
+    lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
+    expected_path = (
+        ObjectPath.root().attr("seq").array_index(2).attr("value").attr("b").attr("value")
+    )
+    assert lhs_path == expected_path
+    assert rhs_path == expected_path
+
+
+def test_seq_mismatch_different_lengths():
+    # Make sure we report a difference inside the array first, rather than the difference in length
+    x = tvm.tir.Var("x", "int32")
+    seq_0 = tvm.tir.SeqStmt(
+        [
+            tvm.tir.Evaluate(x),
+            tvm.tir.Evaluate(x + 1),
+            tvm.tir.Evaluate(x + 2),
+            tvm.tir.Evaluate(x + 3),
+        ]
+    )
+    seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 3)])
+    lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
+    expected_path = (
+        ObjectPath.root().attr("seq").array_index(2).attr("value").attr("b").attr("value")
+    )
+    assert lhs_path == expected_path
+    assert rhs_path == expected_path
+
+
+def test_seq_length_mismatch():
+    x = tvm.tir.Var("x", "int32")
+    seq_0 = tvm.tir.SeqStmt(
+        [
+            tvm.tir.Evaluate(x),
+            tvm.tir.Evaluate(x + 1),
+            tvm.tir.Evaluate(x + 2),
+            tvm.tir.Evaluate(x + 3),
+        ]
+    )
+    seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 2)])
+    lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1)
+    expected_lhs_path = ObjectPath.root().attr("seq").array_index(3)
+    expected_rhs_path = ObjectPath.root().attr("seq").missing_array_element(3)
+    assert lhs_path == expected_lhs_path
+    assert rhs_path == expected_rhs_path
+
+
 if __name__ == "__main__":
     test_exprs()
     test_prim_func()