You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2022/04/19 18:30:46 UTC
[tvm] branch main updated: Attempt to prevent concurrent update in Map (#9842)
This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5987982fae Attempt to prevent concurrent update in Map (#9842)
5987982fae is described below
commit 5987982faeb40849385b4ea99b3a92866f180e76
Author: Dmitriy Smirnov <dm...@arm.com>
AuthorDate: Tue Apr 19 19:30:40 2022 +0100
Attempt to prevent concurrent update in Map (#9842)
* Attempt to prevent concurrent update in Map
Calling Map::Set invalidates exising iterators to protect from
using already deleted data due to re-hashing
Change-Id: Ib6b580758e74c8b77ed560932d87b643bd6c9402
* Migrated to using TVM_LOG_DEBUG
Now uses TVM_LOG_DEBUG
Map state_marker made atomic
Change-Id: I090c4b33e6edaa977cccba11f8d1c6ff3fbca430
* removed usage of atomics
Change-Id: I7bd930cb52d58ca10fd49a5fe8f5d48b3e955d0a
---
include/tvm/runtime/container/map.h | 37 +++++++++++++++++++++++++++++++++++--
tests/cpp/container_test.cc | 15 +++++++++++++++
2 files changed, 50 insertions(+), 2 deletions(-)
diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h
index 977dbfbaaa..4c76a3b0ad 100644
--- a/include/tvm/runtime/container/map.h
+++ b/include/tvm/runtime/container/map.h
@@ -38,6 +38,13 @@
namespace tvm {
namespace runtime {
+#if TVM_LOG_DEBUG
+#define TVM_MAP_FAIL_IF_CHANGED() \
+ ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map";
+#else
+#define TVM_MAP_FAIL_IF_CHANGED()
+#endif // TVM_LOG_DEBUG
+
#if (USE_FALLBACK_STL_MAP != 0)
/*! \brief Shared content of all specializations of hash map */
@@ -233,10 +240,15 @@ class MapNode : public Object {
using value_type = KVType;
using pointer = KVType*;
using reference = KVType&;
- /*! \brief Default constructor */
+/*! \brief Default constructor */
+#if TVM_LOG_DEBUG
+ iterator() : state_marker(0), index(0), self(nullptr) {}
+#else
iterator() : index(0), self(nullptr) {}
+#endif // TVM_LOG_DEBUG
/*! \brief Compare iterators */
bool operator==(const iterator& other) const {
+ TVM_MAP_FAIL_IF_CHANGED()
return index == other.index && self == other.self;
}
/*! \brief Compare iterators */
@@ -244,27 +256,39 @@ class MapNode : public Object {
/*! \brief De-reference iterators */
pointer operator->() const;
/*! \brief De-reference iterators */
- reference operator*() const { return *((*this).operator->()); }
+ reference operator*() const {
+ TVM_MAP_FAIL_IF_CHANGED()
+ return *((*this).operator->());
+ }
/*! \brief Prefix self increment, e.g. ++iter */
iterator& operator++();
/*! \brief Prefix self decrement, e.g. --iter */
iterator& operator--();
/*! \brief Suffix self increment */
iterator operator++(int) {
+ TVM_MAP_FAIL_IF_CHANGED()
iterator copy = *this;
++(*this);
return copy;
}
/*! \brief Suffix self decrement */
iterator operator--(int) {
+ TVM_MAP_FAIL_IF_CHANGED()
iterator copy = *this;
--(*this);
return copy;
}
protected:
+#if TVM_LOG_DEBUG
+ uint64_t state_marker;
/*! \brief Construct by value */
+ iterator(uint64_t index, const MapNode* self)
+ : state_marker(self->state_marker), index(index), self(self) {}
+
+#else
iterator(uint64_t index, const MapNode* self) : index(index), self(self) {}
+#endif // TVM_LOG_DEBUG
/*! \brief The position on the array */
uint64_t index;
/*! \brief The container it points to */
@@ -280,6 +304,9 @@ class MapNode : public Object {
static inline ObjectPtr<MapNode> Empty();
protected:
+#if TVM_LOG_DEBUG
+ uint64_t state_marker;
+#endif // TVM_LOG_DEBUG
/*!
* \brief Create the map using contents from the given iterators.
* \param first Begin of iterator
@@ -1118,10 +1145,12 @@ class DenseMapNode : public MapNode {
}
inline MapNode::iterator::pointer MapNode::iterator::operator->() const {
+ TVM_MAP_FAIL_IF_CHANGED()
TVM_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); });
}
inline MapNode::iterator& MapNode::iterator::operator++() {
+ TVM_MAP_FAIL_IF_CHANGED()
TVM_DISPATCH_MAP_CONST(self, p, {
index = p->IncItr(index);
return *this;
@@ -1129,6 +1158,7 @@ inline MapNode::iterator& MapNode::iterator::operator++() {
}
inline MapNode::iterator& MapNode::iterator::operator--() {
+ TVM_MAP_FAIL_IF_CHANGED()
TVM_DISPATCH_MAP_CONST(self, p, {
index = p->DecItr(index);
return *this;
@@ -1200,6 +1230,9 @@ inline ObjectPtr<Object> MapNode::CreateFromRange(IterType first, IterType last)
inline void MapNode::InsertMaybeReHash(const KVType& kv, ObjectPtr<Object>* map) {
constexpr uint64_t kSmallMapMaxSize = SmallMapNode::kMaxSize;
MapNode* base = static_cast<MapNode*>(map->get());
+#if TVM_LOG_DEBUG
+ base->state_marker++;
+#endif // TVM_LOG_DEBUG
if (base->slots_ < kSmallMapMaxSize) {
SmallMapNode::InsertMaybeReHash(kv, map);
} else if (base->slots_ == kSmallMapMaxSize) {
diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc
index 019fde0698..32ec346c87 100644
--- a/tests/cpp/container_test.cc
+++ b/tests/cpp/container_test.cc
@@ -380,6 +380,21 @@ TEST(Map, Erase) {
}
}
+#if TVM_LOG_DEBUG
+TEST(Map, Race) {
+ using namespace tvm::runtime;
+ Map<Integer, Integer> m;
+
+ m.Set(1, 1);
+ Map<tvm::Integer, tvm::Integer>::iterator it = m.begin();
+ EXPECT_NO_THROW({ auto& kv = *it; });
+
+ m.Set(2, 2);
+ // changed. iterator should be re-obtained
+ EXPECT_ANY_THROW({ auto& kv = *it; });
+}
+#endif // TVM_LOG_DEBUG
+
TEST(String, MoveFromStd) {
using namespace std;
string source = "this is a string";