You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/16 15:45:43 UTC

[GitHub] [incubator-tvm] yzhliu opened a new pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

yzhliu opened a new pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078


   Co-authored-by: Sergei Grechanik <se...@gmail.com>
   
   This is the PR that aim to remove the intermediate large Jacobian tensor in te.autodiff, as well as to do some optimizations such as inline. We have made a series of PRs for https://discuss.tvm.ai/t/rfc-bring-in-tensor-expression-autodiff and this will be the last one.
   
   @sgrechanik-h @MarisaKirisame @junrushao1994 @xqdan @tqchen Please take a look.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#issuecomment-673136830


   @tqchen I make them static now.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r463942670



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {

Review comment:
       your approach is pure conjunction. I was suggesting maybe a disjunction of conjunction will allow more optimization ability. it is fine if you dont do it.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -63,6 +63,7 @@ namespace te {
 using arith::DivMode;
 using arith::kFloorDiv;
 using arith::kTruncDiv;
+using arith::ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE;
 
 template <class K, class V>
 Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {

Review comment:
       same as merge... lets move them both?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r467257282



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {

Review comment:
       I see. I'll keep in mind and maybe change it later.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r470193800



##########
File path: include/tvm/node/container.h
##########
@@ -1439,6 +1427,22 @@ class Map : public ObjectRef {
   MapNode* GetMapNode() const { return static_cast<MapNode*>(data_.get()); }
 };
 
+/*!
+ * \brief Merge two Maps.
+ * \param lhs the first Map to merge.
+ * \param rhs the second Map to merge.
+ * @return The merged Array. Original Maps are kept unchanged.
+ */
+template <typename K, typename V,
+          typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
+          typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
+static Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {

Review comment:
       @tqchen my bad I forgot to when copy-pasting. modified.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r463929804



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT

Review comment:
       Actually in my understanding it's not straightforward to separate NOT node here, as the false branch of (!pair.a) will also contain the reduction (instead of zero). I'm not sure whether it provides benefit, @sergei-grechanik would you help to comment?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r469674445



##########
File path: include/tvm/node/container.h
##########
@@ -1439,6 +1427,22 @@ class Map : public ObjectRef {
   MapNode* GetMapNode() const { return static_cast<MapNode*>(data_.get()); }
 };
 
+/*!
+ * \brief Merge two Maps.
+ * \param lhs the first Map to merge.
+ * \param rhs the second Map to merge.
+ * @return The merged Array. Original Maps are kept unchanged.
+ */
+template <typename K, typename V,
+          typename = typename std::enable_if<std::is_base_of<ObjectRef, K>::value>::type,
+          typename = typename std::enable_if<std::is_base_of<ObjectRef, V>::value>::type>
+static Map<K, V> Merge(Map<K, V> lhs, const Map<K, V>& rhs) {

Review comment:
       static->inline as it is in the header




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r469615735



##########
File path: include/tvm/runtime/container.h
##########
@@ -847,6 +847,19 @@ class Array : public ObjectRef {
  public:
   // Array's own methods
 
+  /*!
+   * \brief Concat two Arrays.
+   * \param lhs first Array to be concatenated.
+   * \param lhs second Array to be concatenated.
+   * \return The concatenated Array. Original Arrays are kept unchanged.
+   */
+  static Array<T> Concat(Array<T> lhs, const Array<T>& rhs) {

Review comment:
       We can make it a global template fucntion, so you can do Concat(lhs, rhs) without having to refer to the original type. Note that the type signature should be able to only run the function for the specific type so we won't risk over generalization




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#issuecomment-674181792


   @tqchen it's ready


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r456881913



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by autodiff.
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0], 3), 0)) {

Review comment:
       refactor the magic number 3.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by autodiff.

Review comment:
       All in all I found this file confusing. Can you describe what compiler pass/what simplification is it doing?

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by autodiff.
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0], 3), 0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result = analyzer.Simplify(combiner->result[0], 3);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index], 3), 0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, 3);

Review comment:
       same everywhere - make it a constant or an argument




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] sergei-grechanik commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
sergei-grechanik commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r464104628



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+using arith::ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+    return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(op_if_then_else_)) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+  const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`

Review comment:
       Replace "Since" with "If" because this procedure now works for both definitions (and the floor div shouldn't cause this problem anyway).

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+using arith::ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+    return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(op_if_then_else_)) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+  const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = domain->variables.Concat(elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);

Review comment:
       Not sure if this todo should be retained since the function AddOuterVariablesIntoDomain doesn't exist in this version.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+using arith::ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+    return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(op_if_then_else_)) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+  const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = domain->variables.Concat(elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.

Review comment:
       This comment should probably be placed before the next function.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT

Review comment:
       `not`s could be pushed down using De Morgan laws before running this function but this case didn't seem to be important enough, and I was lazy to implement it, so now this function treats them as opaque expressions.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#issuecomment-665418427


   I have read the paper and roughly understand how it work. will continue reviewing shortly after.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#issuecomment-675765991


   Thanks @yzhliu @sergei-grechanik @MarisaKirisame !


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r467258142



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT
+  if (const AndNode* op = cond.as<AndNode>()) {
+    auto pair_a = ImplicationNotContainingVars(op->a, vars);
+    auto pair_b = ImplicationNotContainingVars(op->b, vars);
+    return {pair_a.first && pair_b.first, pair_a.second && pair_b.second};
+  } else if (const OrNode* op = cond.as<OrNode>()) {
+    auto pair_a = ImplicationNotContainingVars(op->a, vars);
+    auto pair_b = ImplicationNotContainingVars(op->b, vars);
+    return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
+                                              (pair_b.first || pair_a.second) &&
+                                              (pair_a.second || pair_b.second)};
+  } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
+    return {cond, const_true()};
+  } else {
+    return {const_true(), cond};
+  }
+}
+
+// Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out
+// (in)equalities which do not depend on the reduction variables.
+std::pair<PrimExpr, PrimExpr> LiftConditionsThroughReduction(const PrimExpr& cond,
+                                                             const Array<IterVar>& red_axis,
+                                                             const Array<IterVar>& outer_axis) {
+  // Factor out atomics so that we can consider this as a system of inequalities
+  auto factoratomic_res = FactorOutAtomicFormulas(cond);
+  Array<PrimExpr> atomics = factoratomic_res.atomic_formulas;
+  const PrimExpr& rest = factoratomic_res.rest;
+
+  Array<Var> allvars;
+  for (const IterVar& v : red_axis) {
+    allvars.push_back(v->var);
+  }
+  for (const IterVar& v : outer_axis) {
+    allvars.push_back(v->var);
+  }
+
+  auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis));
+  // start from reduction vars, so that input vars don't depend on them
+  arith::IntConstraints ineq_to_solve(allvars, vranges, atomics);
+  auto res_ineq = arith::SolveLinearInequalities(ineq_to_solve);
+  atomics = arith::AsConditions(allvars, res_ineq.first, res_ineq.second);
+
+  // Append the rest part
+  PrimExpr rewritten_cond = All(atomics) && rest;
+
+  std::unordered_set<const VarNode*> vset;
+  for (const IterVar& v : red_axis) {
+    vset.insert(v->var.get());
+  }
+
+  // The outer (first) condition does not contain reduction vars,
+  // the inner (second) condition is everything else
+  auto res = ImplicationNotContainingVars(rewritten_cond, vset);
+  return res;
+}
+
+// Convert an array of itervars to an array of inequalities
+Array<PrimExpr> IterVarsToInequalities(const Array<IterVar>& itervars) {
+  Array<PrimExpr> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(GE(v->var, v->dom->min));
+    res.push_back(LT(v->var, v->dom->min + v->dom->extent));
+  }
+  return res;
+}
+
+class RemoveRedundantInequalitiesMutator : public ExprMutator {
+ public:
+  explicit RemoveRedundantInequalitiesMutator(Array<PrimExpr> known) {
+    for (const PrimExpr& cond : known) {
+      known_.push_back(analyzer_.Simplify(cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE));
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const SelectNode* op) {
+    bool has_side_effect = (SideEffect(GetRef<PrimExpr>(op)) > CallEffectKind::kReadState);
+    PrimExpr new_cond =
+        analyzer_.Simplify(VisitExpr(op->condition), ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (is_one(new_cond) && !has_side_effect) {
+      return VisitExpr(op->true_value);
+    } else if (is_zero(new_cond) && !has_side_effect) {
+      return VisitExpr(op->false_value);
+    } else {
+      Array<PrimExpr> new_known = known_;
+      for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+        new_known.push_back(atomic);
+      }
+      RemoveRedundantInequalitiesMutator new_mutator(new_known);
+      // Note that we mutate only the true value with the new mutator
+      // TODO(sgrechanik-h): Update known conditions for the false value as well
+      return Select(new_cond, new_mutator(op->true_value), VisitExpr(op->false_value));
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const CallNode* op) {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(VisitExpr(op->args[0]), ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (is_one(new_cond)) {
+        return VisitExpr(op->args[1]);
+      } else if (is_zero(new_cond)) {
+        return VisitExpr(op->args[2]);
+      } else {
+        Array<PrimExpr> new_known = known_;
+        for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+          new_known.push_back(atomic);
+        }
+        RemoveRedundantInequalitiesMutator new_mutator(new_known);
+        // Note that we mutate only the true value with the new mutator
+        // TODO(sgrechanik-h): Update known conditions for the false value as well
+        return if_then_else(new_cond, new_mutator(op->args[1]), VisitExpr(op->args[2]));
+      }
+    } else {
+      return ExprMutator::VisitExpr_(op);
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const ReduceNode* op) {
+    Array<PrimExpr> known_with_axes = known_;
+    for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) {
+      known_with_axes.push_back(axis_cond);
+    }
+    RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes);
+
+    PrimExpr new_cond = mutator_with_axes(op->condition);
+
+    Array<PrimExpr> new_known = known_with_axes;
+    for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+      new_known.push_back(atomic);
+    }
+    RemoveRedundantInequalitiesMutator new_mutator(new_known);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : op->source) {
+      new_source.push_back(new_mutator(src));
+    }
+
+    return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index);
+  }
+
+  virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const NENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const LTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const LENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const GTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const GENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+
+  virtual PrimExpr VisitExpr_(const AndNode* op) { return VisitExpr(op->a) && VisitExpr(op->b); }
+
+ private:
+  PrimExpr MutateAtomic_(const PrimExpr& e) {
+    PrimExpr simplified = analyzer_.Simplify(e, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    for (const PrimExpr& other : known_) {
+      if (ExprDeepEqual()(simplified, other)) {
+        return const_true();
+      }
+    }
+    return simplified;
+  }
+
+  Array<PrimExpr> known_;
+  arith::Analyzer analyzer_;
+};
+
+// Propagate information from conditions and remove redundant inequalities
+inline PrimExpr RemoveRedundantInequalities(const PrimExpr& expr, const Array<PrimExpr>& known) {
+  return RemoveRedundantInequalitiesMutator(known)(expr);
+}
+
+// Extract the given expr under the given condition as a separate tensor if the volume of the
+// extracted tensor will be less than the volume of the outer_axis
+PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
+                            const Array<Var>& outer_axis, const Map<Var, Range>& vranges) {
+  // solve cond, e.g., (jac_i0 == i) && (jac_i1 == j)
+  arith::IntConstraints domain_to_solve(outer_axis, vranges,
+                                        FactorOutAtomicFormulas(cond).to_array());
+  auto res = SimplifyDomain(domain_to_solve);
+
+  arith::Analyzer analyzer;
+  analyzer.Bind(res->dst->ranges);
+  PrimExpr new_expr = analyzer.Simplify(Substitute(expr, res->src_to_dst),
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  // TODO(yzhliu): This is mostly done to simplify if_then_else
+  // which is not realized by the canonical simplifier
+  new_expr = RemoveRedundantInequalities(new_expr, res->dst->relations);
+
+  // Keep only those variables of the new vars which are used in the new_expr
+  Array<Var> used_res_variables;
+  for (const Var& var : res->dst->variables) {
+    if (ExprUseVar(new_expr, var)) {
+      CHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
+      used_res_variables.push_back(var);
+    }
+  }
+
+  // If the expression does not use vars then it is probably better to keep it inlined
+  if (used_res_variables.empty()) {
+    // We can return the new_expr here instead of the old expr because it doesn't use variables
+    // otherwise we would need to replace the new vars or create a let-expression
+    return new_expr;
+  }
+
+  // If it's already tensor[...] then it will probably be useless to further simplify it.
+  if (new_expr.as<ProducerLoadNode>()) {
+    return expr;
+  }
+
+  // Compute volumes before and after
+  PrimExpr old_volume = make_const(DataType::Int(64), 1);
+  for (const Var& var : outer_axis) {
+    CHECK(vranges.count(var)) << "Range of " << var << " was not provided.";
+    old_volume = old_volume * vranges[var]->extent;
+  }
+
+  PrimExpr new_volume = make_const(DataType::Int(64), 1);
+  for (const Var& var : used_res_variables) {
+    new_volume = new_volume * res->dst->ranges[var]->extent;
+  }
+
+  // if we can prove that the old volume is not greater than the new volume then
+  // prefer the old expression.
+  arith::Analyzer ana_vranges;
+  ana_vranges.Bind(vranges);
+  if (ana_vranges.CanProve(old_volume <= new_volume)) {
+    return expr;
+  }
+
+  Tensor tensor = TensorFromExpr(new_expr, IterVarsFromMap(used_res_variables, res->dst->ranges),
+                                 "extracted_tensor");
+
+  Array<PrimExpr> args;
+  for (const Var& var : used_res_variables) {
+    args.push_back(res->dst_to_src[var]);
+  }
+
+  return ProducerLoad(tensor, args);
+}
+
+class FreeVarsVisitor : public StmtExprVisitor {

Review comment:
       I removed it and use `UndefinedVars` instead now.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r463942741



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT

Review comment:
       I just have no idea why is there a single NOT floating around in the TODO. please add more words at least.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r462011402



##########
File path: include/tvm/arith/int_solver.h
##########
@@ -41,6 +41,11 @@ using tir::IterVar;
 using tir::Var;
 using tir::VarNode;
 
+// According to experiments two best simplifications orders were can->rw and rw->can->rw,
+// but rw->can->rw is better for a couple of cases.
+// Also we should end with rw because it factors multipliers out.
+#define ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE 3

Review comment:
       use constexpr int instead of macro




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen merged pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r463929804



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT

Review comment:
       Actually in my understanding it's not straightforward to separate NOT node here, as the false branch of (!pair.first) will also contain the reduction (instead of zero). I'm not sure whether it has benefit, @sergei-grechanik would you help to comment?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r463942884



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {

Review comment:
       return ReuseNZ(nz_a, op);




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
MarisaKirisame commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r462239229



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};

Review comment:
       why do you need this?

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>

Review comment:
       should we move this to some common file? I had seen concat being redefined a few time (for example, in the exhaust matcher) @tqchen @jroesch 

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {

Review comment:
       clang and gcc have sophisticated inline strategy. are you sure this is needed?

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {

Review comment:
       refactor this into a function.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {

Review comment:
       perhaps you should store them in conjunctive normal form.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT
+  if (const AndNode* op = cond.as<AndNode>()) {
+    auto pair_a = ImplicationNotContainingVars(op->a, vars);
+    auto pair_b = ImplicationNotContainingVars(op->b, vars);
+    return {pair_a.first && pair_b.first, pair_a.second && pair_b.second};
+  } else if (const OrNode* op = cond.as<OrNode>()) {
+    auto pair_a = ImplicationNotContainingVars(op->a, vars);
+    auto pair_b = ImplicationNotContainingVars(op->b, vars);
+    return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
+                                              (pair_b.first || pair_a.second) &&
+                                              (pair_a.second || pair_b.second)};
+  } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
+    return {cond, const_true()};
+  } else {
+    return {const_true(), cond};
+  }
+}
+
+// Factor conditions out of a reduction by applying Fourier-Motzkin elimination and moving out
+// (in)equalities which do not depend on the reduction variables.
+std::pair<PrimExpr, PrimExpr> LiftConditionsThroughReduction(const PrimExpr& cond,
+                                                             const Array<IterVar>& red_axis,
+                                                             const Array<IterVar>& outer_axis) {
+  // Factor out atomics so that we can consider this as a system of inequalities
+  auto factoratomic_res = FactorOutAtomicFormulas(cond);
+  Array<PrimExpr> atomics = factoratomic_res.atomic_formulas;
+  const PrimExpr& rest = factoratomic_res.rest;
+
+  Array<Var> allvars;
+  for (const IterVar& v : red_axis) {
+    allvars.push_back(v->var);
+  }
+  for (const IterVar& v : outer_axis) {
+    allvars.push_back(v->var);
+  }
+
+  auto vranges = Merge(IterVarsToMap(red_axis), IterVarsToMap(outer_axis));
+  // start from reduction vars, so that input vars don't depend on them
+  arith::IntConstraints ineq_to_solve(allvars, vranges, atomics);
+  auto res_ineq = arith::SolveLinearInequalities(ineq_to_solve);
+  atomics = arith::AsConditions(allvars, res_ineq.first, res_ineq.second);
+
+  // Append the rest part
+  PrimExpr rewritten_cond = All(atomics) && rest;
+
+  std::unordered_set<const VarNode*> vset;
+  for (const IterVar& v : red_axis) {
+    vset.insert(v->var.get());
+  }
+
+  // The outer (first) condition does not contain reduction vars,
+  // the inner (second) condition is everything else
+  auto res = ImplicationNotContainingVars(rewritten_cond, vset);
+  return res;
+}
+
+// Convert an array of itervars to an array of inequalities
+Array<PrimExpr> IterVarsToInequalities(const Array<IterVar>& itervars) {
+  Array<PrimExpr> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(GE(v->var, v->dom->min));
+    res.push_back(LT(v->var, v->dom->min + v->dom->extent));
+  }
+  return res;
+}
+
+class RemoveRedundantInequalitiesMutator : public ExprMutator {
+ public:
+  explicit RemoveRedundantInequalitiesMutator(Array<PrimExpr> known) {
+    for (const PrimExpr& cond : known) {
+      known_.push_back(analyzer_.Simplify(cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE));
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const SelectNode* op) {
+    bool has_side_effect = (SideEffect(GetRef<PrimExpr>(op)) > CallEffectKind::kReadState);
+    PrimExpr new_cond =
+        analyzer_.Simplify(VisitExpr(op->condition), ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (is_one(new_cond) && !has_side_effect) {
+      return VisitExpr(op->true_value);
+    } else if (is_zero(new_cond) && !has_side_effect) {
+      return VisitExpr(op->false_value);
+    } else {
+      Array<PrimExpr> new_known = known_;
+      for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+        new_known.push_back(atomic);
+      }
+      RemoveRedundantInequalitiesMutator new_mutator(new_known);
+      // Note that we mutate only the true value with the new mutator
+      // TODO(sgrechanik-h): Update known conditions for the false value as well
+      return Select(new_cond, new_mutator(op->true_value), VisitExpr(op->false_value));
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const CallNode* op) {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(VisitExpr(op->args[0]), ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (is_one(new_cond)) {
+        return VisitExpr(op->args[1]);
+      } else if (is_zero(new_cond)) {
+        return VisitExpr(op->args[2]);
+      } else {
+        Array<PrimExpr> new_known = known_;
+        for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+          new_known.push_back(atomic);
+        }
+        RemoveRedundantInequalitiesMutator new_mutator(new_known);
+        // Note that we mutate only the true value with the new mutator
+        // TODO(sgrechanik-h): Update known conditions for the false value as well
+        return if_then_else(new_cond, new_mutator(op->args[1]), VisitExpr(op->args[2]));
+      }
+    } else {
+      return ExprMutator::VisitExpr_(op);
+    }
+  }
+
+  virtual PrimExpr VisitExpr_(const ReduceNode* op) {
+    Array<PrimExpr> known_with_axes = known_;
+    for (const PrimExpr& axis_cond : IterVarsToInequalities(op->axis)) {
+      known_with_axes.push_back(axis_cond);
+    }
+    RemoveRedundantInequalitiesMutator mutator_with_axes(known_with_axes);
+
+    PrimExpr new_cond = mutator_with_axes(op->condition);
+
+    Array<PrimExpr> new_known = known_with_axes;
+    for (const PrimExpr& atomic : FactorOutAtomicFormulas(new_cond).atomic_formulas) {
+      new_known.push_back(atomic);
+    }
+    RemoveRedundantInequalitiesMutator new_mutator(new_known);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : op->source) {
+      new_source.push_back(new_mutator(src));
+    }
+
+    return Reduce(op->combiner, new_source, op->axis, new_cond, op->value_index);
+  }
+
+  virtual PrimExpr VisitExpr_(const EQNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const NENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const LTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const LENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const GTNode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+  virtual PrimExpr VisitExpr_(const GENode* op) { return MutateAtomic_(GetRef<PrimExpr>(op)); }
+
+  virtual PrimExpr VisitExpr_(const AndNode* op) { return VisitExpr(op->a) && VisitExpr(op->b); }
+
+ private:
+  PrimExpr MutateAtomic_(const PrimExpr& e) {
+    PrimExpr simplified = analyzer_.Simplify(e, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    for (const PrimExpr& other : known_) {
+      if (ExprDeepEqual()(simplified, other)) {
+        return const_true();
+      }
+    }
+    return simplified;
+  }
+
+  Array<PrimExpr> known_;
+  arith::Analyzer analyzer_;
+};
+
+// Propagate information from conditions and remove redundant inequalities
+inline PrimExpr RemoveRedundantInequalities(const PrimExpr& expr, const Array<PrimExpr>& known) {
+  return RemoveRedundantInequalitiesMutator(known)(expr);
+}
+
+// Extract the given expr under the given condition as a separate tensor if the volume of the
+// extracted tensor will be less than the volume of the outer_axis
+PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
+                            const Array<Var>& outer_axis, const Map<Var, Range>& vranges) {
+  // solve cond, e.g., (jac_i0 == i) && (jac_i1 == j)
+  arith::IntConstraints domain_to_solve(outer_axis, vranges,
+                                        FactorOutAtomicFormulas(cond).to_array());
+  auto res = SimplifyDomain(domain_to_solve);
+
+  arith::Analyzer analyzer;
+  analyzer.Bind(res->dst->ranges);
+  PrimExpr new_expr = analyzer.Simplify(Substitute(expr, res->src_to_dst),
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  // TODO(yzhliu): This is mostly done to simplify if_then_else
+  // which is not realized by the canonical simplifier
+  new_expr = RemoveRedundantInequalities(new_expr, res->dst->relations);
+
+  // Keep only those variables of the new vars which are used in the new_expr
+  Array<Var> used_res_variables;
+  for (const Var& var : res->dst->variables) {
+    if (ExprUseVar(new_expr, var)) {
+      CHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
+      used_res_variables.push_back(var);
+    }
+  }
+
+  // If the expression does not use vars then it is probably better to keep it inlined
+  if (used_res_variables.empty()) {
+    // We can return the new_expr here instead of the old expr because it doesn't use variables
+    // otherwise we would need to replace the new vars or create a let-expression
+    return new_expr;
+  }
+
+  // If it's already tensor[...] then it will probably be useless to further simplify it.
+  if (new_expr.as<ProducerLoadNode>()) {
+    return expr;
+  }
+
+  // Compute volumes before and after
+  PrimExpr old_volume = make_const(DataType::Int(64), 1);
+  for (const Var& var : outer_axis) {
+    CHECK(vranges.count(var)) << "Range of " << var << " was not provided.";
+    old_volume = old_volume * vranges[var]->extent;
+  }
+
+  PrimExpr new_volume = make_const(DataType::Int(64), 1);
+  for (const Var& var : used_res_variables) {
+    new_volume = new_volume * res->dst->ranges[var]->extent;
+  }
+
+  // if we can prove that the old volume is not greater than the new volume then
+  // prefer the old expression.
+  arith::Analyzer ana_vranges;
+  ana_vranges.Bind(vranges);
+  if (ana_vranges.CanProve(old_volume <= new_volume)) {
+    return expr;
+  }
+
+  Tensor tensor = TensorFromExpr(new_expr, IterVarsFromMap(used_res_variables, res->dst->ranges),
+                                 "extracted_tensor");
+
+  Array<PrimExpr> args;
+  for (const Var& var : used_res_variables) {
+    args.push_back(res->dst_to_src[var]);
+  }
+
+  return ProducerLoad(tensor, args);
+}
+
+class FreeVarsVisitor : public StmtExprVisitor {

Review comment:
       move this to common file.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT

Review comment:
       Fix TODO

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {

Review comment:
       make Op::Get(ifthenelse) a member variable of the functor.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r468908098



##########
File path: include/tvm/node/container.h
##########
@@ -1287,6 +1287,18 @@ class Map : public ObjectRef {
     data_ = other.data_;
     return *this;
   }
+  /*!
+   * \brief Merge with another Map. It does not mutate the current one.
+   * \param other Map to be merged.
+   * @return The merged Array. Original Map is kept unchanged.
+   */
+  Map<K, V> Merge(const Map<K, V>& other) const {

Review comment:
       Shall we make it as a global function instead of member? so it is not ambiguitous (that the result is a new map)

##########
File path: include/tvm/runtime/container.h
##########
@@ -956,6 +956,19 @@ class Array : public ObjectRef {
     return static_cast<ArrayNode*>(data_.get());
   }
 
+  /*!
+   * \brief Concat with another Array. It does not mutate the current one.
+   * \param other Array to be concatenated.
+   * @return The concatenated Array. Original Array is kept unchanged.
+   */
+  Array<T> Concat(const Array<T>& other) const {

Review comment:
       Consider make it as a global function, which also enables copy on write on the lhs(this)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#issuecomment-675620287


   @tqchen kindly ping.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#issuecomment-671536959


   @sergei-grechanik I agree. Perhaps we need performance integration test run periodically as mentioned in https://discuss.tvm.ai/t/efforts-on-benchmarking-for-tvm/
   
   @tqchen could you also take a look?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r467257593



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT

Review comment:
       Thanks @sergei-grechanik . I added it in the comments.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r463901829



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {

Review comment:
       do you mean to change the struct & member name to mention CNF?

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {

Review comment:
       do you mean provide a function `bool value_equals(const PrimExpr&)` in NonzeroConditionResult?

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = Concat(domain->variables, elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);
+    tr = arith::SolveInequalitiesDeskewRange(transf->dst);
+    transf = transf + tr;
+  }
+
+  return transf;
+}
+
+// Use the condition of a reduction op to simplify its domain (axis)
+PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map<Var, Range>& outer_vranges) {
+  if (const ReduceNode* red = expr.as<ReduceNode>()) {
+    Array<Var> vars = IterVarsToVars(red->axis);
+    Map<Var, Range> vranges = Merge(outer_vranges, IterVarsToMap(red->axis));
+    Array<PrimExpr> relations = FactorOutAtomicFormulas(red->condition).to_array();
+
+    arith::IntConstraints domain(vars, vranges, relations);
+    auto res = SimplifyDomain(domain);
+
+    Array<PrimExpr> new_source;
+    for (const PrimExpr& src : red->source) {
+      new_source.push_back(Substitute(src, res->src_to_dst));
+    }
+
+    Array<IterVar> new_axis = IterVarsFromMap(res->dst->variables, res->dst->ranges, kCommReduce);
+
+    // Perform simplification mainly to remove a possibly empty reduction.
+    arith::Analyzer analyzer;
+    return analyzer.Simplify(
+        Reduce(red->combiner, new_source, new_axis, All(res->dst->relations), red->value_index),
+        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+  } else {
+    return expr;
+  }
+}
+
+// Extract from cond an implication of cond not containing vars
+std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
+    const PrimExpr& cond, const std::unordered_set<const VarNode*>& vars) {
+  CHECK(cond.dtype().is_bool()) << "The type of cond must be bool";
+  // TODO(sgrechanik-h): NOT

Review comment:
       Actually in my understanding it's not straightforward to separate NOT node here, as the false branch of (!pair.a) will also contain the reduction (instead the zero). I'm not sure whether it provides benefit, @sergei-grechanik would you help to comment?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#issuecomment-670707260


   @MarisaKirisame @sergei-grechanik @tqchen I addressed most of the comments, please take a look again.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r467258411



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+using arith::ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+    return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(op_if_then_else_)) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+  const Op& op_if_then_else_ = Op::Get("tir.if_then_else");
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {
+  std::vector<PrimExpr> atomic_formulas;
+  PrimExpr rest;
+
+  PrimExpr to_expr() const {
+    PrimExpr res = rest;
+    for (const PrimExpr& e : atomic_formulas) {
+      res = And(e, res);
+    }
+    return res;
+  }
+
+  Array<PrimExpr> to_array() const {
+    Array<PrimExpr> res = atomic_formulas;
+    res.push_back(rest);
+    return res;
+  }
+};
+
+// The implementation of FactorOutAtomicFormulas
+class FactorOutAtomicFormulasFunctor
+    : public ExprFunctor<FactorOutAtomicFormulasResult(const PrimExpr&)> {
+ public:
+  result_type Atomic_(const PrimExpr& e) {
+    // For atomic expressions the result is the expr itself with True as the residual
+    return {{e}, make_const(e.dtype(), 1)};
+  }
+
+  // This is basically the list of expression kinds that are considered atomic
+  result_type VisitExpr_(const VarNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const CallNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const EQNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const NENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const LTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GENode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const GTNode* op) final { return Atomic_(GetRef<PrimExpr>(op)); }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    // Select can be rewritten through other logical ops
+    PrimExpr expr = (op->condition && op->true_value) || (!op->condition && op->false_value);
+    return VisitExpr(expr);
+  }
+
+  result_type VisitExpr_(const NotNode* op) final {
+    // Not should be moved down
+    if (const OrNode* or_expr = op->a.as<OrNode>()) {
+      PrimExpr expr = !or_expr->a && !or_expr->b;
+      return VisitExpr(expr);
+    } else if (const AndNode* and_expr = op->a.as<AndNode>()) {
+      PrimExpr expr = !and_expr->a || !and_expr->b;
+      return VisitExpr(expr);
+    } else if (const SelectNode* sel_expr = op->a.as<SelectNode>()) {
+      PrimExpr expr = ((!sel_expr->condition || !sel_expr->true_value) &&
+                       (sel_expr->condition || !sel_expr->false_value));
+      return VisitExpr(expr);
+    }
+    return Atomic_(GetRef<PrimExpr>(op));
+  }
+
+  result_type VisitExpr_(const AndNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    // For the And case we return the union of the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size());
+    std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+    std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(),
+              std::inserter(res_set, res_set.end()));
+
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    // And the residuals are combined with &&
+    return {res, res_a.rest && res_b.rest};
+  }
+
+  result_type VisitExpr_(const MulNode* op) final {
+    // Since we work with bools, for multiplication we do the same thing as for And
+    PrimExpr e_and = op->a && op->b;
+    return VisitExpr(e_and);
+  }
+
+  result_type VisitExpr_(const OrNode* op) final {
+    auto res_a = VisitExpr(op->a);
+    auto res_b = VisitExpr(op->b);
+
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_a_set{
+        res_a.atomic_formulas.begin(), res_a.atomic_formulas.end()};
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_b_set{
+        res_b.atomic_formulas.begin(), res_b.atomic_formulas.end()};
+
+    // For the Or case we intersect the sets of atomic formulas
+    std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> res_set;
+    res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size()));
+    for (const auto& res_b_formula : res_b_set) {
+      if (res_a_set.count(res_b_formula)) {
+        res_set.insert(res_b_formula);
+      }
+    }
+
+    // Computing the residual is more complex: we have to compute the sets of atomic formulas
+    // which are left behind, and then combine them with the residuals into the new residual.
+    std::vector<PrimExpr> new_cond_a;
+    new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_a_set) {
+      if (!res_set.count(formula)) new_cond_a.emplace_back(formula);
+    }
+
+    std::vector<PrimExpr> new_cond_b;
+    new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size());
+    for (const auto& formula : res_b_set) {
+      if (!res_set.count(formula)) new_cond_b.emplace_back(formula);
+    }
+
+    res_a.atomic_formulas = std::move(new_cond_a);
+    res_b.atomic_formulas = std::move(new_cond_b);
+
+    PrimExpr new_rest = res_a.to_expr() || res_b.to_expr();
+    std::vector<PrimExpr> res{res_set.begin(), res_set.end()};
+
+    return {res, new_rest};
+  }
+};
+
+// Transform the given formula into a conjunction of atomic formulas (represented as an array)
+// and a non-atomic residual. Atomic formulas are consts, calls, variables and comparisons (a <= b,
+// etc), i.e. formulas which are not logical operators (||, &&, !) on the top level.
+FactorOutAtomicFormulasResult FactorOutAtomicFormulas(const PrimExpr& e) {
+  CHECK(e.dtype().is_bool());
+  return FactorOutAtomicFormulasFunctor().VisitExpr(e);
+}
+
+struct EliminateDivModResult {
+  PrimExpr expr;
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+};
+
+inline PrimExpr ModImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncmod(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floormod(a, b);
+  }
+}
+
+inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
+  if (mode == kTruncDiv) {
+    return truncdiv(a, b);
+  } else {
+    CHECK_EQ(mode, kFloorDiv);
+    return floordiv(a, b);
+  }
+}
+
+class EliminateDivModMutator : public ExprMutator {
+ public:
+  Map<Var, PrimExpr> substitution;
+  Array<Var> new_variables;
+  Array<PrimExpr> conditions;
+  Map<Var, Range> ranges;
+
+  explicit EliminateDivModMutator(Map<Var, Range> ranges) : ranges(std::move(ranges)) {}
+
+  virtual PrimExpr VisitExpr_(const DivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == -(x/c) for truncated division
+        return make_zero(op->dtype) -
+               VisitExpr(truncdiv(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return truncdiv(mutated_a, op->b);
+      }
+    }
+
+    return truncdiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const ModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == x % c for truncated division
+        return VisitExpr(truncmod(op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kTruncDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kTruncDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return truncmod(mutated_a, op->b);
+      }
+    }
+
+    return truncmod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorDivNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x / -c == (-x) / c for flooring division
+        return VisitExpr(
+            floordiv(make_zero(op->dtype) - op->a, make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.first;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().first;
+      } else {
+        return floordiv(mutated_a, op->b);
+      }
+    }
+
+    return floordiv(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+  virtual PrimExpr VisitExpr_(const FloorModNode* op) {
+    const IntImmNode* imm = op->b.as<IntImmNode>();
+    if (imm && imm->value != 0) {
+      if (imm->value < 0) {
+        // x % -c == -(-x % c) for flooring division
+        return VisitExpr(make_zero(op->dtype) - floormod(make_zero(op->dtype) - op->a,
+                                                         make_const(op->dtype, -imm->value)));
+      }
+
+      // Try to find the already existing variables for this expression
+      auto it = expr_to_vars_.find(std::make_tuple(kFloorDiv, op->a, imm->value));
+      if (it != expr_to_vars_.end()) {
+        return it->second.second;
+      }
+
+      // Otherwise recursively mutate the left hand side, and create new variables
+      PrimExpr mutated_a = VisitExpr(op->a);
+      if (auto var_pair_opt = AddNewVarPair(op->a, mutated_a, imm->value, kFloorDiv)) {
+        return var_pair_opt.value().second;
+      } else {
+        return floormod(mutated_a, op->b);
+      }
+    }
+
+    return floormod(VisitExpr(op->a), VisitExpr(op->b));
+  }
+
+ private:
+  dmlc::optional<std::pair<Var, Var>> AddNewVarPair(const PrimExpr& e, const PrimExpr& mut,
+                                                    int64_t val, DivMode mode) {
+    using tresult = dmlc::optional<std::pair<Var, Var>>;
+
+    // Try to find the variables using the mutated expressions
+    if (!e.same_as(mut)) {
+      auto it = expr_to_vars_.find(std::make_tuple(mode, mut, val));
+      if (it != expr_to_vars_.end()) {
+        return tresult(it->second);
+      }
+    }
+
+    PrimExpr val_e = make_const(e.dtype(), val);
+    idx_ += 1;
+
+    // Convert `ranges` to IntSets
+    std::unordered_map<const VarNode*, IntSet> var_intsets;
+    for (const auto& p : ranges) {
+      var_intsets[p.first.get()] = IntSet::FromRange(p.second);
+    }
+
+    // Infer ranges for the expressions we want to replace with variables
+    Range div_range = EvalSet(DivImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+    Range mod_range = EvalSet(ModImpl(mut, val_e, mode), var_intsets).CoverRange(Range());
+
+    // We don't want to add unbounded variables
+    if (!div_range.get() || !mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << DivImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+    if (!mod_range.get()) {
+      LOG(WARNING) << "EliminateDivMod: won't eliminate " << ModImpl(e, val_e, mode)
+                   << "  because its bounds cannot be inferred";
+      return tresult();
+    }
+
+    // Create new variables for the expressions
+    auto div = Var((mode == kTruncDiv ? "tdiv" : "fdiv") + std::to_string(idx_), e.dtype());
+    auto mod = Var((mode == kTruncDiv ? "tmod" : "fmod") + std::to_string(idx_), e.dtype());
+
+    new_variables.push_back(div);
+    new_variables.push_back(mod);
+
+    // Note that we have to perform substitution to mut because mut may contain new variables
+    substitution.Set(div, DivImpl(Substitute(mut, substitution), val_e, mode));
+    substitution.Set(mod, ModImpl(Substitute(mut, substitution), val_e, mode));
+
+    ranges.Set(div, div_range);
+    ranges.Set(mod, mod_range);
+
+    // This additional condition works as a definition for the new variables
+    conditions.push_back(mut == div * val_e + mod);
+
+    if (!analyzer_.CanProve(mod_range->extent <= val_e)) {
+      // Since we use the C/C++ definition of mod, there may be multiple values of `mod`
+      // satisfying the added condition if the expr `e` may change its sign, so we
+      // have to add another condition.
+      LOG(WARNING) << "EliminateDivMod: cannot fully eliminate div or mod because "
+                   << ModImpl(e, val_e, mode) << "  probably may change its sign";
+      conditions.push_back(Select(e >= 0, mod >= 0, mod <= 0));
+    }
+
+    auto p = std::make_pair(div, mod);
+    expr_to_vars_[std::make_tuple(mode, e, val)] = p;
+    if (!e.same_as(mut)) {
+      expr_to_vars_[std::make_tuple(mode, mut, val)] = p;
+    }
+    return tresult(p);
+  }
+
+  class TupleEqual_ {
+   public:
+    bool operator()(const std::tuple<DivMode, PrimExpr, int64_t>& lhs,
+                    const std::tuple<DivMode, PrimExpr, int64_t>& rhs) const {
+      return std::get<0>(lhs) == std::get<0>(rhs) &&
+             tir::ExprDeepEqual()(std::get<1>(lhs), std::get<1>(rhs)) &&
+             std::get<2>(lhs) == std::get<2>(rhs);
+    }
+  };
+
+  class TupleHasher_ {
+   public:
+    size_t operator()(const std::tuple<DivMode, PrimExpr, int64_t>& key) const {
+      return ((std::hash<int>()(std::get<0>(key)) ^ (StructuralHash()(std::get<1>(key)) << 1)) >>
+              1) ^
+             (std::hash<int64_t>()(std::get<2>(key)) << 1);
+    }
+  };
+
+  // A counter for naming new variables
+  int idx_{0};
+  // A map from pairs of exprs and numbers (e, n) to pairs of new vars (div, mod)
+  // such that `div = e / n` and `mod = e % n`
+  std::unordered_map<std::tuple<DivMode, PrimExpr, int64_t>, std::pair<Var, Var>, TupleHasher_,
+                     TupleEqual_>
+      expr_to_vars_;
+  arith::Analyzer analyzer_;
+};
+
+// Replace every subexpr of the form e/const and e % const with a new variable.
+// Syntactically equal expressions will be mapped to the same variable.
+EliminateDivModResult EliminateDivMod(const PrimExpr& expr, Map<Var, Range> ranges) {
+  EliminateDivModResult res;
+  EliminateDivModMutator mutator(ranges);
+  res.expr = mutator(expr);
+  res.conditions = std::move(mutator.conditions);
+  res.new_variables = std::move(mutator.new_variables);
+  res.substitution = std::move(mutator.substitution);
+  res.ranges = std::move(mutator.ranges);
+  return res;
+}
+
+arith::IntConstraintsTransform EliminateDivModFromDomainConditions(
+    const arith::IntConstraints& domain) {
+  auto elim_res = EliminateDivMod(All(domain->relations), domain->ranges);
+
+  Map<Var, Range> new_vranges = elim_res.ranges;
+  Array<Var> new_axis = domain->variables.Concat(elim_res.new_variables);
+  PrimExpr new_cond = elim_res.expr && All(elim_res.conditions);
+
+  arith::IntConstraints new_domain(new_axis, new_vranges,
+                                   FactorOutAtomicFormulas(new_cond).to_array());
+
+  Map<Var, PrimExpr> src_to_dst;
+  Map<Var, PrimExpr> dst_to_src = elim_res.substitution;
+  for (const Var& v : domain->variables) {
+    src_to_dst.Set(v, v);
+    dst_to_src.Set(v, v);
+  }
+
+  return arith::IntConstraintsTransform(domain, new_domain, src_to_dst, dst_to_src);
+}
+
+// Simplify an iteration domain.
+inline arith::IntConstraintsTransform IdentityTransformation(const arith::IntConstraints& domain) {
+  Map<Var, PrimExpr> identity_map;
+  for (const Var& v : domain->variables) {
+    identity_map.Set(v, v);
+  }
+  return arith::IntConstraintsTransform(domain, domain, identity_map, identity_map);
+}
+
+arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_domains,
+                                              bool eliminate_div_mod) {
+  arith::IntConstraintsTransform transf = IdentityTransformation(iter_domains);
+
+  if (eliminate_div_mod) {
+    transf = transf + EliminateDivModFromDomainConditions(transf->dst);
+  }
+
+  // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably
+  // should find a better terminating criterion (like stop when the domain volume stops decreasing)
+  // Also 2 steps seems to be slightly better than 3
+  for (size_t i = 0; i < 2; ++i) {
+    arith::IntConstraintsTransform tr = arith::SolveLinearEquations(transf->dst);
+    transf = transf + tr;
+    // TODO(sgrechanik-h): This helps for some artificial examples, however I'm not sure about
+    // enabling it in general. The problem it solves is propagating equalities of outer vars.
+    // tr = AddOuterVariablesIntoDomain(transf->dst);

Review comment:
       removed.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] yzhliu commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

Posted by GitBox <gi...@apache.org>.
yzhliu commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r467256980



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {

Review comment:
       If this is the only place to use the function, do we really need to create a new function? correct me if I misunderstood.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org