You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 19:38:24 UTC
[tvm] 12/28: optional complete
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch aluo/rebase-09192022-autotensorization
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit bc7cf1b1247f9994225f69b1970a844ec7e1f863
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 2 14:42:04 2022 -0700
optional complete
---
src/tir/ir/buffer_common.h | 16 +++++++++-------
src/tir/ir/expr.cc | 8 ++++----
src/tir/ir/stmt.cc | 8 ++++----
src/tir/transforms/inject_ptx_async_copy.cc | 8 ++++----
src/tir/transforms/storage_rewrite.cc | 21 +++++++++------------
5 files changed, 30 insertions(+), 31 deletions(-)
diff --git a/src/tir/ir/buffer_common.h b/src/tir/ir/buffer_common.h
index 5921c54d98..8dac41a02e 100644
--- a/src/tir/ir/buffer_common.h
+++ b/src/tir/ir/buffer_common.h
@@ -26,7 +26,7 @@
#include <tvm/ir/type.h>
#include <tvm/runtime/data_type.h>
-#include <optional>
+#include <utility>
namespace tvm {
namespace tir {
@@ -36,20 +36,22 @@ namespace tir {
*
* \param type The type to be checked.
*
- * \return An std::optional<DataType> object. If the type is a pointer
- * to a primitive type, the object has a value which is the pointed-to
- * type. Otherwise the object is nullopt.
+ * \return A (bool, DataType) pair. If the type is a pointer to a
+ * primitive, the boolean is true and the DataType is the pointed-to
+ * type. Otherwise, the boolean is false and the DataType is
+ * default-constructed. This can be replaced with std::optional with
+ * C++17 if/when C++17 is required.
*/
-inline std::optional<runtime::DataType> GetPointerType(const Type& type) {
+inline std::pair<bool, runtime::DataType> GetPointerType(const Type& type) {
if (type.defined()) {
if (auto* ptr_type = type.as<PointerTypeNode>()) {
if (auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
- return prim_type->dtype;
+ return {true, prim_type->dtype};
}
}
}
- return std::nullopt;
+ return {false, DataType()};
}
} // namespace tir
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 59db4ea410..f841f94b5a 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -648,7 +648,7 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S
// annotation tells us otherwise.
int element_lanes = 1;
auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
- if (pointer_type.has_value()) {
+ if (pointer_type.first) {
// Cannot check element type of array, as it may be different than
// the loaded type in some cases.
//
@@ -663,11 +663,11 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S
// See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.
- // ICHECK(dtype.element_of() == pointer_type->element_of())
+ // ICHECK(dtype.element_of() == pointer_type.second.element_of())
// << "Type mismatch, cannot load type " << dtype << " from buffer " <<
// buffer_var->name_hint
- // << " of type " << pointer_type.value();
- element_lanes = pointer_type->lanes();
+ // << " of type " << pointer_type.second;
+ element_lanes = pointer_type.second.lanes();
}
// The C-based codegens assume that all loads occur on a array with
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index e21d014fe1..524204f3d3 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -271,7 +271,7 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
// annotation tells us otherwise.
int element_lanes = 1;
auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
- if (pointer_type.has_value()) {
+ if (pointer_type.first) {
// Currently cannot check element type of array, see Load::Load
// for details.
@@ -279,10 +279,10 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
// See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
// for discussion.
- // ICHECK_EQ(value.dtype().element_of(), pointer_type->element_of())
+ // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of())
// << "Type mismatch, cannot store type " << value.dtype() << " into buffer "
- // << buffer_var->name_hint << " of type " << pointer_type.value();
- element_lanes = pointer_type->lanes();
+ // << buffer_var->name_hint << " of type " << pointer_type.second;
+ element_lanes = pointer_type.second.lanes();
}
ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) ||
diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc
index 8ee0d054e5..c74ce9d3d2 100644
--- a/src/tir/transforms/inject_ptx_async_copy.cc
+++ b/src/tir/transforms/inject_ptx_async_copy.cc
@@ -60,21 +60,21 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
- ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
+ ICHECK(dst_elem_type.first && src_elem_type.first)
<< "Both store and load buffer should have a pointer type annotation.";
int index_factor = 1;
- if (dst_elem_type.value() != src_elem_type.value()) {
+ if (dst_elem_type != src_elem_type) {
// The only case where src and dst have different dtypes is when the dst shared memory
// is a byte buffer generated by merging dynamic shared memory.
ICHECK(store->buffer.scope() == "shared.dyn");
- ICHECK(dst_elem_type.value() == DataType::UInt(8));
+ ICHECK(dst_elem_type.second == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according to their
// "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
// for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
// To replace BufferStore/Load with cp.async, we need to multiply the store index by
// the byte size of the "value" dtype, to get the correct offset into the byte buffer.
- index_factor = src_elem_type->bytes();
+ index_factor = src_elem_type.second.bytes();
}
if (indices_lanes == 1) {
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index 67972ce672..277d3e63c7 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -917,7 +917,7 @@ class StoragePlanRewriter : public StmtExprMutator {
const StorageScope& scope, size_t const_nbits) {
ICHECK(op != nullptr);
// Re-use not successful, allocate a new buffer.
- auto entry = std::make_unique<StorageEntry>();
+ std::unique_ptr<StorageEntry> entry(new StorageEntry());
entry->attach_scope_ = attach_scope;
entry->scope = scope;
entry->elem_type = op->dtype.element_of();
@@ -1028,11 +1028,11 @@ class StoragePlanRewriter : public StmtExprMutator {
// symbolic free list, for non constant items.
std::list<StorageEntry*> sym_free_list_;
// The allocation attach map
- std::unordered_map<const Object*, std::vector<StorageEntry*>> attach_map_;
+ std::unordered_map<const Object*, std::vector<StorageEntry*> > attach_map_;
// The allocation assign map
std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
// The allocations
- std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
+ std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
// The buffer objects being remapped
std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
// analyzer
@@ -1143,8 +1143,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
// track the parameter itself.
for (Var buffer_var : params) {
auto pointer_type = GetPointerType(buffer_var->type_annotation);
- if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) {
- DataType dtype = pointer_type.value();
+ if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) {
+ DataType dtype = pointer_type.second;
PrimExpr extent = 0;
OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncBufferMap);
}
@@ -1208,8 +1208,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
void HandleLetNode(Var let_var) {
if (let_var->dtype.is_handle()) {
auto pointer_type = GetPointerType(let_var->type_annotation);
- if (pointer_type.has_value()) {
- OnArrayDeclaration(let_var, pointer_type.value(), 0, BufferVarInfo::kLetNode);
+ if (pointer_type.first) {
+ OnArrayDeclaration(let_var, pointer_type.second, 0, BufferVarInfo::kLetNode);
} else if (allow_untyped_pointers_) {
OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode);
} else {
@@ -1481,13 +1481,10 @@ class VectorTypeRewriter : public StmtExprMutator {
Stmt VisitStmt_(const LetStmtNode* op) final {
auto it = rewrite_map_.find(op->var.get());
- PrimExpr value = this->VisitExpr(op->value);
- Stmt body = this->VisitStmt(op->body);
- Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
- if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
+ if (it == rewrite_map_.end()) {
return GetRef<Stmt>(op);
}
- return LetStmt(var, value, body);
+ return LetStmt(it->second.new_buffer_var, op->value, op->body);
}
Buffer RemapBuffer(Buffer buf) {