You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2022/04/19 18:30:46 UTC

[tvm] branch main updated: Attempt to prevent concurrent update in Map (#9842)

This is an automated email from the ASF dual-hosted git repository.

areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5987982fae Attempt to prevent concurrent update in Map (#9842)
5987982fae is described below

commit 5987982faeb40849385b4ea99b3a92866f180e76
Author: Dmitriy Smirnov <dm...@arm.com>
AuthorDate: Tue Apr 19 19:30:40 2022 +0100

    Attempt to prevent concurrent update in Map (#9842)
    
    * Attempt to prevent concurrent update in Map
    
    Calling Map::Set invalidates exising iterators to protect from
    using already deleted data due to re-hashing
    
    Change-Id: Ib6b580758e74c8b77ed560932d87b643bd6c9402
    
    * Migrated to using TVM_LOG_DEBUG
    
    Now uses TVM_LOG_DEBUG
    Map state_marker made atomic
    
    Change-Id: I090c4b33e6edaa977cccba11f8d1c6ff3fbca430
    
    * removed usage of atomics
    
    Change-Id: I7bd930cb52d58ca10fd49a5fe8f5d48b3e955d0a
---
 include/tvm/runtime/container/map.h | 37 +++++++++++++++++++++++++++++++++++--
 tests/cpp/container_test.cc         | 15 +++++++++++++++
 2 files changed, 50 insertions(+), 2 deletions(-)

diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h
index 977dbfbaaa..4c76a3b0ad 100644
--- a/include/tvm/runtime/container/map.h
+++ b/include/tvm/runtime/container/map.h
@@ -38,6 +38,13 @@
 namespace tvm {
 namespace runtime {
 
+#if TVM_LOG_DEBUG
+#define TVM_MAP_FAIL_IF_CHANGED() \
+  ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map";
+#else
+#define TVM_MAP_FAIL_IF_CHANGED()
+#endif  // TVM_LOG_DEBUG
+
 #if (USE_FALLBACK_STL_MAP != 0)
 
 /*! \brief Shared content of all specializations of hash map */
@@ -233,10 +240,15 @@ class MapNode : public Object {
     using value_type = KVType;
     using pointer = KVType*;
     using reference = KVType&;
-    /*! \brief Default constructor */
+/*! \brief Default constructor */
+#if TVM_LOG_DEBUG
+    iterator() : state_marker(0), index(0), self(nullptr) {}
+#else
     iterator() : index(0), self(nullptr) {}
+#endif  // TVM_LOG_DEBUG
     /*! \brief Compare iterators */
     bool operator==(const iterator& other) const {
+      TVM_MAP_FAIL_IF_CHANGED()
       return index == other.index && self == other.self;
     }
     /*! \brief Compare iterators */
@@ -244,27 +256,39 @@ class MapNode : public Object {
     /*! \brief De-reference iterators */
     pointer operator->() const;
     /*! \brief De-reference iterators */
-    reference operator*() const { return *((*this).operator->()); }
+    reference operator*() const {
+      TVM_MAP_FAIL_IF_CHANGED()
+      return *((*this).operator->());
+    }
     /*! \brief Prefix self increment, e.g. ++iter */
     iterator& operator++();
     /*! \brief Prefix self decrement, e.g. --iter */
     iterator& operator--();
     /*! \brief Suffix self increment */
     iterator operator++(int) {
+      TVM_MAP_FAIL_IF_CHANGED()
       iterator copy = *this;
       ++(*this);
       return copy;
     }
     /*! \brief Suffix self decrement */
     iterator operator--(int) {
+      TVM_MAP_FAIL_IF_CHANGED()
       iterator copy = *this;
       --(*this);
       return copy;
     }
 
    protected:
+#if TVM_LOG_DEBUG
+    uint64_t state_marker;
     /*! \brief Construct by value */
+    iterator(uint64_t index, const MapNode* self)
+        : state_marker(self->state_marker), index(index), self(self) {}
+
+#else
     iterator(uint64_t index, const MapNode* self) : index(index), self(self) {}
+#endif  // TVM_LOG_DEBUG
     /*! \brief The position on the array */
     uint64_t index;
     /*! \brief The container it points to */
@@ -280,6 +304,9 @@ class MapNode : public Object {
   static inline ObjectPtr<MapNode> Empty();
 
  protected:
+#if TVM_LOG_DEBUG
+  uint64_t state_marker;
+#endif  // TVM_LOG_DEBUG
   /*!
    * \brief Create the map using contents from the given iterators.
    * \param first Begin of iterator
@@ -1118,10 +1145,12 @@ class DenseMapNode : public MapNode {
   }
 
 inline MapNode::iterator::pointer MapNode::iterator::operator->() const {
+  TVM_MAP_FAIL_IF_CHANGED()
   TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); });
 }
 
 inline MapNode::iterator& MapNode::iterator::operator++() {
+  TVM_MAP_FAIL_IF_CHANGED()
   TVM_DISPATCH_MAP_CONST(self, p, {
     index = p->IncItr(index);
     return *this;
@@ -1129,6 +1158,7 @@ inline MapNode::iterator& MapNode::iterator::operator++() {
 }
 
 inline MapNode::iterator& MapNode::iterator::operator--() {
+  TVM_MAP_FAIL_IF_CHANGED()
   TVM_DISPATCH_MAP_CONST(self, p, {
     index = p->DecItr(index);
     return *this;
@@ -1200,6 +1230,9 @@ inline ObjectPtr<Object> MapNode::CreateFromRange(IterType first, IterType last)
 inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
   constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize;
   MapNode* base = static_cast<MapNode*>(map->get());
+#if TVM_LOG_DEBUG
+  base->state_marker++;
+#endif  // TVM_LOG_DEBUG
   if (base->slots_ < kSmallMapMaxSize) {
     SmallMapNode::InsertMaybeReHash(kv, map);
   } else if (base->slots_ == kSmallMapMaxSize) {
diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc
index 019fde0698..32ec346c87 100644
--- a/tests/cpp/container_test.cc
+++ b/tests/cpp/container_test.cc
@@ -380,6 +380,21 @@ TEST(Map, Erase) {
   }
 }
 
+#if TVM_LOG_DEBUG
+TEST(Map, Race) {
+  using namespace tvm::runtime;
+  Map<Integer, Integer> m;
+
+  m.Set(1, 1);
+  Map<tvm::Integer, tvm::Integer>::iterator it = m.begin();
+  EXPECT_NO_THROW({ auto& kv = *it; });
+
+  m.Set(2, 2);
+  // changed. iterator should be re-obtained
+  EXPECT_ANY_THROW({ auto& kv = *it; });
+}
+#endif  // TVM_LOG_DEBUG
+
 TEST(String, MoveFromStd) {
   using namespace std;
   string source = "this is a string";