You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zi...@apache.org on 2021/06/09 23:03:43 UTC

[tvm] branch main updated: [RUNTIME] ShapeTuple Container (#8200)

This is an automated email from the ASF dual-hosted git repository.

ziheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d9bc9b  [RUNTIME] ShapeTuple Container (#8200)
4d9bc9b is described below

commit 4d9bc9b4a3e9e8d3420efe60a52964fcd4c29c8d
Author: ziheng <zi...@apache.org>
AuthorDate: Thu Jun 10 07:03:15 2021 +0800

    [RUNTIME] ShapeTuple Container (#8200)
    
    * Add ShapeTuple.
    
    * Update NDArray.
    
    * Documents.
    
    * Lint.
    
    * Lint.
    
    * Lint.
    
    * Address comment.
    
    * Address comment.
    
    * Address comment.
    
    * Lint.
    
    * Lint.
---
 include/tvm/runtime/container/shape_tuple.h     | 180 ++++++++++++++++++++++++
 include/tvm/runtime/ndarray.h                   |  19 ++-
 include/tvm/runtime/object.h                    |   2 +
 include/tvm/runtime/packed_func.h               |   1 +
 python/tvm/runtime/container.py                 |  24 ++++
 src/node/container_printing.cc                  |  12 ++
 src/runtime/container.cc                        |  23 ++-
 src/runtime/ndarray.cc                          |  23 +--
 tests/python/unittest/test_runtime_container.py |   9 ++
 9 files changed, 270 insertions(+), 23 deletions(-)

diff --git a/include/tvm/runtime/container/shape_tuple.h b/include/tvm/runtime/container/shape_tuple.h
new file mode 100644
index 0000000..774077f
--- /dev/null
+++ b/include/tvm/runtime/container/shape_tuple.h
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/container/shape_tuple.h
+ * \brief Runtime ShapeTuple container types.
+ */
+#ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
+#define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
+
+#include <utility>
+#include <vector>
+
+#include "./base.h"
+
+namespace tvm {
+namespace runtime {
+
+/*! \brief An object representing a shape tuple. */
+class ShapeTupleObj : public Object {
+ public:
+  /*! \brief The type of shape index element. */
+  using index_type = int64_t;
+  /*! \brief The pointer to shape tuple data. */
+  index_type* data;
+  /*! \brief The size of the shape tuple object. */
+  uint64_t size;
+
+  static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeShapeTuple;
+  static constexpr const char* _type_key = "runtime.ShapeTuple";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTupleObj, Object);
+
+ private:
+  /*! \brief ShapeTuple object which is moved from std::vector container. */
+  class FromStd;
+
+  friend class ShapeTuple;
+};
+
+/*! \brief An object representing shape tuple moved from std::vector. */
+class ShapeTupleObj::FromStd : public ShapeTupleObj {
+ public:
+  /*! \brief The type of shape index element. */
+  using index_type = ShapeTupleObj::index_type;
+  /*!
+   * \brief Construct a new FromStd object
+   *
+   * \param other The moved/copied std::vector object
+   *
+   * \note If user passes const reference, it will trigger copy. If it's rvalue,
+   * it will be moved into other.
+   */
+  explicit FromStd(std::vector<index_type> other) : data_container{other} {}
+
+ private:
+  /*! \brief Container that holds the memory. */
+  std::vector<index_type> data_container;
+
+  friend class ShapeTuple;
+};
+
+/*!
+ * \brief Reference to shape tuple objects.
+ */
+class ShapeTuple : public ObjectRef {
+ public:
+  /*! \brief The type of shape index element. */
+  using index_type = ShapeTupleObj::index_type;
+
+  /*!
+   * \brief Construct an empty shape tuple.
+   */
+  ShapeTuple() : ShapeTuple(std::vector<index_type>()) {}
+
+  /*!
+   * \brief Constructor from iterator
+   * \param begin begin of iterator
+   * \param end end of iterator
+   * \tparam IterType The type of iterator
+   */
+  template <typename IterType>
+  ShapeTuple(IterType begin, IterType end) : ShapeTuple(std::vector<index_type>(begin, end)) {}
+
+  /*!
+   * \brief constructor from initializer list
+   * \param shape The initializer list
+   */
+  ShapeTuple(std::initializer_list<index_type> shape) : ShapeTuple(shape.begin(), shape.end()) {}
+
+  /*!
+   * \brief Construct a new ShapeTuple object
+   *
+   * \param shape The moved/copied std::vector object
+   *
+   * \note If user passes const reference, it will trigger copy. If it's rvalue,
+   * it will be moved into other.
+   */
+  ShapeTuple(std::vector<index_type> shape);  // NOLINT(*)
+
+  /*!
+   * \brief Return the data pointer
+   *
+   * \return const index_type* data pointer
+   */
+  const index_type* data() const { return get()->data; }
+
+  /*!
+   * \brief Return the size of the shape tuple
+   *
+   * \return size_t shape tuple size
+   */
+  size_t size() const { return get()->size; }
+
+  /*!
+   * \brief Immutably read i-th element from the shape tuple.
+   * \param idx The index
+   * \return the i-th element.
+   */
+  index_type operator[](size_t idx) const {
+    ICHECK(0 <= idx && idx < this->size())
+        << "IndexError: indexing " << idx << " on an array of size " << this->size();
+    return this->data()[idx];
+  }
+
+  /*!
+   * \brief Immutably read i-th element from the shape tuple.
+   * \param idx The index
+   * \return the i-th element.
+   */
+  index_type at(size_t idx) const { return this->operator[](idx); }
+
+  /*! \return Whether shape tuple is empty */
+  bool empty() const { return size() == 0; }
+
+  /*! \return The first element of the shape tuple */
+  index_type front() const { return this->at(0); }
+
+  /*! \return The last element of the shape tuple */
+  index_type back() const { return this->at(this->size() - 1); }
+
+  /*! \return begin iterator */
+  const index_type* begin() const { return get()->data; }
+
+  /*! \return end iterator */
+  const index_type* end() const { return (get()->data + size()); }
+
+  TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeTuple, ObjectRef, ShapeTupleObj);
+};
+
+inline ShapeTuple::ShapeTuple(std::vector<index_type> shape) {
+  auto ptr = make_object<ShapeTupleObj::FromStd>(std::move(shape));
+  ptr->size = ptr->data_container.size();
+  ptr->data = ptr->data_container.data();
+  data_ = std::move(ptr);
+}
+
+}  // namespace runtime
+
+// expose the functions to the root namespace.
+using runtime::ShapeTuple;
+using runtime::ShapeTupleObj;
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index bfc681e..1127a9a 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -25,8 +25,8 @@
 #define TVM_RUNTIME_NDARRAY_H_
 
 #include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/container/map.h>
 #include <tvm/runtime/container/optional.h>
+#include <tvm/runtime/container/shape_tuple.h>
 #include <tvm/runtime/container/string.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/runtime/object.h>
@@ -128,7 +128,7 @@ class NDArray : public ObjectRef {
    * \param dtype The data type of the new array.
    * \note The memory size of new array must be smaller than the current one.
    */
-  TVM_DLL NDArray CreateView(std::vector<int64_t> shape, DLDataType dtype);
+  TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype);
   /*!
    * \brief Create a reference view of NDArray that
    *  represents as DLManagedTensor.
@@ -143,7 +143,7 @@ class NDArray : public ObjectRef {
    * \param mem_scope The memory scope of the array.
    * \return The created Array
    */
-  TVM_DLL static NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, Device dev,
+  TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev,
                                Optional<String> mem_scope = NullOpt);
   /*!
    * \brief Create a NDArray backed by a dlpack tensor.
@@ -166,7 +166,7 @@ class NDArray : public ObjectRef {
   TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to,
                                  TVMStreamHandle stream = nullptr);
 
-  TVM_DLL std::vector<int64_t> Shape() const;
+  TVM_DLL ShapeTuple Shape() const;
   TVM_DLL runtime::DataType DataType() const;
   // internal namespace
   struct Internal;
@@ -241,7 +241,7 @@ class NDArray::ContainerBase {
    * \brief The shape container,
    *  can be used used for shape data.
    */
-  std::vector<int64_t> shape_;
+  ShapeTuple shape_;
 };
 
 /*!
@@ -261,13 +261,13 @@ class NDArray::Container : public Object, public NDArray::ContainerBase {
     dl_tensor.byte_offset = 0;
   }
 
-  Container(void* data, std::vector<int64_t> shape, DLDataType dtype, Device dev) {
+  Container(void* data, ShapeTuple shape, DLDataType dtype, Device dev) {
     // Initialize the type index.
     type_index_ = Container::RuntimeTypeIndex();
     dl_tensor.data = data;
     shape_ = std::move(shape);
     dl_tensor.ndim = static_cast<int>(shape_.size());
-    dl_tensor.shape = dmlc::BeginPtr(shape_);
+    dl_tensor.shape = const_cast<ShapeTuple::index_type*>(shape_.data());
     dl_tensor.dtype = dtype;
     dl_tensor.strides = nullptr;
     dl_tensor.byte_offset = 0;
@@ -357,8 +357,7 @@ inline void NDArray::CopyTo(const NDArray& other) const {
 inline NDArray NDArray::CopyTo(const Device& dev) const {
   ICHECK(data_ != nullptr);
   const DLTensor* dptr = operator->();
-  NDArray ret =
-      Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev);
+  NDArray ret = Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev);
   this->CopyTo(ret);
   return ret;
 }
@@ -460,7 +459,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
   if (ndim != 0) {
     ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format";
   }
-  NDArray ret = NDArray::Empty(shape, dtype, dev);
+  NDArray ret = NDArray::Empty(ShapeTuple(shape), dtype, dev);
   int64_t num_elems = 1;
   int elem_bytes = (ret->dtype.bits + 7) / 8;
   for (int i = 0; i < ret->ndim; ++i) {
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index f13bdee..0ed6117 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -68,6 +68,8 @@ struct TypeIndex {
     kRuntimeArray = 4,
     /*! \brief runtime::Map. */
     kRuntimeMap = 5,
+    /*! \brief runtime::ShapeTuple. */
+    kRuntimeShapeTuple = 6,
     // static assignments that may subject to change.
     kRuntimeClosure,
     kRuntimeADT,
diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h
index 3e8f23b..9bfe379 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -26,6 +26,7 @@
 
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/map.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/module.h>
diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py
index 63383e7..7f83693 100644
--- a/python/tvm/runtime/container.py
+++ b/python/tvm/runtime/container.py
@@ -137,3 +137,27 @@ class String(str, PyNativeObject):
         val = str.__new__(cls, content)
         val.__tvm_object__ = obj
         return val
+
+
+@tvm._ffi.register_object("runtime.ShapeTuple")
+class ShapeTuple(Object):
+    """TVM runtime ShapeTuple object.
+    Parameters
+    ----------
+    shape : list[int]
+        The shape list used to construct the object.
+    """
+
+    def __init__(self, shape):
+        assert isinstance(shape, (list, tuple)), "Expect list of tuple, but received : {0}".format(
+            type(shape)
+        )
+        for x in shape:
+            assert isinstance(x, int), "Expect int type, but received : {0}".format(type(x))
+        self.__init_handle_by_constructor__(_ffi_api.ShapeTuple, *shape)
+
+    def __len__(self):
+        return _ffi_api.GetShapeTupleSize(self)
+
+    def __getitem__(self, idx):
+        return getitem_helper(self, _ffi_api.GetShapeTupleElem, len(self), idx)
diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc
index 7b97296..1565630 100644
--- a/src/node/container_printing.cc
+++ b/src/node/container_printing.cc
@@ -60,4 +60,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << '}';
     });
 
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<ShapeTupleObj>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const ShapeTupleObj*>(node.get());
+      p->stream << '[';
+      for (size_t i = 0; i < op->size; ++i) {
+        if (i != 0) {
+          p->stream << ", ";
+        }
+        p->stream << op->data[i];
+      }
+      p->stream << ']';
+    });
 }  // namespace tvm
diff --git a/src/runtime/container.cc b/src/runtime/container.cc
index 9d648dc..159404b 100644
--- a/src/runtime/container.cc
+++ b/src/runtime/container.cc
@@ -25,6 +25,7 @@
 #include <tvm/runtime/container/array.h>
 #include <tvm/runtime/container/closure.h>
 #include <tvm/runtime/container/map.h>
+#include <tvm/runtime/container/shape_tuple.h>
 #include <tvm/runtime/container/string.h>
 #include <tvm/runtime/memory.h>
 #include <tvm/runtime/object.h>
@@ -108,7 +109,6 @@ TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) {
 });
 
 // String
-
 TVM_REGISTER_OBJECT_TYPE(StringObj);
 
 TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) {
@@ -120,7 +120,6 @@ TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) {
 });
 
 // Map
-
 TVM_REGISTER_OBJECT_TYPE(MapNode);
 
 TVM_REGISTER_GLOBAL("runtime.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -185,7 +184,27 @@ TVM_REGISTER_GLOBAL("runtime.MapItems").set_body([](TVMArgs args, TVMRetValue* r
 TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[];
 #endif
 
+// Closure
 TVM_REGISTER_OBJECT_TYPE(ClosureObj);
 
+// ShapeTuple
+TVM_REGISTER_OBJECT_TYPE(ShapeTupleObj);
+
+TVM_REGISTER_GLOBAL("runtime.ShapeTuple").set_body([](TVMArgs args, TVMRetValue* rv) {
+  std::vector<ShapeTuple::index_type> shape;
+  for (int i = 0; i < args.size(); i++) {
+    shape.push_back(args[i]);
+  }
+  *rv = ShapeTuple(shape);
+});
+
+TVM_REGISTER_GLOBAL("runtime.GetShapeTupleSize").set_body_typed([](ShapeTuple shape) {
+  return static_cast<int64_t>(shape.size());
+});
+
+TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple shape, int idx) {
+  ICHECK_LT(idx, shape.size());
+  return shape[idx];
+});
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index 3d3466b..968a448 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -123,7 +123,7 @@ struct NDArray::Internal {
   }
   // Local create function which allocates tensor metadata
   // but does not allocate space for the data.
-  static NDArray Create(std::vector<int64_t> shape, DLDataType dtype, Device dev) {
+  static NDArray Create(ShapeTuple shape, DLDataType dtype, Device dev) {
     VerifyDataType(dtype);
 
     // critical zone: construct header
@@ -134,7 +134,7 @@ struct NDArray::Internal {
     NDArray ret(GetObjectPtr<Object>(data));
     // setup shape
     data->shape_ = std::move(shape);
-    data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
+    data->dl_tensor.shape = const_cast<ShapeTuple::index_type*>(data->shape_.data());
     data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
     // setup dtype
     data->dl_tensor.dtype = dtype;
@@ -172,7 +172,7 @@ struct NDArray::Internal {
   }
 };
 
-NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
+NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) {
   ICHECK(data_ != nullptr);
   ICHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor";
   NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.device);
@@ -190,8 +190,7 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
 
 DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); }
 
-NDArray NDArray::Empty(std::vector<int64_t> shape, DLDataType dtype, Device dev,
-                       Optional<String> mem_scope) {
+NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional<String> mem_scope) {
   NDArray ret = Internal::Create(shape, dtype, dev);
   ret.get_mutable()->dl_tensor.data =
       DeviceAPI::Get(ret->device)
@@ -207,9 +206,11 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
   data->manager_ctx = tensor;
   data->dl_tensor = tensor->dl_tensor;
   // update shape_
-  data->shape_.resize(data->dl_tensor.ndim);
-  data->shape_.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim);
-  data->dl_tensor.shape = data->shape_.data();
+  std::vector<ShapeTuple::index_type> shape;
+  shape.resize(data->dl_tensor.ndim);
+  shape.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim);
+  data->shape_ = ShapeTuple(shape);
+  data->dl_tensor.shape = const_cast<ShapeTuple::index_type*>(data->shape_.data());
   return NDArray(GetObjectPtr<Object>(data));
 }
 
@@ -242,7 +243,7 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str
   DeviceAPI::Get(dev)->CopyDataFromTo(const_cast<DLTensor*>(from), to, stream);
 }
 
-std::vector<int64_t> NDArray::Shape() const { return get_mutable()->shape_; }
+ShapeTuple NDArray::Shape() const { return get_mutable()->shape_; }
 runtime::DataType NDArray::DataType() const {
   return runtime::DataType(get_mutable()->dl_tensor.dtype);
 }
@@ -274,7 +275,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_
   Device dev;
   dev.device_type = static_cast<DLDeviceType>(device_type);
   dev.device_id = device_id;
-  auto ndarray = NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, dev);
+  auto ndarray = NDArray::Empty(ShapeTuple(shape, shape + ndim), dtype, dev);
 
   *out = NDArray::Internal::MoveToFFIHandle(ndarray);
   API_END();
@@ -283,7 +284,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_
 TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body([](TVMArgs args, TVMRetValue* ret) {
   int64_t* shape_ptr = static_cast<int64_t*>(static_cast<void*>(args[0]));
   int ndim = args[1];
-  std::vector<int64_t> shape(shape_ptr, shape_ptr + ndim);
+  ShapeTuple shape(shape_ptr, shape_ptr + ndim);
   DataType dtype = args[2];
   Device dev = args[3];
   Optional<String> mem_scope = args[4];
diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py
index 39fd575..781fd7f 100644
--- a/tests/python/unittest/test_runtime_container.py
+++ b/tests/python/unittest/test_runtime_container.py
@@ -78,7 +78,16 @@ def test_string():
     assert s == z
 
 
+def test_shape_tuple():
+    shape = [random.randint(-10, 10) for _ in range(5)]
+    stuple = _container.ShapeTuple(shape)
+    len(stuple) == len(shape)
+    for a, b in zip(stuple, shape):
+        assert a == b
+
+
 if __name__ == "__main__":
     test_string()
     test_adt_constructor()
     test_tuple_object()
+    test_shape_tuple()