You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 19:38:24 UTC

[tvm] 12/28: optional complete

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/rebase-09192022-autotensorization
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit bc7cf1b1247f9994225f69b1970a844ec7e1f863
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 2 14:42:04 2022 -0700

    optional complete
---
 src/tir/ir/buffer_common.h                  | 16 +++++++++-------
 src/tir/ir/expr.cc                          |  8 ++++----
 src/tir/ir/stmt.cc                          |  8 ++++----
 src/tir/transforms/inject_ptx_async_copy.cc |  8 ++++----
 src/tir/transforms/storage_rewrite.cc       | 21 +++++++++------------
 5 files changed, 30 insertions(+), 31 deletions(-)

diff --git a/src/tir/ir/buffer_common.h b/src/tir/ir/buffer_common.h
index 5921c54d98..8dac41a02e 100644
--- a/src/tir/ir/buffer_common.h
+++ b/src/tir/ir/buffer_common.h
@@ -26,7 +26,7 @@
 #include <tvm/ir/type.h>
 #include <tvm/runtime/data_type.h>
 
-#include <optional>
+#include <utility>
 
 namespace tvm {
 namespace tir {
@@ -36,20 +36,22 @@ namespace tir {
  *
  * \param type The type to be checked.
  *
- * \return An std::optional<DataType> object. If the type is a pointer
- * to a primitive type, the object has a value which is the pointed-to
- * type. Otherwise the object is nullopt.
+ * \return A (bool, DataType) pair.  If the type is a pointer to a
+ * primitive, the boolean is true and the DataType is the pointed-to
+ * type.  Otherwise, the boolean is false and the DataType is
+ * default-constructed.  This can be replaced with std::optional with
+ * C++17 if/when C++17 is required.
  */
-inline std::optional<runtime::DataType> GetPointerType(const Type& type) {
+inline std::pair<bool, runtime::DataType> GetPointerType(const Type& type) {
   if (type.defined()) {
     if (auto* ptr_type = type.as<PointerTypeNode>()) {
       if (auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) {
-        return prim_type->dtype;
+        return {true, prim_type->dtype};
       }
     }
   }
 
-  return std::nullopt;
+  return {false, DataType()};
 }
 
 }  // namespace tir
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 59db4ea410..f841f94b5a 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -648,7 +648,7 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S
   // annotation tells us otherwise.
   int element_lanes = 1;
   auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
-  if (pointer_type.has_value()) {
+  if (pointer_type.first) {
     // Cannot check element type of array, as it may be different than
     // the loaded type in some cases.
     //
@@ -663,11 +663,11 @@ Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, S
     // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
     // for discussion.
 
-    // ICHECK(dtype.element_of() == pointer_type->element_of())
+    // ICHECK(dtype.element_of() == pointer_type.second.element_of())
     //     << "Type mismatch, cannot load type " << dtype << " from buffer " <<
     //     buffer_var->name_hint
-    //     << " of type " << pointer_type.value();
-    element_lanes = pointer_type->lanes();
+    //     << " of type " << pointer_type.second;
+    element_lanes = pointer_type.second.lanes();
   }
 
   // The C-based codegens assume that all loads occur on a array with
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index e21d014fe1..524204f3d3 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -271,7 +271,7 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
   // annotation tells us otherwise.
   int element_lanes = 1;
   auto pointer_type = tir::GetPointerType(buffer_var->type_annotation);
-  if (pointer_type.has_value()) {
+  if (pointer_type.first) {
     // Currently cannot check element type of array, see Load::Load
     // for details.
 
@@ -279,10 +279,10 @@ Store::Store(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate,
     // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615
     // for discussion.
 
-    // ICHECK_EQ(value.dtype().element_of(), pointer_type->element_of())
+    // ICHECK_EQ(value.dtype().element_of(), pointer_type.second.element_of())
     //     << "Type mismatch, cannot store type " << value.dtype() << " into buffer "
-    //     << buffer_var->name_hint << " of type " << pointer_type.value();
-    element_lanes = pointer_type->lanes();
+    //     << buffer_var->name_hint << " of type " << pointer_type.second;
+    element_lanes = pointer_type.second.lanes();
   }
 
   ICHECK((value.dtype().lanes() == element_lanes * index.dtype().lanes()) ||
diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc
index 8ee0d054e5..c74ce9d3d2 100644
--- a/src/tir/transforms/inject_ptx_async_copy.cc
+++ b/src/tir/transforms/inject_ptx_async_copy.cc
@@ -60,21 +60,21 @@ class PTXAsyncCopyInjector : public StmtMutator {
           if (bytes == 4 || bytes == 8 || bytes == 16) {
             auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
             auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
-            ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
+            ICHECK(dst_elem_type.first && src_elem_type.first)
                 << "Both store and load buffer should have a pointer type annotation.";
 
             int index_factor = 1;
-            if (dst_elem_type.value() != src_elem_type.value()) {
+            if (dst_elem_type != src_elem_type) {
               // The only case where src and dst have different dtypes is when the dst shared memory
               // is a byte buffer generated by merging dynamic shared memory.
               ICHECK(store->buffer.scope() == "shared.dyn");
-              ICHECK(dst_elem_type.value() == DataType::UInt(8));
+              ICHECK(dst_elem_type.second == DataType::UInt(8));
               // BufferStore/Load have the "pointer reinterpret" semantics according to their
               // "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
               // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
               // To replace BufferStore/Load with cp.async, we need to multiply the store index by
               // the byte size of the "value" dtype, to get the correct offset into the byte buffer.
-              index_factor = src_elem_type->bytes();
+              index_factor = src_elem_type.second.bytes();
             }
 
             if (indices_lanes == 1) {
diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc
index 67972ce672..277d3e63c7 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -917,7 +917,7 @@ class StoragePlanRewriter : public StmtExprMutator {
                          const StorageScope& scope, size_t const_nbits) {
     ICHECK(op != nullptr);
     // Re-use not successful, allocate a new buffer.
-    auto entry = std::make_unique<StorageEntry>();
+    std::unique_ptr<StorageEntry> entry(new StorageEntry());
     entry->attach_scope_ = attach_scope;
     entry->scope = scope;
     entry->elem_type = op->dtype.element_of();
@@ -1028,11 +1028,11 @@ class StoragePlanRewriter : public StmtExprMutator {
   // symbolic free list, for non constant items.
   std::list<StorageEntry*> sym_free_list_;
   // The allocation attach map
-  std::unordered_map<const Object*, std::vector<StorageEntry*>> attach_map_;
+  std::unordered_map<const Object*, std::vector<StorageEntry*> > attach_map_;
   // The allocation assign map
   std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
   // The allocations
-  std::vector<std::unique_ptr<StorageEntry>> alloc_vec_;
+  std::vector<std::unique_ptr<StorageEntry> > alloc_vec_;
   // The buffer objects being remapped
   std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
   // analyzer
@@ -1143,8 +1143,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
     // track the parameter itself.
     for (Var buffer_var : params) {
       auto pointer_type = GetPointerType(buffer_var->type_annotation);
-      if (pointer_type.has_value() && (buffer_map.count(buffer_var) == 0)) {
-        DataType dtype = pointer_type.value();
+      if (pointer_type.first && (buffer_map.count(buffer_var) == 0)) {
+        DataType dtype = pointer_type.second;
         PrimExpr extent = 0;
         OnArrayDeclaration(buffer_var, dtype, extent, BufferVarInfo::kPrimFuncBufferMap);
       }
@@ -1208,8 +1208,8 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
   void HandleLetNode(Var let_var) {
     if (let_var->dtype.is_handle()) {
       auto pointer_type = GetPointerType(let_var->type_annotation);
-      if (pointer_type.has_value()) {
-        OnArrayDeclaration(let_var, pointer_type.value(), 0, BufferVarInfo::kLetNode);
+      if (pointer_type.first) {
+        OnArrayDeclaration(let_var, pointer_type.second, 0, BufferVarInfo::kLetNode);
       } else if (allow_untyped_pointers_) {
         OnArrayDeclaration(let_var, let_var->dtype, 0, BufferVarInfo::kLetNode);
       } else {
@@ -1481,13 +1481,10 @@ class VectorTypeRewriter : public StmtExprMutator {
 
   Stmt VisitStmt_(const LetStmtNode* op) final {
     auto it = rewrite_map_.find(op->var.get());
-    PrimExpr value = this->VisitExpr(op->value);
-    Stmt body = this->VisitStmt(op->body);
-    Var var = (it == rewrite_map_.end()) ? op->var : it->second.new_buffer_var;
-    if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
+    if (it == rewrite_map_.end()) {
       return GetRef<Stmt>(op);
     }
-    return LetStmt(var, value, body);
+    return LetStmt(it->second.new_buffer_var, op->value, op->body);
   }
 
   Buffer RemapBuffer(Buffer buf) {