You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/02/03 18:25:42 UTC

[tvm] branch simplify_full_argwhere created (now f61a8b2)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a change to branch simplify_full_argwhere
in repository https://gitbox.apache.org/repos/asf/tvm.git.


      at f61a8b2  convert argwhere(full(const)) to reshape(arange())

This branch includes the following new commits:

     new f61a8b2  convert argwhere(full(const)) to reshape(arange())

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[tvm] 01/01: convert argwhere(full(const)) to reshape(arange())

Posted by mb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch simplify_full_argwhere
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f61a8b2e77c804f378d153b6347e42bd451f6b39
Author: mbrookhart <mb...@octoml.ai>
AuthorDate: Wed Feb 3 09:57:17 2021 -0700

    convert argwhere(full(const)) to reshape(arange())
---
 src/relay/op/make_op.h                        |  6 +++
 src/relay/op/tensor/unary.cc                  |  6 ++-
 src/relay/transforms/simplify_expr.cc         | 70 ++++++++++++++++++++++-----
 tests/python/relay/test_pass_simplify_expr.py | 36 ++++++++++++++
 4 files changed, 105 insertions(+), 13 deletions(-)

diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 2b05290..79f7e13 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -100,6 +100,12 @@ Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,
 
 Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value);
 
+Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);
+
+Expr MakeShapeOf(Expr data, DataType dtype);
+
+Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_OP_MAKE_OP_H_
diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc
index e17bdc0..3e82b92 100644
--- a/src/relay/op/tensor/unary.cc
+++ b/src/relay/op/tensor/unary.cc
@@ -430,12 +430,14 @@ Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& in
   return {topi::shape(inputs[0], param->dtype)};
 }
 
-TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) {
+Expr MakeShapeOf(Expr data, DataType dtype) {
   auto attrs = make_object<ShapeOfAttrs>();
   attrs->dtype = dtype;
   static const Op& op = Op::Get("shape_of");
   return Call(op, {data}, Attrs(attrs), {});
-});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed(MakeShapeOf);
 
 RELAY_REGISTER_OP("shape_of")
     .describe(R"code(Returns a tensor representing the shape of a tensor.
diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index 0f78c26..2185723 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -29,15 +29,28 @@
 #include <tvm/support/logging.h>
 
 #include "../op/tensor/transform.h"
+#include "pattern_utils.h"
 
 namespace tvm {
 namespace relay {
 
+class SimplifyPattern {
+ public:
+  virtual Expr callback(const Expr& pre, const Expr& post,
+                        const Map<DFPattern, Array<Expr>>& node_map) const = 0;
+
+  DFPattern pattern() const { return pattern_; }
+
+ protected:
+  /*! \brief Pattern for rewriting */
+  DFPattern pattern_;
+};
+
 /*!
  * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
  *   and merges into one reshape op.
  */
-class SimplifyReshape {
+class SimplifyReshape : public SimplifyPattern {
  public:
   SimplifyReshape() {
     x_ = WildcardPattern(make_object<WildcardPatternNode>());
@@ -46,7 +59,8 @@ class SimplifyReshape {
     pattern_ = reshape1({reshape2({x_})});
   }
 
-  Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
+  Expr callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
     auto x = node_map[x_][0];
     bool const_shape = true;
     Array<Integer> newshape;
@@ -63,13 +77,45 @@ class SimplifyReshape {
     return post;
   }
 
-  DFPattern pattern() const { return pattern_; }
+ private:
+  /*! \brief Pattern input */
+  DFPattern x_;
+};
+
+/*!
+ * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
+ */
+class FullArgwhere : public SimplifyPattern {
+ public:
+  FullArgwhere() {
+    x_ = ConstantPattern(make_object<ConstantPatternNode>());
+    full_ = IsOp("full")({x_}) ||
+            IsOp("dyn.full")({x_, WildcardPattern(make_object<WildcardPatternNode>())});
+    pattern_ = IsOp("argwhere")({full_});
+  }
+
+  Expr callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    auto x = node_map[x_][0];
+    auto dtype = pre->checked_type_.as<TensorTypeNode>()->dtype;
+    auto shape = pre->checked_type_.as<TensorTypeNode>()->shape;
+    if (IsConstScalar(x) && shape.size() == 2) {
+      auto x_val = ToScalar(x.as<ConstantNode>()->data);
+      if (x_val > 0) {
+        Expr start = MakeConstantScalar(dtype, 0);
+        Expr end = MakeTake(MakeShapeOf(node_map[full_][0], dtype), start, Integer(0), "clip");
+        Expr step = MakeConstantScalar(dtype, 1);
+        return MakeReshape(MakeArange(start, end, step, dtype), {-1, 1});
+      }
+    }
+    return post;
+  }
 
  private:
   /*! \brief Pattern input */
   DFPattern x_;
-  /*! \brief Pattern for consecutive reshape or reverse_reshape ops */
-  DFPattern pattern_;
+  /*! \brief Full op */
+  DFPattern full_;
 };
 
 /*!
@@ -78,22 +124,24 @@ class SimplifyReshape {
 class ExprSimplifier {
  public:
   explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
+    CreateCallback(SimplifyReshape());
+    CreateCallback(FullArgwhere());
+  }
+  template <typename T>
+  void CreateCallback(const T& pattern) {
+    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
       Expr pre = args[0];
       Expr post = args[1];
       Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = simplify_reshape_.callback(pre, post, node_map);
+      *rv = pattern.callback(pre, post, node_map);
     };
-    callbacks_.push_back(
-        DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true));
+    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
   }
 
   Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
 
  private:
   IRModule mod_;
-  /*! \brief Simplify reshape pattern */
-  SimplifyReshape simplify_reshape_;
   /*! \brief Callbacks for expr simplification */
   Array<DFPatternCallback> callbacks_;
 };
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index b57abc6..0457946 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -58,5 +58,41 @@ def test_simplify_reshape():
     assert tvm.ir.structural_equal(zz, after)
 
 
+def test_simplify_full_argwhere():
+    def before():
+        x = relay.const(1)
+        y = relay.full(x, [128], dtype="int64")
+        z = relay.argwhere(y)
+        return z
+
+    def expected():
+        x = relay.const(1)
+        y = relay.full(x, [128], dtype="int64")
+        start = relay.const(0)
+        end = relay.take(relay.shape_of(full, "int32"), [0], 0)
+        step = relay.const(1)
+        y = relay.arange(start, end, step, dtype="int32")
+        z = relay.reshape(y, [-1, 1])
+        return z
+
+    z = before()
+    zz = run_opt_pass(z, transform.SimplifyExpr())
+    after = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(zz, after)
+
+    mod1 = tvm.IRModule.from_expr(z)
+    mod2 = tvm.IRModule.from_expr(zz)
+
+    with relay.build_config(disabled_pass="SimplifyExpr"):
+        ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm")
+    ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm")
+
+    result1 = ex1.evaluate()()
+    result2 = ex2.evaluate()()
+
+    tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy())
+
+
 if __name__ == "__main__":
     test_simplify_reshape()
+    test_simplify_full_argwhere()