You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 19:38:27 UTC
[tvm] 15/28: ad simplify optional
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch aluo/rebase-09192022-autotensorization
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit b16922f704d95d20403159a7389ce49918505d8f
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 2 15:12:20 2022 -0700
ad simplify optional
---
src/te/autodiff/ad_simplify.cc | 18 ++++++++++--------
1 file changed, 10 insertions(+), 8 deletions(-)
diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc
index 26047e879e..240adf14b3 100644
--- a/src/te/autodiff/ad_simplify.cc
+++ b/src/te/autodiff/ad_simplify.cc
@@ -44,6 +44,7 @@
* Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
*
*/
+#include <dmlc/optional.h>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/runtime/registry.h>
@@ -53,7 +54,6 @@
#include <iterator>
#include <memory>
-#include <optional>
#include <utility>
#include "ad_utils.h"
@@ -629,9 +629,9 @@ class EliminateDivModMutator : public ExprMutator {
}
private:
- std::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
- int64_t val, DivMode mode) {
- using tresult = std::optional<std::pair<Var, Var>>;
+ dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+ int64_t val, DivMode mode) {
+ using tresult = dmlc::optional<std::pair<Var, Var>>;
// Try to find the variables using the mutated expressions
if (!e.same_as(mut)) {
@@ -1183,19 +1183,21 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges);
}
+ PrimExpr new_outer_cond, new_reduce_cond;
Array<PrimExpr> new_source = red->source;
// Partially lift conditions from the reduce condition
- auto [new_outer_cond, new_reduce_cond] =
+ std::tie(new_outer_cond, new_reduce_cond) =
LiftConditionsThroughReduction(red->condition, red->axis, axis);
// If it's not sum then we haven't yet lifted nonzeroness cond from the source
if (!is_sum) {
+ PrimExpr outer_nz_cond, nz_cond, nz_source;
auto nz = NonzeronessCondition(red->source[red->value_index]);
// Append conditions from the reduction
- PrimExpr nz_source = nz.value;
- auto [outer_nz_cond, nz_cond] =
- LiftConditionsThroughReduction(new_reduce_cond && nz.cond, red->axis, axis);
+ nz_cond = new_reduce_cond && nz.cond;
+ nz_source = nz.value;
+ std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis);
new_outer_cond = new_outer_cond && outer_nz_cond;
new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
}