You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/02/03 18:25:43 UTC
[tvm] 01/01: convert argwhere(full(const)) to reshape(arange())
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch simplify_full_argwhere
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit f61a8b2e77c804f378d153b6347e42bd451f6b39
Author: mbrookhart <mb...@octoml.ai>
AuthorDate: Wed Feb 3 09:57:17 2021 -0700
convert argwhere(full(const)) to reshape(arange())
---
src/relay/op/make_op.h | 6 +++
src/relay/op/tensor/unary.cc | 6 ++-
src/relay/transforms/simplify_expr.cc | 70 ++++++++++++++++++++++-----
tests/python/relay/test_pass_simplify_expr.py | 36 ++++++++++++++
4 files changed, 105 insertions(+), 13 deletions(-)
diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h
index 2b05290..79f7e13 100644
--- a/src/relay/op/make_op.h
+++ b/src/relay/op/make_op.h
@@ -100,6 +100,12 @@ Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,
Expr MakeSparseToDense(Expr indices, Array<Integer> output_shape, Expr values, Expr default_value);
+Expr MakeArange(Expr start, Expr stop, Expr step, DataType dtype);
+
+Expr MakeShapeOf(Expr data, DataType dtype);
+
+Expr MakeTake(Expr data, Expr indices, Integer axis, String mode);
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MAKE_OP_H_
diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc
index e17bdc0..3e82b92 100644
--- a/src/relay/op/tensor/unary.cc
+++ b/src/relay/op/tensor/unary.cc
@@ -430,12 +430,14 @@ Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& in
return {topi::shape(inputs[0], param->dtype)};
}
-TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed([](Expr data, DataType dtype) {
+Expr MakeShapeOf(Expr data, DataType dtype) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("shape_of");
return Call(op, {data}, Attrs(attrs), {});
-});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.shape_of").set_body_typed(MakeShapeOf);
RELAY_REGISTER_OP("shape_of")
.describe(R"code(Returns a tensor representing the shape of a tensor.
diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index 0f78c26..2185723 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -29,15 +29,28 @@
#include <tvm/support/logging.h>
#include "../op/tensor/transform.h"
+#include "pattern_utils.h"
namespace tvm {
namespace relay {
+class SimplifyPattern {
+ public:
+ virtual Expr callback(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const = 0;
+
+ DFPattern pattern() const { return pattern_; }
+
+ protected:
+ /*! \brief Pattern for rewriting */
+ DFPattern pattern_;
+};
+
/*!
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
* and merges into one reshape op.
*/
-class SimplifyReshape {
+class SimplifyReshape : public SimplifyPattern {
public:
SimplifyReshape() {
x_ = WildcardPattern(make_object<WildcardPatternNode>());
@@ -46,7 +59,8 @@ class SimplifyReshape {
pattern_ = reshape1({reshape2({x_})});
}
- Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
+ Expr callback(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const override {
auto x = node_map[x_][0];
bool const_shape = true;
Array<Integer> newshape;
@@ -63,13 +77,45 @@ class SimplifyReshape {
return post;
}
- DFPattern pattern() const { return pattern_; }
+ private:
+ /*! \brief Pattern input */
+ DFPattern x_;
+};
+
+/*!
+ * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
+ */
+class FullArgwhere : public SimplifyPattern {
+ public:
+ FullArgwhere() {
+ x_ = ConstantPattern(make_object<ConstantPatternNode>());
+ full_ = IsOp("full")({x_}) ||
+ IsOp("dyn.full")({x_, WildcardPattern(make_object<WildcardPatternNode>())});
+ pattern_ = IsOp("argwhere")({full_});
+ }
+
+ Expr callback(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const override {
+ auto x = node_map[x_][0];
+ auto dtype = pre->checked_type_.as<TensorTypeNode>()->dtype;
+ auto shape = pre->checked_type_.as<TensorTypeNode>()->shape;
+ if (IsConstScalar(x) && shape.size() == 2) {
+ auto x_val = ToScalar(x.as<ConstantNode>()->data);
+ if (x_val > 0) {
+ Expr start = MakeConstantScalar(dtype, 0);
+ Expr end = MakeTake(MakeShapeOf(node_map[full_][0], dtype), start, Integer(0), "clip");
+ Expr step = MakeConstantScalar(dtype, 1);
+ return MakeReshape(MakeArange(start, end, step, dtype), {-1, 1});
+ }
+ }
+ return post;
+ }
private:
/*! \brief Pattern input */
DFPattern x_;
- /*! \brief Pattern for consecutive reshape or reverse_reshape ops */
- DFPattern pattern_;
+ /*! \brief Full op */
+ DFPattern full_;
};
/*!
@@ -78,22 +124,24 @@ class SimplifyReshape {
class ExprSimplifier {
public:
explicit ExprSimplifier(IRModule mod) : mod_(mod) {
- auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
+ CreateCallback(SimplifyReshape());
+ CreateCallback(FullArgwhere());
+ }
+ template <typename T>
+ void CreateCallback(const T& pattern) {
+ auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
Expr pre = args[0];
Expr post = args[1];
Map<DFPattern, Array<Expr>> node_map = args[2];
- *rv = simplify_reshape_.callback(pre, post, node_map);
+ *rv = pattern.callback(pre, post, node_map);
};
- callbacks_.push_back(
- DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true));
+ callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
}
Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
private:
IRModule mod_;
- /*! \brief Simplify reshape pattern */
- SimplifyReshape simplify_reshape_;
/*! \brief Callbacks for expr simplification */
Array<DFPatternCallback> callbacks_;
};
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index b57abc6..0457946 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -58,5 +58,41 @@ def test_simplify_reshape():
assert tvm.ir.structural_equal(zz, after)
+def test_simplify_full_argwhere():
+ def before():
+ x = relay.const(1)
+ y = relay.full(x, [128], dtype="int64")
+ z = relay.argwhere(y)
+ return z
+
+ def expected():
+ x = relay.const(1)
+ y = relay.full(x, [128], dtype="int64")
+ start = relay.const(0)
+ end = relay.take(relay.shape_of(full, "int32"), [0], 0)
+ step = relay.const(1)
+ y = relay.arange(start, end, step, dtype="int32")
+ z = relay.reshape(y, [-1, 1])
+ return z
+
+ z = before()
+ zz = run_opt_pass(z, transform.SimplifyExpr())
+ after = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(zz, after)
+
+ mod1 = tvm.IRModule.from_expr(z)
+ mod2 = tvm.IRModule.from_expr(zz)
+
+ with relay.build_config(disabled_pass="SimplifyExpr"):
+ ex1 = relay.create_executor("vm", mod=mod1, ctx=tvm.cpu(), target="llvm")
+ ex2 = relay.create_executor("vm", mod=mod2, ctx=tvm.cpu(), target="llvm")
+
+ result1 = ex1.evaluate()()
+ result2 = ex2.evaluate()()
+
+ tvm.testing.assert_allclose(result1.asnumpy(), result2.asnumpy())
+
+
if __name__ == "__main__":
test_simplify_reshape()
+ test_simplify_full_argwhere()