You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/08/01 09:12:43 UTC

[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

MarisaKirisame commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r463942670



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by tensor-level autodiff.
+ *
+ * The major simplification we do in this file is to eliminate
+ * the Jacobian tensor created by autodiff.
+ *
+ * Jacobian tensor is sparse because one output element usually relates
+ * to a small portion of the inputs. For example, element-wise function has a one-to-one mapping
+ * between input tensor and output tensor, thus the Jacobian is diagonal.
+ *
+ * Generally, we have Out_{\beta} = f( In_{A \alpha} ) in which A is a matrix,
+ * \alpha and \beta are vectors represent the indices of In and Out respectively.
+ * i.e., the non-zero Jacobian indices is a linear combination of the input indices.
+ * Thereby we solve linear equations of \beta = A \alpha,
+ * as well as linear inequalities of their domain ranges.
+ *
+ * Refer to Urban S, van der Smagt P. Automatic differentiation for tensor algebras[J].
+ * arXiv preprint arXiv:1711.01348, 2017. for more details.
+ *
+ * Implement-wise, we extract the equations in the compute definition via NonzeronessCondition,
+ * replace the compute expression with solved new axes, and create a selection node
+ * (non-zero-condition ? new_compute_expression : 0).
+ *
+ * Due to TVM's restriction, we also lift the reduction to the top of the compute stage.
+ *
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result =
+      analyzer.Simplify(combiner->result[0], ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index],
+                                        ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE),
+                      0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+  return is_const_value(in, 0);
+}
+
+struct NonzeroConditionResult {
+  PrimExpr cond;
+  PrimExpr value;
+
+  PrimExpr to_expr() const { return Select(cond, value, make_zero(value.dtype())); }
+
+  friend std::ostream& operator<<(std::ostream& os, const NonzeroConditionResult& r) {
+    return os << r.to_expr();
+  }
+};
+
+// The implementation of NonzeroCondition
+// transform expression to cond ? value : 0
+class NonzeroConditionFunctor : public ExprFunctor<NonzeroConditionResult(const PrimExpr&)> {
+ public:
+  NonzeroConditionResult NonzeroCondition(const PrimExpr& e) {
+    if (e.dtype().is_bool()) {
+      // Boolean expressions are non-zero whenever they are true themselves
+      return {e, const_true()};
+    } else {
+      return VisitExpr(e);
+    }
+  }
+
+  // Most of the cases are implemented using helpers below
+  result_type VisitExpr_(const VarNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const IntImmNode* op) final { return Const_(GetRef<IntImm>(op)); }
+  result_type VisitExpr_(const FloatImmNode* op) final { return Const_(GetRef<FloatImm>(op)); }
+  result_type VisitExpr_(const StringImmNode* op) final { return Default_(GetRef<PrimExpr>(op)); }
+  result_type VisitExpr_(const AddNode* op) final { return BinOpAddLike_(GetRef<Add>(op)); }
+  result_type VisitExpr_(const SubNode* op) final { return BinOpAddLike_(GetRef<Sub>(op)); }
+  result_type VisitExpr_(const MulNode* op) final { return BinOpMulLike_(GetRef<Mul>(op)); }
+  result_type VisitExpr_(const DivNode* op) final { return BinOpDivLike_(GetRef<Div>(op)); }
+  result_type VisitExpr_(const ModNode* op) final { return BinOpDivLike_(GetRef<Mod>(op)); }
+  result_type VisitExpr_(const FloorDivNode* op) final {
+    return BinOpDivLike_(GetRef<FloorDiv>(op));
+  }
+  result_type VisitExpr_(const FloorModNode* op) final {
+    return BinOpDivLike_(GetRef<FloorMod>(op));
+  }
+  result_type VisitExpr_(const MinNode* op) final { return BinOpAddLike_(GetRef<Min>(op)); }
+  result_type VisitExpr_(const MaxNode* op) final { return BinOpAddLike_(GetRef<Max>(op)); }
+
+  result_type VisitExpr_(const CastNode* op) final {
+    auto nz_a = NonzeroCondition(op->value);
+
+    if (nz_a.value.same_as(op->value)) {
+      return {nz_a.cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {nz_a.cond, Cast(op->dtype, nz_a.value)};
+    }
+  }
+
+  result_type VisitExpr_(const SelectNode* op) final {
+    PrimExpr cond = op->condition, true_val = op->true_value, false_val = op->false_value;
+    auto nz_a = NonzeroCondition(true_val);
+    auto nz_b = NonzeroCondition(false_val);
+
+    // If the false part is zero, we can get rid of the select
+    if (is_const_value(nz_b.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond && cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_a.value};
+    }
+
+    // If the true part is zero, we can also get rid of the select
+    if (is_const_value(nz_a.value, 0)) {
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_b.cond && !cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      return {new_cond, nz_b.value};
+    }
+
+    // Otherwise we retain the select and combine the conditions into this
+    PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                           ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+    if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+      return {new_cond, GetRef<PrimExpr>(op)};
+    } else {
+      return {new_cond, Select(cond, nz_a.value, nz_b.value)};
+    }
+  }
+
+  result_type VisitExpr_(const CallNode* op) final {
+    if (op->op.same_as(Op::Get("tir.if_then_else"))) {
+      PrimExpr cond = op->args[0], true_val = op->args[1], false_val = op->args[2];
+      auto nz_a = NonzeroCondition(true_val);
+      auto nz_b = NonzeroCondition(false_val);
+
+      // We don't have as much freedom here as in the select case
+      // since the `if` must be preserved in any case
+      PrimExpr new_cond = analyzer_.Simplify((cond && nz_a.cond) || (!cond && nz_b.cond),
+                                             ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      if (nz_a.value.same_as(true_val) && nz_b.value.same_as(false_val)) {
+        return {new_cond, GetRef<PrimExpr>(op)};
+      } else {
+        return {new_cond, if_then_else(cond, nz_a.value, nz_b.value)};
+      }
+    } else {
+      return Default_(GetRef<PrimExpr>(op));
+    }
+  }
+
+  result_type VisitExpr_(const ProducerLoadNode* op) final {
+    return Default_(GetRef<PrimExpr>(op));
+  }
+
+  NonzeroConditionResult Default_(const PrimExpr& e) {
+    // This is always correct, so it's the default
+    return {const_true(), e};
+  }
+
+  template <class T>
+  NonzeroConditionResult Const_(const T& op) {
+    if (op->value == 0) {
+      return {const_false(), op};
+    } else {
+      return {const_true(), op};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpAddLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For addition and similar ops the result may be nonzero if either of the arguments is
+    // nonzero, so we combine the conditions with Or.
+    if (tir::ExprDeepEqual()(nz_a.cond, nz_b.cond)) {
+      // If the conditions are the same, we don't need Or
+      if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+        return {nz_a.cond, op};
+      } else {
+        return {nz_a.cond, T(nz_a.value, nz_b.value)};
+      }
+    } else {
+      // Otherwise use Or
+      PrimExpr new_cond =
+          analyzer_.Simplify(nz_a.cond || nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+      // A little optimization: if the combined condition is the same as one of the inner
+      // conditions, we don't need to guard the inner value with a select, otherwise
+      // we create a select in the `to_expr` call.
+      PrimExpr new_a = tir::ExprDeepEqual()(nz_a.cond, new_cond) ? nz_a.value : nz_a.to_expr();
+      PrimExpr new_b = tir::ExprDeepEqual()(nz_b.cond, new_cond) ? nz_b.value : nz_b.to_expr();
+      PrimExpr new_expr = T(new_a, new_b);
+      return {new_cond, new_expr};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpMulLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+    auto nz_b = NonzeroCondition(op->b);
+
+    // For multiplication and similar ops the result may be nonzero if
+    // both the arguments are nonzero, so we combine with And.
+    PrimExpr new_cond =
+        analyzer_.Simplify(nz_a.cond && nz_b.cond, ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE);
+
+    if (nz_a.value.same_as(op->a) && nz_b.value.same_as(op->b)) {
+      return {new_cond, op};
+    } else {
+      return {new_cond, T(nz_a.value, nz_b.value)};
+    }
+  }
+
+  template <class T>
+  NonzeroConditionResult BinOpDivLike_(const T& op) {
+    auto nz_a = NonzeroCondition(op->a);
+
+    // For Div we simply use the condition of the numerator.
+
+    if (nz_a.value.same_as(op->a)) {
+      return {nz_a.cond, op};
+    } else {
+      return {nz_a.cond, T(nz_a.value, op->b)};
+    }
+  }
+
+ private:
+  arith::Analyzer analyzer_;
+};
+
+inline NonzeroConditionResult NonzeronessCondition(const PrimExpr& expr) {
+  return NonzeroConditionFunctor().NonzeroCondition(expr);
+}
+
+struct FactorOutAtomicFormulasResult {

Review comment:
       your approach is pure conjunction. I was suggesting maybe a disjunction of conjunction will allow more optimization ability. it is fine if you dont do it.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -63,6 +63,7 @@ namespace te {
 using arith::DivMode;
 using arith::kFloorDiv;
 using arith::kTruncDiv;
+using arith::ARITH_SIMPLIFY_REWRITE_CANONICAL_REWRITE;
 
 template <class K, class V>
 Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {

Review comment:
       same as merge... lets move them both?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org