You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/09/19 19:38:27 UTC

[tvm] 15/28: ad simplify optional

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch aluo/rebase-09192022-autotensorization
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit b16922f704d95d20403159a7389ce49918505d8f
Author: Andrew Zhao Luo <an...@gmail.com>
AuthorDate: Fri Sep 2 15:12:20 2022 -0700

    ad simplify optional
---
 src/te/autodiff/ad_simplify.cc | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc
index 26047e879e..240adf14b3 100644
--- a/src/te/autodiff/ad_simplify.cc
+++ b/src/te/autodiff/ad_simplify.cc
@@ -44,6 +44,7 @@
  * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
  *
  */
+#include <dmlc/optional.h>
 #include <tvm/arith/analyzer.h>
 #include <tvm/arith/int_solver.h>
 #include <tvm/runtime/registry.h>
@@ -53,7 +54,6 @@
 
 #include <iterator>
 #include <memory>
-#include <optional>
 #include <utility>
 
 #include "ad_utils.h"
@@ -629,9 +629,9 @@ class EliminateDivModMutator : public ExprMutator {
   }
 
  private:
-  std::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
-                                                   int64_t val, DivMode mode) {
-    using tresult = std::optional<std::pair<Var, Var>>;
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
 
     // Try to find the variables using the mutated expressions
     if (!e.same_as(mut)) {
@@ -1183,19 +1183,21 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
         return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges);
       }
 
+      PrimExpr new_outer_cond, new_reduce_cond;
       Array<PrimExpr> new_source = red->source;
 
       // Partially lift conditions from the reduce condition
-      auto [new_outer_cond, new_reduce_cond] =
+      std::tie(new_outer_cond, new_reduce_cond) =
           LiftConditionsThroughReduction(red->condition, red->axis, axis);
 
       // If it's not sum then we haven't yet lifted nonzeroness cond from the source
       if (!is_sum) {
+        PrimExpr outer_nz_cond, nz_cond, nz_source;
         auto nz = NonzeronessCondition(red->source[red->value_index]);
         // Append conditions from the reduction
-        PrimExpr nz_source = nz.value;
-        auto [outer_nz_cond, nz_cond] =
-            LiftConditionsThroughReduction(new_reduce_cond && nz.cond, red->axis, axis);
+        nz_cond = new_reduce_cond && nz.cond;
+        nz_source = nz.value;
+        std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis);
         new_outer_cond = new_outer_cond && outer_nz_cond;
         new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
       }