You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zi...@apache.org on 2021/06/09 23:03:43 UTC
[tvm] branch main updated: [RUNTIME] ShapeTuple Container (#8200)
This is an automated email from the ASF dual-hosted git repository.
ziheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4d9bc9b [RUNTIME] ShapeTuple Container (#8200)
4d9bc9b is described below
commit 4d9bc9b4a3e9e8d3420efe60a52964fcd4c29c8d
Author: ziheng <zi...@apache.org>
AuthorDate: Thu Jun 10 07:03:15 2021 +0800
[RUNTIME] ShapeTuple Container (#8200)
* Add ShapeTuple.
* Update NDArray.
* Documents.
* Lint.
* Lint.
* Lint.
* Address comment.
* Address comment.
* Address comment.
* Lint.
* Lint.
---
include/tvm/runtime/container/shape_tuple.h | 180 ++++++++++++++++++++++++
include/tvm/runtime/ndarray.h | 19 ++-
include/tvm/runtime/object.h | 2 +
include/tvm/runtime/packed_func.h | 1 +
python/tvm/runtime/container.py | 24 ++++
src/node/container_printing.cc | 12 ++
src/runtime/container.cc | 23 ++-
src/runtime/ndarray.cc | 23 +--
tests/python/unittest/test_runtime_container.py | 9 ++
9 files changed, 270 insertions(+), 23 deletions(-)
diff --git a/include/tvm/runtime/container/shape_tuple.h b/include/tvm/runtime/container/shape_tuple.h
new file mode 100644
index 0000000..774077f
--- /dev/null
+++ b/include/tvm/runtime/container/shape_tuple.h
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/container/shape_tuple.h
+ * \brief Runtime ShapeTuple container types.
+ */
+#ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
+#define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
+
+#include <utility>
+#include <vector>
+
+#include "./base.h"
+
+namespace tvm {
+namespace runtime {
+
+/*! \brief An object representing a shape tuple. */
+class ShapeTupleObj : public Object {
+ public:
+ /*! \brief The type of shape index element. */
+ using index_type = int64_t;
+ /*! \brief The pointer to shape tuple data. */
+ index_type* data;
+ /*! \brief The size of the shape tuple object. */
+ uint64_t size;
+
+ static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeShapeTuple;
+ static constexpr const char* _type_key = "runtime.ShapeTuple";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTupleObj, Object);
+
+ private:
+ /*! \brief ShapeTuple object which is moved from std::vector container. */
+ class FromStd;
+
+ friend class ShapeTuple;
+};
+
+/*! \brief An object representing shape tuple moved from std::vector. */
+class ShapeTupleObj::FromStd : public ShapeTupleObj {
+ public:
+ /*! \brief The type of shape index element. */
+ using index_type = ShapeTupleObj::index_type;
+ /*!
+ * \brief Construct a new FromStd object
+ *
+ * \param other The moved/copied std::vector object
+ *
+ * \note If user passes const reference, it will trigger copy. If it's rvalue,
+ * it will be moved into other.
+ */
+ explicit FromStd(std::vector<index_type> other) : data_container{other} {}
+
+ private:
+ /*! \brief Container that holds the memory. */
+ std::vector<index_type> data_container;
+
+ friend class ShapeTuple;
+};
+
+/*!
+ * \brief Reference to shape tuple objects.
+ */
+class ShapeTuple : public ObjectRef {
+ public:
+ /*! \brief The type of shape index element. */
+ using index_type = ShapeTupleObj::index_type;
+
+ /*!
+ * \brief Construct an empty shape tuple.
+ */
+ ShapeTuple() : ShapeTuple(std::vector<index_type>()) {}
+
+ /*!
+ * \brief Constructor from iterator
+ * \param begin begin of iterator
+ * \param end end of iterator
+ * \tparam IterType The type of iterator
+ */
+ template <typename IterType>
+ ShapeTuple(IterType begin, IterType end) : ShapeTuple(std::vector<index_type>(begin, end)) {}
+
+ /*!
+ * \brief constructor from initializer list
+ * \param shape The initializer list
+ */
+ ShapeTuple(std::initializer_list<index_type> shape) : ShapeTuple(shape.begin(), shape.end()) {}
+
+ /*!
+ * \brief Construct a new ShapeTuple object
+ *
+ * \param shape The moved/copied std::vector object
+ *
+ * \note If user passes const reference, it will trigger copy. If it's rvalue,
+ * it will be moved into other.
+ */
+ ShapeTuple(std::vector<index_type> shape); // NOLINT(*)
+
+ /*!
+ * \brief Return the data pointer
+ *
+ * \return const index_type* data pointer
+ */
+ const index_type* data() const { return get()->data; }
+
+ /*!
+ * \brief Return the size of the shape tuple
+ *
+ * \return size_t shape tuple size
+ */
+ size_t size() const { return get()->size; }
+
+ /*!
+ * \brief Immutably read i-th element from the shape tuple.
+ * \param idx The index
+ * \return the i-th element.
+ */
+ index_type operator[](size_t idx) const {
+ ICHECK(0 <= idx && idx < this->size())
+ << "IndexError: indexing " << idx << " on an array of size " << this->size();
+ return this->data()[idx];
+ }
+
+ /*!
+ * \brief Immutably read i-th element from the shape tuple.
+ * \param idx The index
+ * \return the i-th element.
+ */
+ index_type at(size_t idx) const { return this->operator[](idx); }
+
+ /*! \return Whether shape tuple is empty */
+ bool empty() const { return size() == 0; }
+
+ /*! \return The first element of the shape tuple */
+ index_type front() const { return this->at(0); }
+
+ /*! \return The last element of the shape tuple */
+ index_type back() const { return this->at(this->size() - 1); }
+
+ /*! \return begin iterator */
+ const index_type* begin() const { return get()->data; }
+
+ /*! \return end iterator */
+ const index_type* end() const { return (get()->data + size()); }
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeTuple, ObjectRef, ShapeTupleObj);
+};
+
+inline ShapeTuple::ShapeTuple(std::vector<index_type> shape) {
+ auto ptr = make_object<ShapeTupleObj::FromStd>(std::move(shape));
+ ptr->size = ptr->data_container.size();
+ ptr->data = ptr->data_container.data();
+ data_ = std::move(ptr);
+}
+
+} // namespace runtime
+
+// expose the functions to the root namespace.
+using runtime::ShapeTuple;
+using runtime::ShapeTupleObj;
+} // namespace tvm
+
+#endif // TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index bfc681e..1127a9a 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -25,8 +25,8 @@
#define TVM_RUNTIME_NDARRAY_H_
#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/optional.h>
+#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
@@ -128,7 +128,7 @@ class NDArray : public ObjectRef {
* \param dtype The data type of the new array.
* \note The memory size of new array must be smaller than the current one.
*/
- TVM_DLL NDArray CreateView(std::vector<int64_t> shape, DLDataType dtype);
+ TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype);
/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
@@ -143,7 +143,7 @@ class NDArray : public ObjectRef {
* \param mem_scope The memory scope of the array.
* \return The created Array
*/
- TVM_DLL static NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, Device dev,
+ TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev,
Optional<String> mem_scope = NullOpt);
/*!
* \brief Create a NDArray backed by a dlpack tensor.
@@ -166,7 +166,7 @@ class NDArray : public ObjectRef {
TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to,
TVMStreamHandle stream = nullptr);
- TVM_DLL std::vector<int64_t> Shape() const;
+ TVM_DLL ShapeTuple Shape() const;
TVM_DLL runtime::DataType DataType() const;
// internal namespace
struct Internal;
@@ -241,7 +241,7 @@ class NDArray::ContainerBase {
* \brief The shape container,
* can be used used for shape data.
*/
- std::vector<int64_t> shape_;
+ ShapeTuple shape_;
};
/*!
@@ -261,13 +261,13 @@ class NDArray::Container : public Object, public NDArray::ContainerBase {
dl_tensor.byte_offset = 0;
}
- Container(void* data, std::vector<int64_t> shape, DLDataType dtype, Device dev) {
+ Container(void* data, ShapeTuple shape, DLDataType dtype, Device dev) {
// Initialize the type index.
type_index_ = Container::RuntimeTypeIndex();
dl_tensor.data = data;
shape_ = std::move(shape);
dl_tensor.ndim = static_cast<int>(shape_.size());
- dl_tensor.shape = dmlc::BeginPtr(shape_);
+ dl_tensor.shape = const_cast<ShapeTuple::index_type*>(shape_.data());
dl_tensor.dtype = dtype;
dl_tensor.strides = nullptr;
dl_tensor.byte_offset = 0;
@@ -357,8 +357,7 @@ inline void NDArray::CopyTo(const NDArray& other) const {
inline NDArray NDArray::CopyTo(const Device& dev) const {
ICHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
- NDArray ret =
- Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev);
+ NDArray ret = Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev);
this->CopyTo(ret);
return ret;
}
@@ -460,7 +459,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
if (ndim != 0) {
ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format";
}
- NDArray ret = NDArray::Empty(shape, dtype, dev);
+ NDArray ret = NDArray::Empty(ShapeTuple(shape), dtype, dev);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index f13bdee..0ed6117 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -68,6 +68,8 @@ struct TypeIndex {
kRuntimeArray = 4,
/*! \brief runtime::Map. */
kRuntimeMap = 5,
+ /*! \brief runtime::ShapeTuple. */
+ kRuntimeShapeTuple = 6,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h
index 3e8f23b..9bfe379 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -26,6 +26,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/map.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/module.h>
diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py
index 63383e7..7f83693 100644
--- a/python/tvm/runtime/container.py
+++ b/python/tvm/runtime/container.py
@@ -137,3 +137,27 @@ class String(str, PyNativeObject):
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val
+
+
+@tvm._ffi.register_object("runtime.ShapeTuple")
+class ShapeTuple(Object):
+ """TVM runtime ShapeTuple object.
+ Parameters
+ ----------
+ shape : list[int]
+ The shape list used to construct the object.
+ """
+
+ def __init__(self, shape):
+ assert isinstance(shape, (list, tuple)), "Expect list of tuple, but received : {0}".format(
+ type(shape)
+ )
+ for x in shape:
+ assert isinstance(x, int), "Expect int type, but received : {0}".format(type(x))
+ self.__init_handle_by_constructor__(_ffi_api.ShapeTuple, *shape)
+
+ def __len__(self):
+ return _ffi_api.GetShapeTupleSize(self)
+
+ def __getitem__(self, idx):
+ return getitem_helper(self, _ffi_api.GetShapeTupleElem, len(self), idx)
diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc
index 7b97296..1565630 100644
--- a/src/node/container_printing.cc
+++ b/src/node/container_printing.cc
@@ -60,4 +60,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << '}';
});
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<ShapeTupleObj>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const ShapeTupleObj*>(node.get());
+ p->stream << '[';
+ for (size_t i = 0; i < op->size; ++i) {
+ if (i != 0) {
+ p->stream << ", ";
+ }
+ p->stream << op->data[i];
+ }
+ p->stream << ']';
+ });
} // namespace tvm
diff --git a/src/runtime/container.cc b/src/runtime/container.cc
index 9d648dc..159404b 100644
--- a/src/runtime/container.cc
+++ b/src/runtime/container.cc
@@ -25,6 +25,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/closure.h>
#include <tvm/runtime/container/map.h>
+#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
@@ -108,7 +109,6 @@ TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) {
});
// String
-
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) {
@@ -120,7 +120,6 @@ TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) {
});
// Map
-
TVM_REGISTER_OBJECT_TYPE(MapNode);
TVM_REGISTER_GLOBAL("runtime.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
@@ -185,7 +184,27 @@ TVM_REGISTER_GLOBAL("runtime.MapItems").set_body([](TVMArgs args, TVMRetValue* r
TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[];
#endif
+// Closure
TVM_REGISTER_OBJECT_TYPE(ClosureObj);
+// ShapeTuple
+TVM_REGISTER_OBJECT_TYPE(ShapeTupleObj);
+
+TVM_REGISTER_GLOBAL("runtime.ShapeTuple").set_body([](TVMArgs args, TVMRetValue* rv) {
+ std::vector<ShapeTuple::index_type> shape;
+ for (int i = 0; i < args.size(); i++) {
+ shape.push_back(args[i]);
+ }
+ *rv = ShapeTuple(shape);
+});
+
+TVM_REGISTER_GLOBAL("runtime.GetShapeTupleSize").set_body_typed([](ShapeTuple shape) {
+ return static_cast<int64_t>(shape.size());
+});
+
+TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple shape, int idx) {
+ ICHECK_LT(idx, shape.size());
+ return shape[idx];
+});
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index 3d3466b..968a448 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -123,7 +123,7 @@ struct NDArray::Internal {
}
// Local create function which allocates tensor metadata
// but does not allocate space for the data.
- static NDArray Create(std::vector<int64_t> shape, DLDataType dtype, Device dev) {
+ static NDArray Create(ShapeTuple shape, DLDataType dtype, Device dev) {
VerifyDataType(dtype);
// critical zone: construct header
@@ -134,7 +134,7 @@ struct NDArray::Internal {
NDArray ret(GetObjectPtr<Object>(data));
// setup shape
data->shape_ = std::move(shape);
- data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
+ data->dl_tensor.shape = const_cast<ShapeTuple::index_type*>(data->shape_.data());
data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
// setup dtype
data->dl_tensor.dtype = dtype;
@@ -172,7 +172,7 @@ struct NDArray::Internal {
}
};
-NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
+NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) {
ICHECK(data_ != nullptr);
ICHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor";
NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.device);
@@ -190,8 +190,7 @@ NDArray NDArray::CreateView(std::vector<int64_t> shape, DLDataType dtype) {
DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); }
-NDArray NDArray::Empty(std::vector<int64_t> shape, DLDataType dtype, Device dev,
- Optional<String> mem_scope) {
+NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional<String> mem_scope) {
NDArray ret = Internal::Create(shape, dtype, dev);
ret.get_mutable()->dl_tensor.data =
DeviceAPI::Get(ret->device)
@@ -207,9 +206,11 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) {
data->manager_ctx = tensor;
data->dl_tensor = tensor->dl_tensor;
// update shape_
- data->shape_.resize(data->dl_tensor.ndim);
- data->shape_.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim);
- data->dl_tensor.shape = data->shape_.data();
+ std::vector<ShapeTuple::index_type> shape;
+ shape.resize(data->dl_tensor.ndim);
+ shape.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim);
+ data->shape_ = ShapeTuple(shape);
+ data->dl_tensor.shape = const_cast<ShapeTuple::index_type*>(data->shape_.data());
return NDArray(GetObjectPtr<Object>(data));
}
@@ -242,7 +243,7 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str
DeviceAPI::Get(dev)->CopyDataFromTo(const_cast<DLTensor*>(from), to, stream);
}
-std::vector<int64_t> NDArray::Shape() const { return get_mutable()->shape_; }
+ShapeTuple NDArray::Shape() const { return get_mutable()->shape_; }
runtime::DataType NDArray::DataType() const {
return runtime::DataType(get_mutable()->dl_tensor.dtype);
}
@@ -274,7 +275,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_
Device dev;
dev.device_type = static_cast<DLDeviceType>(device_type);
dev.device_id = device_id;
- auto ndarray = NDArray::Empty(std::vector<int64_t>(shape, shape + ndim), dtype, dev);
+ auto ndarray = NDArray::Empty(ShapeTuple(shape, shape + ndim), dtype, dev);
*out = NDArray::Internal::MoveToFFIHandle(ndarray);
API_END();
@@ -283,7 +284,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_
TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body([](TVMArgs args, TVMRetValue* ret) {
int64_t* shape_ptr = static_cast<int64_t*>(static_cast<void*>(args[0]));
int ndim = args[1];
- std::vector<int64_t> shape(shape_ptr, shape_ptr + ndim);
+ ShapeTuple shape(shape_ptr, shape_ptr + ndim);
DataType dtype = args[2];
Device dev = args[3];
Optional<String> mem_scope = args[4];
diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py
index 39fd575..781fd7f 100644
--- a/tests/python/unittest/test_runtime_container.py
+++ b/tests/python/unittest/test_runtime_container.py
@@ -78,7 +78,16 @@ def test_string():
assert s == z
+def test_shape_tuple():
+ shape = [random.randint(-10, 10) for _ in range(5)]
+ stuple = _container.ShapeTuple(shape)
+ len(stuple) == len(shape)
+ for a, b in zip(stuple, shape):
+ assert a == b
+
+
if __name__ == "__main__":
test_string()
test_adt_constructor()
test_tuple_object()
+ test_shape_tuple()