You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 19:38:29 UTC
[tvm] 17/28: final optional
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch aluo/rebase-09192022-autotensorization
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 556c4c27f102439dfcd15c0b9a7b5d04cf533b3c
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 2 15:22:55 2022 -0700
final optional
---
src/relay/transforms/fold_explicit_padding.cc | 13 ++---
src/relay/transforms/pattern_utils.h | 69 +++++++-------------------
src/tir/transforms/common_subexpr_elim_tools.h | 5 +-
src/tir/transforms/loop_partition.cc | 55 +++++++-------------
4 files changed, 44 insertions(+), 98 deletions(-)
diff --git a/src/relay/transforms/fold_explicit_padding.cc b/src/relay/transforms/fold_explicit_padding.cc
index 794bcfd3d0..37385f80c1 100644
--- a/src/relay/transforms/fold_explicit_padding.cc
+++ b/src/relay/transforms/fold_explicit_padding.cc
@@ -22,6 +22,7 @@
* \brief A pass for folding explicit pads into other ops.
*/
+#include <dmlc/optional.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
@@ -31,10 +32,6 @@
#include <tvm/tir/op.h>
#include <tvm/topi/nn/pooling.h>
-#include <optional>
-#include <set>
-#include <string>
-
#include "../op/tensor/transform.h"
#include "pattern_utils.h"
@@ -183,10 +180,10 @@ class SimplifyExplicitPad {
return attrs;
}
- static const std::optional<Array<PrimExpr>> get_padding(const PadAttrs* param,
- std::string data_layout) {
+ static const Optional<Array<PrimExpr>> get_padding(const PadAttrs* param,
+ std::string data_layout) {
// Gets spatial axes padding from the given PadAttrs `param`. If padding
- // is non-zero on non-spatial axes, return std::nullopt.
+ // is non-zero on non-spatial axes, return NullOpt.
ICHECK(param);
ICHECK(data_layout.size() == param->pad_width.size())
<< "Data Layout and padding attributes should have the same extent";
@@ -199,7 +196,7 @@ class SimplifyExplicitPad {
if (!image_dims.count(data_layout[i])) {
for (size_t j = 0; j < param->pad_width[i].size(); ++j) {
if (param->pad_width[i][j] != 0) {
- return std::nullopt;
+ return NullOpt;
}
}
}
diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index ffe1cc2ca2..f71d84434d 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -27,6 +27,7 @@
#define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_
#include <builtin_fp16.h>
+#include <dmlc/optional.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
@@ -39,7 +40,6 @@
#include <tvm/tir/data_layout.h>
#include <limits>
-#include <optional>
#include <string>
#include <utility>
#include <vector>
@@ -344,40 +344,6 @@ static inline Constant MakeConstantTensor(DataType dtype, std::vector<int64_t> s
return Constant(arr);
}
-/*!
- * \brief Create a Constant tensor of zeros.
- *
- * \param dtype The data type.
- * \param shape The shape of the output constant tensor.
- * \return A Constant.
- */
-static inline Constant MakeConstantZeros(DataType dtype, std::vector<int64_t> shape) {
- runtime::NDArray arr = runtime::NDArray::Empty(shape, dtype, {kDLCPU, 0});
- int64_t data_size = 1;
- for (int64_t dim : shape) {
- data_size *= dim;
- }
- TVM_DTYPE_DISPATCH(dtype, DType, {
- for (int64_t i = 0; i < data_size; i++) {
- if (dtype == DataType::Float(16)) {
- // convert to float16
- // storage is uint16_t
- // Similar handling as that in MakeConstantScalar
- *(static_cast<DType*>(arr->data) + i) =
- __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(0));
- } else if (dtype == DataType::BFloat(16)) {
- // convert to bfloat16
- // storage is uint16_t
- *(static_cast<DType*>(arr->data) + i) =
- __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 7>(static_cast<float>(0));
- } else {
- *(static_cast<DType*>(arr->data) + i) = 0;
- }
- }
- })
- return Constant(arr);
-}
-
/*!
* \brief Check whether a shape is static and create corresponding Constant.
Eventually this will be removed and replaced with CheckConstantShapeArrayInteger
@@ -439,47 +405,48 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
* \param i element index
* \return Converted scalar value, or None if conversion failed
*/
-static inline std::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
+static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
- return std::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
- return std::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
- return std::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return std::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLUInt) {
if (array->dtype.bits == 1) { // bool
- return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 8) {
- return std::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
- return std::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
- return std::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return std::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLFloat) {
if (array->dtype.bits == 16) {
- return std::optional<long double>(
+ return dmlc::optional<long double>(
__extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
reinterpret_cast<uint16_t*>(array->data)[i]));
}
if (array->dtype.bits == 32) {
- return std::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return std::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
+ return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
}
} else if (array->dtype.code == kDLBfloat) {
if (array->dtype.bits == 16) {
- return std::optional<long double>(__extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
- reinterpret_cast<uint16_t*>(array->data)[i]));
+ return dmlc::optional<long double>(
+ __extendXfYf2__<uint16_t, uint16_t, 7, float, uint32_t, 23>(
+ reinterpret_cast<uint16_t*>(array->data)[i]));
}
}
- return std::nullopt;
+ return dmlc::optional<long double>();
}
/*!
diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h
index 0871fd0091..fcd29fddc0 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.h
+++ b/src/tir/transforms/common_subexpr_elim_tools.h
@@ -33,11 +33,12 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h> // For the class StmtExprVisitor
-#include <optional>
#include <unordered_map> // For the hashtable datatype
#include <utility> // For pairs datatype
#include <vector>
+#include "../../../3rdparty/dmlc-core/include/dmlc/optional.h"
+
namespace tvm {
namespace tir {
@@ -176,7 +177,7 @@ class UsesVarName : public StmtExprVisitor {
*/
void PrintComputationTable(const ComputationTable& table);
-using MaybeValue = std::optional<PrimExpr>;
+using MaybeValue = dmlc::optional<PrimExpr>;
bool EqualTerms(const PrimExpr& a, const PrimExpr& b);
// Used for deciding the (decidable) equivalence relation
diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc
index e1445d29da..3833be6faf 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -29,7 +29,6 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
-#include <optional>
#include <unordered_map>
#include <unordered_set>
@@ -564,43 +563,25 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
if (finder.partitions.empty()) return Stmt();
arith::IntervalSet for_interval(min, max);
-
- auto [middle_interval, cond_set,
- opt_cond_value] = [&]() -> std::tuple<IntSet, ExpressionSet, std::optional<bool>> {
- {
- // find an interval in which all conditions on var are true
- auto [middle_interval, cond_set] =
- GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_);
- if (!middle_interval.IsNothing()) {
- return {middle_interval, cond_set, true};
- }
- }
-
- {
- // if such interval doesn't exist, find an interval in which all
- // conditions on var are false
- auto [middle_interval, cond_set] =
- GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_);
-
- if (!middle_interval.IsNothing()) {
- return {middle_interval, cond_set, false};
- }
- }
-
- // we couldn't find an interval in which the conditions are
- // provably true or false. Therefore, we can't partition the loop
- // based on those conds
- return {{}, {}, std::nullopt};
- }();
-
- if (!opt_cond_value.has_value()) {
- if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
- analyzer_.CanProve(max - min > 0)) {
- return For(var, min, max - min + 1, ForKind::kUnrolled, body);
- }
- return Stmt();
+ bool cond_value;
+ IntSet middle_interval;
+ ExpressionSet cond_set;
+ // find an interval in which all conditions on var are true
+ std::tie(middle_interval, cond_set) =
+ GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_);
+ if (middle_interval.IsNothing()) {
+ // if such interval doesn't exist, find an interval in which all
+ // conditions on var are false
+ std::tie(middle_interval, cond_set) =
+ GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_);
+ if (middle_interval.IsNothing())
+ // we couldn't find an interval in which the conditions are provably true or false
+ // Therefore, we can't partition the loop based on those conds
+ return Stmt();
+ cond_value = false;
+ } else {
+ cond_value = true;
}
- bool cond_value = opt_cond_value.value();
IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
// middle_interval is the subrange of the loop variable range for which a