You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/02/03 18:25:43 UTC

[tvm] 01/01: convert argwhere(full(const)) to reshape(arange())

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch simplify_full_argwhere
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f61a8b2e77c804f378d153b6347e42bd451f6b39
Author: mbrookhart <mb...@octoml.ai>
AuthorDate: Wed Feb 3 09:57:17 2021 -0700

    convert argwhere(full(const)) to reshape(arange())
---
 src/relay/op/make_op.h                        |  6 +++
 src/relay/op/tensor/unary.cc                  |  6 ++-
 src/relay/transforms/simplify_expr.cc         | 70 ++++++++++++++++++++++-----
 tests/python/relay/test_pass_simplify_expr.py | 36 ++++++++++++++
 4 files changed, 105 insertions(+), 13 deletions(-)

diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 2b05290..79f7e13 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -100,6 +100,12 @@ Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,
 
 Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value);
 
+Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);
+
+Expr MakeShapeOf(Expr data, DataType dtype);
+
+Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_OP_MAKE_OP_H_
diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc
index e17bdc0..3e82b92 100644
--- a/src/relay/op/tensor/unary.cc
+++ b/src/relay/op/tensor/unary.cc
@@ -430,12 +430,14 @@ Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& in
   return {topi::shape(inputs[0], param->dtype)};
 }
 
-TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) {
+Expr MakeShapeOf(Expr data, DataType dtype) {
   auto attrs = make_object<ShapeOfAttrs>();
   attrs->dtype = dtype;
   static const Op& op = Op::Get("shape_of");
   return Call(op, {data}, Attrs(attrs), {});
-});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed(MakeShapeOf);
 
 RELAY_REGISTER_OP("shape_of")
     .describe(R"code(Returns a tensor representing the shape of a tensor.
diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index 0f78c26..2185723 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -29,15 +29,28 @@
 #include <tvm/support/logging.h>
 
 #include "../op/tensor/transform.h"
+#include "pattern_utils.h"
 
 namespace tvm {
 namespace relay {
 
+class SimplifyPattern {
+ public:
+  virtual Expr callback(const Expr& pre, const Expr& post,
+                        const Map<DFPattern, Array<Expr>>& node_map) const = 0;
+
+  DFPattern pattern() const { return pattern_; }
+
+ protected:
+  /*! \brief Pattern for rewriting */
+  DFPattern pattern_;
+};
+
 /*!
  * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
  *   and merges into one reshape op.
  */
-class SimplifyReshape {
+class SimplifyReshape : public SimplifyPattern {
  public:
   SimplifyReshape() {
     x_ = WildcardPattern(make_object<WildcardPatternNode>());
@@ -46,7 +59,8 @@ class SimplifyReshape {
     pattern_ = reshape1({reshape2({x_})});
   }
 
-  Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
+  Expr callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
     auto x = node_map[x_][0];
     bool const_shape = true;
     Array<Integer> newshape;
@@ -63,13 +77,45 @@ class SimplifyReshape {
     return post;
   }
 
-  DFPattern pattern() const { return pattern_; }
+ private:
+  /*! \brief Pattern input */
+  DFPattern x_;
+};
+
+/*!
+ * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
+ */
+class FullArgwhere : public SimplifyPattern {
+ public:
+  FullArgwhere() {
+    x_ = ConstantPattern(make_object<ConstantPatternNode>());
+    full_ = IsOp("full")({x_}) ||
+            IsOp("dyn.full")({x_, WildcardPattern(make_object<WildcardPatternNode>())});
+    pattern_ = IsOp("argwhere")({full_});
+  }
+
+  Expr callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    auto x = node_map[x_][0];
+    auto dtype = pre->checked_type_.as<TensorTypeNode>()->dtype;
+    auto shape = pre->checked_type_.as<TensorTypeNode>()->shape;
+    if (IsConstScalar(x) && shape.size() == 2) {
+      auto x_val = ToScalar(x.as<ConstantNode>()->data);
+      if (x_val > 0) {
+        Expr start = MakeConstantScalar(dtype, 0);
+        Expr end = MakeTake(MakeShapeOf(node_map[full_][0], dtype), start, Integer(0), "clip");
+        Expr step = MakeConstantScalar(dtype, 1);
+        return MakeReshape(MakeArange(start, end, step, dtype), {-1, 1});
+      }
+    }
+    return post;
+  }
 
  private:
   /*! \brief Pattern input */
   DFPattern x_;
-  /*! \brief Pattern for consecutive reshape or reverse_reshape ops */
-  DFPattern pattern_;
+  /*! \brief Full op */
+  DFPattern full_;
 };
 
 /*!
@@ -78,22 +124,24 @@ class SimplifyReshape {
 class ExprSimplifier {
  public:
   explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
+    CreateCallback(SimplifyReshape());
+    CreateCallback(FullArgwhere());
+  }
+  template <typename T>
+  void CreateCallback(const T& pattern) {
+    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
       Expr pre = args[0];
       Expr post = args[1];
       Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = simplify_reshape_.callback(pre, post, node_map);
+      *rv = pattern.callback(pre, post, node_map);
     };
-    callbacks_.push_back(
-        DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true));
+    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
   }
 
   Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
 
  private:
   IRModule mod_;
-  /*! \brief Simplify reshape pattern */
-  SimplifyReshape simplify_reshape_;
   /*! \brief Callbacks for expr simplification */
   Array<DFPatternCallback> callbacks_;
 };
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index b57abc6..0457946 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -58,5 +58,41 @@ def test_simplify_reshape():
     assert tvm.ir.structural_equal(zz, after)
 
 
+def test_simplify_full_argwhere():
+    def before():
+        x = relay.const(1)
+        y = relay.full(x, [128], dtype="int64")
+        z = relay.argwhere(y)
+        return z
+
+    def expected():
+        x = relay.const(1)
+        y = relay.full(x, [128], dtype="int64")
+        start = relay.const(0)
+        end = relay.take(relay.shape_of(full, "int32"), [0], 0)
+        step = relay.const(1)
+        y = relay.arange(start, end, step, dtype="int32")
+        z = relay.reshape(y, [-1, 1])
+        return z
+
+    z = before()
+    zz = run_opt_pass(z, transform.SimplifyExpr())
+    after = run_opt_pass(expected(), transform.InferType())
+    assert tvm.ir.structural_equal(zz, after)
+
+    mod1 = tvm.IRModule.from_expr(z)
+    mod2 = tvm.IRModule.from_expr(zz)
+
+    with relay.build_config(disabled_pass="SimplifyExpr"):
+        ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm")
+    ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm")
+
+    result1 = ex1.evaluate()()
+    result2 = ex2.evaluate()()
+
+    tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy())
+
+
 if __name__ == "__main__":
     test_simplify_reshape()
+    test_simplify_full_argwhere()