You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/16 22:47:42 UTC

[tvm] 18/20: final optional

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/rebase-08312022-autotensorization-fq2i-changes
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 18b8089564e3eda6caf42ede61a24a6d47efb841
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 2 15:22:55 2022 -0700

    final optional
---
 src/relay/transforms/fold_explicit_padding.cc  | 13 ++---
 src/relay/transforms/pattern_utils.h           | 69 +++++++-------------------
 src/tir/transforms/common_subexpr_elim_tools.h |  5 +-
 src/tir/transforms/loop_partition.cc           | 51 +++++++------------
 4 files changed, 44 insertions(+), 94 deletions(-)

diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc
index 794bcfd3d0..37385f80c1 100644
--- a/src/relay/transforms/fold_explicit_padding.cc
+++ b/src/relay/transforms/fold_explicit_padding.cc
@@ -22,6 +22,7 @@
  * \brief A pass for folding explicit pads into other ops.
  */
 
+#include <dmlc/optional.h>
 #include <tvm/relay/dataflow_matcher.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
@@ -31,10 +32,6 @@
 #include <tvm/tir/op.h>
 #include <tvm/topi/nn/pooling.h>
 
-#include <optional>
-#include <set>
-#include <string>
-
 #include "../op/tensor/transform.h"
 #include "pattern_utils.h"
 
@@ -183,10 +180,10 @@ class SimplifyExplicitPad {
     return attrs;
   }
 
-  static const std::optional<Array<PrimExpr>> get_padding(const PadAttrs* param,
-                                                          std::string data_layout) {
+  static const Optional<Array<PrimExpr>> get_padding(const PadAttrs* param,
+                                                     std::string data_layout) {
     // Gets spatial axes padding from the given PadAttrs `param`. If padding
-    // is non-zero on non-spatial axes, return std::nullopt.
+    // is non-zero on non-spatial axes, return NullOpt.
     ICHECK(param);
     ICHECK(data_layout.size() == param->pad_width.size())
         << "Data Layout and padding attributes should have the same extent";
@@ -199,7 +196,7 @@ class SimplifyExplicitPad {
       if (!image_dims.count(data_layout[i])) {
         for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
           if (param->pad_width[i][j] != 0) {
-            return std::nullopt;
+            return NullOpt;
           }
         }
       }
diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index ffe1cc2ca2..f71d84434d 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -27,6 +27,7 @@
 #define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_
 
 #include <builtin_fp16.h>
+#include <dmlc/optional.h>
 #include <tvm/node/structural_equal.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/nn.h>
@@ -39,7 +40,6 @@
 #include <tvm/tir/data_layout.h>
 
 #include <limits>
-#include <optional>
 #include <string>
 #include <utility>
 #include <vector>
@@ -344,40 +344,6 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
   return Constant(arr);
 }
 
-/*!
- * \brief Create a Constant tensor of zeros.
- *
- * \param dtype The data type.
- * \param shape The shape of the output constant tensor.
- * \return A Constant.
- */
-static inline Constant MakeConstantZeros(DataType dtype, std::vector<int64_t> shape) {
-  runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
-  int64_t data_size = 1;
-  for (int64_t dim : shape) {
-    data_size *= dim;
-  }
-  TVM_DTYPE_DISPATCH(dtype, DType, {
-    for (int64_t i = 0; i < data_size; i++) {
-      if (dtype == DataType::Float(16)) {
-        // convert to float16
-        // storage is uint16_t
-        // Similar handling as that in MakeConstantScalar
-        *(static_cast<DType*>(arr->data) + i) =
-            __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(0));
-      } else if (dtype == DataType::BFloat(16)) {
-        // convert to bfloat16
-        // storage is uint16_t
-        *(static_cast<DType*>(arr->data) + i) =
-            __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(0));
-      } else {
-        *(static_cast<DType*>(arr->data) + i) = 0;
-      }
-    }
-  })
-  return Constant(arr);
-}
-
 /*!
  * \brief Check whether a shape is static and create corresponding Constant.
  Eventually this will be removed and replaced with CheckConstantShapeArrayInteger
@@ -439,47 +405,48 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
  * \param i element index
  * \return Converted scalar value, or None if conversion failed
  */
-static inline std::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
+static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
   if (array->dtype.code == kDLInt) {
     if (array->dtype.bits == 8) {
-      return std::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
     } else if (array->dtype.bits == 16) {
-      return std::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
     } else if (array->dtype.bits == 32) {
-      return std::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return std::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
     }
   } else if (array->dtype.code == kDLUInt) {
     if (array->dtype.bits == 1) {  // bool
-      return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
     } else if (array->dtype.bits == 8) {
-      return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
     } else if (array->dtype.bits == 16) {
-      return std::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
     } else if (array->dtype.bits == 32) {
-      return std::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return std::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
     }
   } else if (array->dtype.code == kDLFloat) {
     if (array->dtype.bits == 16) {
-      return std::optional<long double>(
+      return dmlc::optional<long double>(
           __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
               reinterpret_cast<uint16_t*>(array->data)[i]));
     }
     if (array->dtype.bits == 32) {
-      return std::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return std::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
+      return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
     }
   } else if (array->dtype.code == kDLBfloat) {
     if (array->dtype.bits == 16) {
-      return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
-          reinterpret_cast<uint16_t*>(array->data)[i]));
+      return dmlc::optional<long double>(
+          __extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
+              reinterpret_cast<uint16_t*>(array->data)[i]));
     }
   }
-  return std::nullopt;
+  return dmlc::optional<long double>();
 }
 
 /*!
diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h
index 0871fd0091..fcd29fddc0 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.h
+++ b/src/tir/transforms/common_subexpr_elim_tools.h
@@ -33,11 +33,12 @@
 #include <tvm/tir/stmt.h>
 #include <tvm/tir/stmt_functor.h>  // For the class StmtExprVisitor
 
-#include <optional>
 #include <unordered_map>  // For the hashtable datatype
 #include <utility>        // For pairs datatype
 #include <vector>
 
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
 namespace tvm {
 namespace tir {
 
@@ -176,7 +177,7 @@ class UsesVarName : public StmtExprVisitor {
  */
 void PrintComputationTable(const ComputationTable& table);
 
-using MaybeValue = std::optional<PrimExpr>;
+using MaybeValue = dmlc::optional<PrimExpr>;
 
 bool EqualTerms(const PrimExpr& a, const PrimExpr& b);
 // Used for deciding the (decidable) equivalence relation
diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc
index 6ecc6459b9..677506889e 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -29,7 +29,6 @@
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
 
-#include <optional>
 #include <unordered_map>
 #include <unordered_set>
 
@@ -554,39 +553,25 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
   if (finder.partitions.empty()) return Stmt();
 
   arith::IntervalSet for_interval(min, max);
-
-  auto [middle_interval, cond_set,
-        opt_cond_value] = [&]() -> std::tuple<IntSet, ExpressionSet, std::optional<bool>> {
-    {
-      // find an interval in which all conditions on var are true
-      auto [middle_interval, cond_set] =
-          GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_);
-      if (!middle_interval.IsNothing()) {
-        return {middle_interval, cond_set, true};
-      }
-    }
-
-    {
-      // if such interval doesn't exist, find an interval in which all
-      // conditions on var are false
-      auto [middle_interval, cond_set] =
-          GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_);
-
-      if (!middle_interval.IsNothing()) {
-        return {middle_interval, cond_set, false};
-      }
-    }
-
-    // we couldn't find an interval in which the conditions are
-    // provably true or false.  Therefore, we can't partition the loop
-    // based on those conds
-    return {{}, {}, std::nullopt};
-  }();
-
-  if (!opt_cond_value.has_value()) {
-    return Stmt();
+  bool cond_value;
+  IntSet middle_interval;
+  ExpressionSet cond_set;
+  // find an interval in which all conditions on var are true
+  std::tie(middle_interval, cond_set) =
+      GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_);
+  if (middle_interval.IsNothing()) {
+    // if such interval doesn't exist, find an interval in which all
+    // conditions on var are false
+    std::tie(middle_interval, cond_set) =
+        GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_);
+    if (middle_interval.IsNothing())
+      // we couldn't find an interval in which the conditions are provably true or false
+      // Therefore, we can't partition the loop based on those conds
+      return Stmt();
+    cond_value = false;
+  } else {
+    cond_value = true;
   }
-  bool cond_value = opt_cond_value.value();
 
   IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
   // middle_interval is the subrange of the loop variable range for which a