You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/12/29 23:26:20 UTC

[tvm] branch main updated: [CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple (#13671)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d582b7e511 [CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple (#13671)
d582b7e511 is described below

commit d582b7e511e9ff548067848ca99d8ab9194cbe9d
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Thu Dec 29 18:26:09 2022 -0500

    [CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple (#13671)
    
    This PR add struct equal/hash and json serialization support
    for shape tuple. Testcases added.
---
 src/node/structural_hash.cc                        | 44 ++++++++++++++++++++++
 src/support/base64.h                               |  9 ++++-
 .../unittest/test_container_structural_equal.py    | 14 +++++++
 tests/python/unittest/test_runtime_container.py    |  5 +++
 4 files changed, 70 insertions(+), 2 deletions(-)

diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 1d1185cddc..0426b8454d 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -484,6 +484,50 @@ TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
       return ::tvm::runtime::make_object<ArrayNode>();
     });
 
+struct ShapeTupleObjTrait {
+  static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+  static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) {
+    hash_reduce(self->size);
+    for (size_t i = 0; i < self->size; ++i) {
+      hash_reduce(self->data[i]);
+    }
+  }
+
+  static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs,
+                           SEqualReducer equal) {
+    if (lhs->size != rhs->size) return false;
+    for (size_t i = 0; i < lhs->size; ++i) {
+      if (!equal(lhs->data[i], rhs->data[i])) return false;
+    }
+    return true;
+  }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait)
+    .set_creator([](const std::string& blob) {
+      // Store shape tuple in blob to avoid large integer overflow in JSON.
+      dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
+      support::Base64InStream b64strm(&mstrm);
+      b64strm.InitPosition();
+      uint64_t size;
+      b64strm.Read<uint64_t>(&size);
+      std::vector<int64_t> data(size);
+      b64strm.ReadArray(data.data(), size);
+      ShapeTuple shape(data);
+      return RefToObjectPtr::Get(shape);
+    })
+    .set_repr_bytes([](const Object* n) -> std::string {
+      std::string blob;
+      dmlc::MemoryStringStream mstrm(&blob);
+      support::Base64OutStream b64strm(&mstrm);
+      const auto* shape = static_cast<const runtime::ShapeTupleObj*>(n);
+      b64strm.Write<uint64_t>(shape->size);
+      b64strm.WriteArray(shape->data, shape->size);
+      b64strm.Finish();
+      return blob;
+    });
+
 struct MapNodeTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
 
diff --git a/src/support/base64.h b/src/support/base64.h
index 7b37afce66..aba4197bce 100644
--- a/src/support/base64.h
+++ b/src/support/base64.h
@@ -115,8 +115,10 @@ class Base64InStream : public dmlc::Stream {
   }
   /*! \brief whether current position is end of a base64 stream */
   bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); }
+
+  using dmlc::Stream::Read;
   // override read function.
-  virtual size_t Read(void* ptr, size_t size) {
+  size_t Read(void* ptr, size_t size) final {
     using base64::DecodeTable;
     if (size == 0) return 0;
     // use tlen to record left size
@@ -224,7 +226,10 @@ class Base64InStream : public dmlc::Stream {
 class Base64OutStream : public dmlc::Stream {
  public:
   explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {}
-  virtual void Write(const void* ptr, size_t size) {
+
+  using dmlc::Stream::Write;
+
+  void Write(const void* ptr, size_t size) final {
     using base64::EncodeTable;
     size_t tlen = size;
     const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
diff --git a/tests/python/unittest/test_container_structural_equal.py b/tests/python/unittest/test_container_structural_equal.py
index cdd9ffb7af..61511c609c 100644
--- a/tests/python/unittest/test_container_structural_equal.py
+++ b/tests/python/unittest/test_container_structural_equal.py
@@ -107,6 +107,20 @@ def test_array_structural_equal_to_self(contents):
     assert get_first_mismatch_ensure_symmetry(a, b) is None
 
 
+@pytest.mark.parametrize(
+    "contents",
+    [
+        [],
+        [1],
+        [1, 2, 3],
+    ],
+)
+def test_shape_tuple_structural_equal_to_self(contents):
+    a = tvm.runtime.ShapeTuple(list(contents))
+    b = tvm.runtime.ShapeTuple(list(contents))
+    assert get_first_mismatch_ensure_symmetry(a, b) is None
+
+
 @pytest.mark.parametrize(
     "a, b, expected_a_path, expected_b_path",
     [
diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py
index 8c302e9205..7538075ae7 100644
--- a/tests/python/unittest/test_runtime_container.py
+++ b/tests/python/unittest/test_runtime_container.py
@@ -90,6 +90,11 @@ def test_shape_tuple():
     # ShapleTuple vs. ShapeTuple
     assert stuple == _container.ShapeTuple(shape)
 
+    # test pickle
+    z = pickle.loads(pickle.dumps(stuple))
+    assert isinstance(z, tvm.runtime.ShapeTuple)
+    assert stuple == z
+
 
 if __name__ == "__main__":
     test_string()