You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/19 09:05:50 UTC

[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #6078: [Autodiff] Optimize and eliminate the Jacobian tensor for te.autodiff

MarisaKirisame commented on a change in pull request #6078:
URL: https://github.com/apache/incubator-tvm/pull/6078#discussion_r456881913



##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by autodiff.
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0], 3), 0)) {

Review comment:
       refactor the magic number 3.

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by autodiff.

Review comment:
       All in all I found this file confusing. Can you describe what compiler pass/what simplification is it doing?

##########
File path: src/te/autodiff/ad_simplify.cc
##########
@@ -0,0 +1,1266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ad_simplify.cc
+ * \brief Simplify tensor compute generated by autodiff.
+ */
+#include <dmlc/optional.h>
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/int_solver.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/autodiff.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <memory>
+#include <utility>
+
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+using arith::DivMode;
+using arith::kFloorDiv;
+using arith::kTruncDiv;
+
+template <class K, class V>
+Map<K, V> Merge(Map<K, V> original, const Map<K, V>& update) {
+  for (const auto& p : update) {
+    original.Set(p.first, p.second);
+  }
+  return std::move(original);
+}
+
+// Concatenate two arrays
+template <class T>
+Array<T> Concat(Array<T> a, const Array<T>& b) {
+  for (const auto& x : b) {
+    a.push_back(x);
+  }
+  return std::move(a);
+}
+
+// Combine all expressions from the container using &&.
+template <class container>
+PrimExpr All(const container& c) {
+  PrimExpr res;
+  for (const auto& e : c) {
+    if (res.get()) {
+      res = res && e;
+    } else {
+      res = e;
+    }
+  }
+  if (res.get()) {
+    return res;
+  } else {
+    return const_true();
+  }
+}
+
+Map<Var, Range> IterVarsToMap(const Array<IterVar>& itervars) {
+  Map<Var, Range> res;
+  for (const IterVar& v : itervars) {
+    res.Set(v->var, v->dom);
+  }
+  return res;
+}
+
+// Given a map from vars to ranges create an array of itervars
+Array<IterVar> IterVarsFromMap(const Array<Var>& vars, const Map<Var, Range>& vranges,
+                               IterVarType iter_type = kDataPar, std::string thread_tag = "") {
+  Array<IterVar> res;
+  for (const Var& v : vars) {
+    CHECK(vranges.count(v)) << "A range for the variable " << v << " was not provided in map "
+                            << vranges;
+    res.push_back(IterVar(vranges[v], v, iter_type, thread_tag));
+  }
+  return res;
+}
+
+Array<Var> IterVarsToVars(const Array<IterVar>& itervars) {
+  Array<Var> res;
+  for (const IterVar& v : itervars) {
+    res.push_back(v->var);
+  }
+  return res;
+}
+
+template <typename ValueType>
+inline bool is_const_value(const PrimExpr& e, ValueType value) {
+  static_assert(std::is_integral<ValueType>::value,
+                "Comparison to non-integer values is forbidden.");
+  if (const tir::IntImmNode* i = e.as<tir::IntImmNode>()) {
+    return i->value == value;
+  } else if (const tir::FloatImmNode* i = e.as<tir::FloatImmNode>()) {
+    return i->value == value;
+  } else if (const tir::CastNode* c = e.as<tir::CastNode>()) {
+    return is_const_value(c->value, value);
+  } else if (const tir::BroadcastNode* b = e.as<tir::BroadcastNode>()) {
+    return is_const_value(b->value, value);
+  } else {
+    return false;
+  }
+}
+
+// Return true if this combiner is just a sum.
+bool IsSumCombiner(const CommReducer& combiner, const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (combiner->result.size() != 1) {
+    return false;
+  }
+
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[0], 3), 0)) {
+    return false;
+  }
+
+  PrimExpr combiner_result = analyzer.Simplify(combiner->result[0], 3);
+
+  return tir::ExprDeepEqual()(combiner_result, combiner->lhs[0] + combiner->rhs[0]) ||
+         tir::ExprDeepEqual()(combiner_result, combiner->rhs[0] + combiner->lhs[0]);
+}
+
+bool CanFactorZeroFromCombiner(const CommReducer& combiner, int value_index,
+                               const Map<Var, Range>& vranges) {
+  arith::Analyzer analyzer;
+  analyzer.Bind(vranges);
+  if (!is_const_value(analyzer.Simplify(combiner->identity_element[value_index], 3), 0)) {
+    return false;
+  }
+
+  PrimExpr zero = make_zero(combiner->result[value_index].dtype());
+  PrimExpr in = Substitute(combiner->result[value_index], {{combiner->lhs[value_index], zero},
+                                                           {combiner->rhs[value_index], zero}});
+  in = analyzer.Simplify(in, 3);

Review comment:
       same everywhere - make it a constant or an argument




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org