You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/31 20:44:55 UTC
[tvm] branch main updated: [Relay][Pass] ConcretizeLike and
EliminateIdentity rewrites for SimplifyExpr (#7731)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b3ab19e [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr (#7731)
b3ab19e is described below
commit b3ab19ed63bca0481557dab095c08e24e49dda78
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Wed Mar 31 13:44:34 2021 -0700
[Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr (#7731)
* factor out some common code for DF rewriting, add ConcretizeLike
* slight refactoring, add EliminateIdentity pass
* lint
* merge ConcretizeLike and EliminateIdentity into SimplifyExpr
* nits and lint
* remove static stuff
* document
* definitely ran clang-format but ok
* make ToScalar return optional, fix missing virtual destructor
* lint
* tweak scalar conversion API to maintain compatibility
---
src/relay/transforms/dynamic_to_static.cc | 1 -
src/relay/transforms/pattern_utils.h | 50 ++++--
src/relay/transforms/simplify_expr.cc | 246 ++++++++++++++++++++++----
src/relay/transforms/simplify_expr.h | 91 ++++++++++
tests/python/relay/test_pass_simplify_expr.py | 146 ++++++++++++++-
5 files changed, 472 insertions(+), 62 deletions(-)
diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc
index 815e4d2..0590b41 100644
--- a/src/relay/transforms/dynamic_to_static.cc
+++ b/src/relay/transforms/dynamic_to_static.cc
@@ -161,7 +161,6 @@ class DynamicToStaticMutator : public MixedModeMutator {
ICHECK_EQ(scale_w->data->ndim, 0);
const UpSampling3DAttrs* param = call_node->attrs.as<UpSampling3DAttrs>();
ICHECK(param);
-
return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data),
ToScalar(scale_h->data), ToScalar(scale_w->data),
param->layout, param->method,
diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index c1eebde..8d9f723 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -27,6 +27,7 @@
#define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_
#include <builtin_fp16.h>
+#include <dmlc/optional.h>
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
@@ -380,43 +381,56 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
* \brief Convert an element of a NDArray with type int or float to scalar.
* \param array Input NDArray
* \param i element index
- * \return Converted scalar value.
+ * \return Converted scalar value, or None if conversion failed
*/
-static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
+static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
if (array->dtype.code == kDLInt) {
if (array->dtype.bits == 8) {
- return reinterpret_cast<int8_t*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
- return reinterpret_cast<int16_t*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
- return reinterpret_cast<int32_t*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return reinterpret_cast<int64_t*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLUInt) {
- if (array->dtype.bits == 8) {
- return reinterpret_cast<uint8_t*>(array->data)[i];
+ if (array->dtype.bits == 1) { // bool
+ return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+ } else if (array->dtype.bits == 8) {
+ return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
} else if (array->dtype.bits == 16) {
- return reinterpret_cast<uint16_t*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
} else if (array->dtype.bits == 32) {
- return reinterpret_cast<uint32_t*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return reinterpret_cast<uint64_t*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
}
} else if (array->dtype.code == kDLFloat) {
if (array->dtype.bits == 16) {
- return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
- reinterpret_cast<uint16_t*>(array->data)[i]);
+ return dmlc::optional<long double>(
+ __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
+ reinterpret_cast<uint16_t*>(array->data)[i]));
}
if (array->dtype.bits == 32) {
- return reinterpret_cast<float*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
} else if (array->dtype.bits == 64) {
- return reinterpret_cast<double*>(array->data)[i];
+ return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
}
}
- LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
- // make compiler happy
- return -std::numeric_limits<double>::infinity();
+ return dmlc::optional<long double>();
+}
+
+/*!
+ * \brief Convert an element of a NDArray with type int or float to scalar.
+ * \param array Input NDArray
+ * \param i element index
+ * \return Converted scalar value
+ */
+static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
+ auto try_value = TryToScalar(array, i);
+ ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
+ return try_value.value();
}
/*!
diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index b4f4cc1..762aa58 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -22,35 +22,28 @@
* \brief A pass for simplifying the Relay expression.
*/
+#include "simplify_expr.h"
+
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/logging.h>
+#include <limits>
+#include <utility>
+
#include "../op/tensor/transform.h"
#include "pattern_utils.h"
namespace tvm {
namespace relay {
-class SimplifyPattern {
- public:
- virtual Expr callback(const Expr& pre, const Expr& post,
- const Map<DFPattern, Array<Expr>>& node_map) const = 0;
-
- DFPattern pattern() const { return pattern_; }
-
- protected:
- /*! \brief Pattern for rewriting */
- DFPattern pattern_;
-};
-
/*!
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
* and merges into one reshape op.
*/
-class SimplifyReshape : public SimplifyPattern {
+class SimplifyReshape : public DFPatternRewrite {
public:
SimplifyReshape() {
x_ = IsWildcard();
@@ -59,7 +52,7 @@ class SimplifyReshape : public SimplifyPattern {
pattern_ = reshape1({reshape2({x_})});
}
- Expr callback(const Expr& pre, const Expr& post,
+ Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
auto x = node_map[x_][0];
bool const_shape = true;
@@ -86,7 +79,7 @@ class SimplifyReshape : public SimplifyPattern {
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
*/
-class SimplifyTranspose : public SimplifyPattern {
+class SimplifyTranspose : public DFPatternRewrite {
public:
SimplifyTranspose() {
x_ = IsWildcard();
@@ -95,7 +88,7 @@ class SimplifyTranspose : public SimplifyPattern {
pattern_ = trans1({trans2({x_})});
}
- Expr callback(const Expr& pre, const Expr& post,
+ Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
// Helper function to get the axes from call node attribute
auto get_axes_from_call = [](const Call trans_call, int ndim) {
@@ -176,9 +169,10 @@ class SimplifyTranspose : public SimplifyPattern {
};
/*!
- * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
+ * \brief FullElementwise finds full like ops followed by broadcasting ops, and eliminates
+ * the full op by directly passing the fill value into the broadcasting op.
*/
-class FullElementwise : public SimplifyPattern {
+class FullElementwise : public DFPatternRewrite {
public:
FullElementwise() {
x_ = IsWildcard();
@@ -196,7 +190,7 @@ class FullElementwise : public SimplifyPattern {
pattern_ = op({full, x_}) || op({x_, full});
}
- Expr callback(const Expr& pre, const Expr& post,
+ Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
const CallNode* call = pre.as<CallNode>();
ICHECK(call);
@@ -249,36 +243,210 @@ class FullElementwise : public SimplifyPattern {
};
/*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
*/
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
public:
- explicit ExprSimplifier(IRModule mod) : mod_(mod) {
- CreateCallback(SimplifyReshape());
- CreateCallback(SimplifyTranspose());
- CreateCallback(FullElementwise());
+ explicit ConcretizeLikeRewrite(const Op& op) {
+ ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+ << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+ like_pat_ = IsWildcard();
+ data_pat_ = IsWildcard();
+ if (op->num_inputs == 1) {
+ pattern_ = IsExpr(op)({like_pat_});
+ } else {
+ pattern_ = IsExpr(op)({data_pat_, like_pat_});
+ }
}
- template <typename T>
- void CreateCallback(const T& pattern) {
- auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
- Expr pre = args[0];
- Expr post = args[1];
- Map<DFPattern, Array<Expr>> node_map = args[2];
- *rv = pattern.callback(pre, post, node_map);
- };
- callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+ virtual bool Check(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const {
+ const CallNode* call_node = pre.as<CallNode>();
+ ICHECK(call_node);
+
+ if (!call_node->checked_type().as<TensorTypeNode>()) {
+ return false;
+ }
+
+ return true;
+ }
+
+ virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+ DataType dtype) const = 0;
+
+ Expr Callback(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const override {
+ if (!Check(pre, post, node_map)) {
+ return post;
+ }
+
+ const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+ Array<Integer> cshape;
+ for (const auto& dim : like_ty->shape) {
+ if (const auto* imm = dim.as<IntImmNode>()) {
+ cshape.push_back(Integer(GetRef<IntImm>(imm)));
+ } else {
+ // shape is not static, don't concretize
+ return post;
+ }
+ }
+
+ return Concretize(node_map, cshape, like_ty->dtype);
+ }
+
+ protected:
+ DFPattern data_pat_;
+ DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+ ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+ Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+ DataType dtype) const override {
+ return MakeZeros(shape, dtype);
+ }
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+ ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+ Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+ DataType dtype) const override {
+ return MakeOnes(shape, dtype);
+ }
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+ ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+ Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+ DataType dtype) const override {
+ return MakeReshape(node_map[data_pat_][0], shape);
+ }
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+ ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+ Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+ DataType dtype) const override {
+ ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+ static const Op& op = Op::Get("collapse_sum_to");
+ auto attrs = make_object<InitOpAttrs>();
+ attrs->shape = shape;
+ auto cshape =
+ MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+ return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+ }
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+ ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+ Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+ DataType dtype) const override {
+ return MakeBroadCastTo(node_map[data_pat_][0], shape);
+ }
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+ EliminateIdentityRewrite() {
+ x_ = IsWildcard();
+ const_ = IsConstant();
+
+ DFPattern add_op = IsOp("add");
+ DFPattern mul_op = IsOp("multiply");
+ DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+ DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+ // add and multiply are commutative so we don't need another pattern for reversed args
+ DFPattern add_id = add_op({x_, zeros_expr});
+ DFPattern mul_id = mul_op({x_, ones_expr});
+
+ DFPattern sub_id = IsOp("subtract")({x_, zeros_expr});
+ DFPattern div_id = IsOp("divide")({x_, ones_expr});
+
+ pattern_ = add_id || mul_id || sub_id || div_id;
}
- Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+ bool CheckConstant(const OpNode* op, const ConstantNode* constant) const {
+ if (!IsScalar(GetRef<Expr>(constant))) {
+ return false;
+ }
+ auto value = TryToScalar(constant->data, 0);
+ if (!value) {
+ // unsupported dtype
+ return false;
+ }
+ if (op->name == "add" || op->name == "subtract") {
+ return value.value() == 0.0;
+ } else if (op->name == "multiply" || op->name == "divide") {
+ return value.value() == 1.0;
+ }
+ return false;
+ }
+
+ Expr Callback(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const override {
+ const CallNode* call = pre.as<CallNode>();
+ ICHECK(call);
+ Type pre_type = pre->checked_type_;
+ ICHECK(pre_type.as<TensorTypeNode>());
+ auto x = node_map[x_][0];
+ bool is_left = post.as<CallNode>()->args[1] == x;
+ Type x_type;
+ if (is_left) {
+ x_type = call->args[1]->checked_type_;
+ } else {
+ x_type = call->args[0]->checked_type_;
+ }
+
+ if (node_map.count(const_)) {
+ // the other argument is a Constant in this case
+ const ConstantNode* constant = node_map[const_][0].as<ConstantNode>();
+ const OpNode* op = call->op.as<OpNode>();
+ ICHECK(constant);
+ ICHECK(op);
+ if (!CheckConstant(op, constant)) {
+ return post;
+ }
+ }
+
+ if (StructuralEqual()(x_type, pre_type)) {
+ return x;
+ }
+
+ return post;
+ }
private:
- IRModule mod_;
- /*! \brief Callbacks for expr simplification */
- Array<DFPatternCallback> callbacks_;
+ DFPattern x_;
+ DFPattern const_;
};
Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
- return ExprSimplifier(mod).Simplify(expr);
+ // the rewrites will be applied in the given order, and repeated until fixed point
+ DFPatternRewriteComposer composer;
+ composer.AddRewrite<ConcretizeZerosLikeRewrite>();
+ composer.AddRewrite<ConcretizeOnesLikeRewrite>();
+ composer.AddRewrite<ConcretizeReshapeLikeRewrite>();
+ composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
+ composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
+ composer.AddRewrite<EliminateIdentityRewrite>();
+ composer.AddRewrite<SimplifyReshape>();
+ composer.AddRewrite<SimplifyTranspose>();
+ composer.AddRewrite<FullElementwise>();
+ return RewritePatterns(composer.MakeCallbacks(), expr, mod);
}
namespace transform {
diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h
new file mode 100644
index 0000000..6b3925e
--- /dev/null
+++ b/src/relay/transforms/simplify_expr.h
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.h
+ * \brief Utility data structures for simplifying Relay expressions.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+
+#include <memory>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief A wrapper class defining a rewrite matching a specific pattern. */
+class DFPatternRewrite {
+ public:
+ /*! \brief Returns the rewritten expression. */
+ virtual Expr Callback(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const = 0;
+
+ virtual ~DFPatternRewrite() = default;
+
+ /*! \brief Returns the pattern to be used for matching and rewriting. */
+ inline DFPattern Pattern() const { return pattern_; }
+
+ inline bool RequireType() const { return require_type_; }
+
+ inline DFPatternCallback MakeCallback() const {
+ auto func = [this](TVMArgs args, TVMRetValue* rv) {
+ Expr pre = args[0];
+ Expr post = args[1];
+ Map<DFPattern, Array<Expr>> node_map = args[2];
+ *rv = this->Callback(pre, post, node_map);
+ };
+ return DFPatternCallback(pattern_, PackedFunc(func), require_type_);
+ }
+
+ protected:
+ /*! \brief The pattern for matching and rewriting. */
+ DFPattern pattern_;
+ /*! \brief Whether or not the rewrite requires types to be inferred. */
+ bool require_type_ = true;
+};
+
+/*! \brief Helper class for composing rewrites and getting callbacks. */
+class DFPatternRewriteComposer {
+ public:
+ template <typename T, typename... Args>
+ inline void AddRewrite(Args... args) {
+ rewrites_.push_back(std::make_shared<T, Args...>(&args...));
+ }
+
+ inline Array<DFPatternCallback> MakeCallbacks() const {
+ Array<DFPatternCallback> callbacks;
+ for (const auto rewrite : rewrites_) {
+ callbacks.push_back(rewrite->MakeCallback());
+ }
+ return callbacks;
+ }
+
+ private:
+ /*! \brief the rewrites to be composed. */
+ std::vector<std::shared_ptr<DFPatternRewrite>> rewrites_;
+};
+
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index 897f90b..d015cdd 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import pytest
import tvm
from tvm import relay
from tvm.relay import transform
-from tvm.relay.testing import run_opt_pass
+from tvm.relay.testing import run_opt_pass, run_infer_type
import numpy as np
@@ -123,12 +124,22 @@ def test_simplify_full_elementwise():
return elem_op(full, x)
def after_left(x, elem_op, value):
+ if elem_op == relay.add and value == 0:
+ return x
+ elif elem_op == relay.multiply and (value == 1 or (value > 1 and dtype == "bool")):
+ return x
return elem_op(relay.const(value, dtype), x)
def before_right(x, elem_op, full):
return elem_op(x, full)
def after_right(x, elem_op, value):
+ if elem_op in [relay.add, relay.subtract] and value == 0:
+ return x
+ elif elem_op in [relay.multiply, relay.divide] and (
+ value == 1 or (value > 1 and dtype == "bool")
+ ):
+ return x
return elem_op(x, relay.const(value, dtype))
x = relay.var("x", shape=shape, dtype=dtype)
@@ -181,7 +192,134 @@ def test_simplify_full_elementwise():
validate(shape, value, dtype)
+def test_eliminate_identity():
+ def check(x, y=None, do_nothing=False):
+ expected = run_infer_type(x)
+ if do_nothing:
+ actual = run_opt_pass(x, transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+ else:
+ assert y is not None
+ actual = run_opt_pass(y, transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+
+ shape = [2, 3, 4]
+ dtype = "float32"
+ x = relay.var("x", shape=shape, dtype=dtype)
+ x = run_opt_pass(x, transform.InferType())
+
+ for (op, op_like, id_op, const) in [
+ (relay.zeros, relay.zeros_like, relay.add, relay.const(0, dtype)),
+ (relay.ones, relay.ones_like, relay.multiply, relay.const(1, dtype)),
+ ]:
+ check(x, id_op(op_like(x), x))
+ check(x, id_op(op(shape, dtype), x))
+ check(x, id_op(const, x))
+ check(x, id_op(op(shape[1:], dtype), x))
+ check(x, id_op(x, op_like(x)))
+ check(x, id_op(x, op(shape, dtype)))
+ check(x, id_op(x, const))
+ check(x, id_op(x, op(shape[1:], dtype)))
+ check(id_op(x, op([2] + shape, dtype)), do_nothing=True)
+ check(id_op(op([2] + shape, dtype), x), do_nothing=True)
+
+ for (op, op_like, id_op, const) in [
+ (relay.zeros, relay.zeros_like, relay.subtract, relay.const(0, dtype)),
+ (relay.ones, relay.ones_like, relay.divide, relay.const(1, dtype)),
+ ]:
+ check(x, id_op(x, op_like(x)))
+ check(x, id_op(x, const))
+ check(x, id_op(x, op(shape, dtype)))
+ check(x, id_op(x, op(shape[1:], dtype)))
+ check(id_op(x, op([2] + shape, dtype)), do_nothing=True)
+ check(id_op(const, x), id_op(op(shape, dtype), x))
+ check(id_op(const, x), id_op(op_like(x), x))
+
+
+def test_concretize_reshape_like():
+ data = relay.var("data", shape=(2, 3, 4), dtype="float32")
+ shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
+ expr = relay.reshape_like(data, shape_like)
+
+ expected = run_infer_type(relay.reshape(data, (6, 2, 2)))
+ actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_reshape_like_attrs():
+ data = relay.var("data", shape=(2, 3, 4), dtype="float32")
+ shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
+ expr = relay.reshape_like(data, shape_like, lhs_begin=2, rhs_begin=1)
+
+ expected = run_infer_type(relay.reshape(data, (2, 3, 2, 2)))
+ actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_zeros_like():
+ dtype = "int32"
+ shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype)
+ expr = relay.zeros_like(shape_like)
+
+ expected = run_infer_type(relay.zeros((3, 4, 5), dtype))
+ actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_ones_like():
+ dtype = "int32"
+ shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype)
+ expr = relay.ones_like(shape_like)
+
+ expected = run_infer_type(relay.ones((3, 4, 5), dtype))
+ actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_collapse_sum_like():
+ data = relay.var("data", shape=(3, 3, 3), dtype="float32")
+ shape_like = relay.var("shape_like", shape=(3,), dtype="float32")
+ expr = relay.collapse_sum_like(data, shape_like)
+
+ expected = run_infer_type(relay.collapse_sum_to(data, (3,)))
+ actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_broadcast_to_like():
+ data = relay.var("data", shape=(3,), dtype="float32")
+ shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32")
+ expr = relay.broadcast_to_like(data, shape_like)
+
+ expected = run_infer_type(relay.broadcast_to(data, (3, 3, 3)))
+ actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_multiple():
+ x = relay.var("x", shape=(2, 3), dtype="float32")
+ y = relay.var("y", shape=(3,), dtype="float32")
+ l = x + y
+
+ dl = relay.ones_like(l)
+ dx = relay.zeros_like(x)
+ dy = relay.zeros_like(y)
+ dx = dx + relay.collapse_sum_like(dl, dx)
+ dy = dy + relay.collapse_sum_like(dl, dy)
+ ret = relay.Tuple([dx, dy])
+
+ dl_c = relay.ones((2, 3), "float32")
+ # NOTE: these are removed by EliminateIdentity
+ # dx_c = relay.zeros((2, 3), "float32")
+ # dy_c = relay.zeros((3,), "float32")
+ dx_c = relay.collapse_sum_to(dl_c, (2, 3))
+ dy_c = relay.collapse_sum_to(dl_c, (3,))
+ ret_c = relay.Tuple([dx_c, dy_c])
+
+ expected = run_infer_type(ret_c)
+ actual = run_opt_pass(ret, relay.transform.SimplifyExpr())
+ assert tvm.ir.structural_equal(actual, expected)
+
+
if __name__ == "__main__":
- test_simplify_reshape()
- test_simplify_transpose()
- test_simplify_full_elementwise()
+ pytest.main([__file__])