You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/12/29 23:26:20 UTC
[tvm] branch main updated: [CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple (#13671)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d582b7e511 [CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple (#13671)
d582b7e511 is described below
commit d582b7e511e9ff548067848ca99d8ab9194cbe9d
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Thu Dec 29 18:26:09 2022 -0500
[CONTAINER] Struct Hash/Equal and JSON support for ShapeTuple (#13671)
This PR add struct equal/hash and json serialization support
for shape tuple. Testcases added.
---
src/node/structural_hash.cc | 44 ++++++++++++++++++++++
src/support/base64.h | 9 ++++-
.../unittest/test_container_structural_equal.py | 14 +++++++
tests/python/unittest/test_runtime_container.py | 5 +++
4 files changed, 70 insertions(+), 2 deletions(-)
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 1d1185cddc..0426b8454d 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -484,6 +484,50 @@ TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
return ::tvm::runtime::make_object<ArrayNode>();
});
+struct ShapeTupleObjTrait {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+ static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) {
+ hash_reduce(self->size);
+ for (size_t i = 0; i < self->size; ++i) {
+ hash_reduce(self->data[i]);
+ }
+ }
+
+ static bool SEqualReduce(const ShapeTupleObj* lhs, const ShapeTupleObj* rhs,
+ SEqualReducer equal) {
+ if (lhs->size != rhs->size) return false;
+ for (size_t i = 0; i < lhs->size; ++i) {
+ if (!equal(lhs->data[i], rhs->data[i])) return false;
+ }
+ return true;
+ }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(ShapeTupleObj, ShapeTupleObjTrait)
+ .set_creator([](const std::string& blob) {
+ // Store shape tuple in blob to avoid large integer overflow in JSON.
+ dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
+ support::Base64InStream b64strm(&mstrm);
+ b64strm.InitPosition();
+ uint64_t size;
+ b64strm.Read<uint64_t>(&size);
+ std::vector<int64_t> data(size);
+ b64strm.ReadArray(data.data(), size);
+ ShapeTuple shape(data);
+ return RefToObjectPtr::Get(shape);
+ })
+ .set_repr_bytes([](const Object* n) -> std::string {
+ std::string blob;
+ dmlc::MemoryStringStream mstrm(&blob);
+ support::Base64OutStream b64strm(&mstrm);
+ const auto* shape = static_cast<const runtime::ShapeTupleObj*>(n);
+ b64strm.Write<uint64_t>(shape->size);
+ b64strm.WriteArray(shape->data, shape->size);
+ b64strm.Finish();
+ return blob;
+ });
+
struct MapNodeTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
diff --git a/src/support/base64.h b/src/support/base64.h
index 7b37afce66..aba4197bce 100644
--- a/src/support/base64.h
+++ b/src/support/base64.h
@@ -115,8 +115,10 @@ class Base64InStream : public dmlc::Stream {
}
/*! \brief whether current position is end of a base64 stream */
bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); }
+
+ using dmlc::Stream::Read;
// override read function.
- virtual size_t Read(void* ptr, size_t size) {
+ size_t Read(void* ptr, size_t size) final {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
@@ -224,7 +226,10 @@ class Base64InStream : public dmlc::Stream {
class Base64OutStream : public dmlc::Stream {
public:
explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {}
- virtual void Write(const void* ptr, size_t size) {
+
+ using dmlc::Stream::Write;
+
+ void Write(const void* ptr, size_t size) final {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
diff --git a/tests/python/unittest/test_container_structural_equal.py b/tests/python/unittest/test_container_structural_equal.py
index cdd9ffb7af..61511c609c 100644
--- a/tests/python/unittest/test_container_structural_equal.py
+++ b/tests/python/unittest/test_container_structural_equal.py
@@ -107,6 +107,20 @@ def test_array_structural_equal_to_self(contents):
assert get_first_mismatch_ensure_symmetry(a, b) is None
+@pytest.mark.parametrize(
+ "contents",
+ [
+ [],
+ [1],
+ [1, 2, 3],
+ ],
+)
+def test_shape_tuple_structural_equal_to_self(contents):
+ a = tvm.runtime.ShapeTuple(list(contents))
+ b = tvm.runtime.ShapeTuple(list(contents))
+ assert get_first_mismatch_ensure_symmetry(a, b) is None
+
+
@pytest.mark.parametrize(
"a, b, expected_a_path, expected_b_path",
[
diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py
index 8c302e9205..7538075ae7 100644
--- a/tests/python/unittest/test_runtime_container.py
+++ b/tests/python/unittest/test_runtime_container.py
@@ -90,6 +90,11 @@ def test_shape_tuple():
# ShapleTuple vs. ShapeTuple
assert stuple == _container.ShapeTuple(shape)
+ # test pickle
+ z = pickle.loads(pickle.dumps(stuple))
+ assert isinstance(z, tvm.runtime.ShapeTuple)
+ assert stuple == z
+
if __name__ == "__main__":
test_string()