You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/31 20:44:55 UTC

[tvm] branch main updated: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr (#7731)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b3ab19e  [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr (#7731)
b3ab19e is described below

commit b3ab19ed63bca0481557dab095c08e24e49dda78
Author: Altan Haan <ah...@octoml.ai>
AuthorDate: Wed Mar 31 13:44:34 2021 -0700

    [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr (#7731)
    
    * factor out some common code for DF rewriting, add ConcretizeLike
    
    * slight refactoring, add EliminateIdentity pass
    
    * lint
    
    * merge ConcretizeLike and EliminateIdentity into SimplifyExpr
    
    * nits and lint
    
    * remove static stuff
    
    * document
    
    * definitely ran clang-format but ok
    
    * make ToScalar return optional, fix missing virtual destructor
    
    * lint
    
    * tweak scalar conversion API to maintain compatibility
---
 src/relay/transforms/dynamic_to_static.cc     |   1 -
 src/relay/transforms/pattern_utils.h          |  50 ++++--
 src/relay/transforms/simplify_expr.cc         | 246 ++++++++++++++++++++++----
 src/relay/transforms/simplify_expr.h          |  91 ++++++++++
 tests/python/relay/test_pass_simplify_expr.py | 146 ++++++++++++++-
 5 files changed, 472 insertions(+), 62 deletions(-)

diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc
index 815e4d2..0590b41 100644
--- a/src/relay/transforms/dynamic_to_static.cc
+++ b/src/relay/transforms/dynamic_to_static.cc
@@ -161,7 +161,6 @@ class DynamicToStaticMutator : public MixedModeMutator {
              ICHECK_EQ(scale_w->data->ndim, 0);
              const UpSampling3DAttrs* param = call_node->attrs.as<UpSampling3DAttrs>();
              ICHECK(param);
-
              return MakeUpSampling3D(call_node->args[0], ToScalar(scale_d->data),
                                      ToScalar(scale_h->data), ToScalar(scale_w->data),
                                      param->layout, param->method,
diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h
index c1eebde..8d9f723 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -27,6 +27,7 @@
 #define TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_
 
 #include <builtin_fp16.h>
+#include <dmlc/optional.h>
 #include <tvm/node/structural_equal.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/attrs/nn.h>
@@ -380,43 +381,56 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
  * \brief Convert an element of a NDArray with type int or float to scalar.
  * \param array Input NDArray
  * \param i element index
- * \return Converted scalar value.
+ * \return Converted scalar value, or None if conversion failed
  */
-static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
+static inline dmlc::optional<long double> TryToScalar(const runtime::NDArray& array, size_t i = 0) {
   if (array->dtype.code == kDLInt) {
     if (array->dtype.bits == 8) {
-      return reinterpret_cast<int8_t*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<int8_t*>(array->data)[i]);
     } else if (array->dtype.bits == 16) {
-      return reinterpret_cast<int16_t*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<int16_t*>(array->data)[i]);
     } else if (array->dtype.bits == 32) {
-      return reinterpret_cast<int32_t*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<int32_t*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<int64_t*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<int64_t*>(array->data)[i]);
     }
   } else if (array->dtype.code == kDLUInt) {
-    if (array->dtype.bits == 8) {
-      return reinterpret_cast<uint8_t*>(array->data)[i];
+    if (array->dtype.bits == 1) {  // bool
+      return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
+    } else if (array->dtype.bits == 8) {
+      return dmlc::optional<long double>(reinterpret_cast<uint8_t*>(array->data)[i]);
     } else if (array->dtype.bits == 16) {
-      return reinterpret_cast<uint16_t*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<uint16_t*>(array->data)[i]);
     } else if (array->dtype.bits == 32) {
-      return reinterpret_cast<uint32_t*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<uint32_t*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<uint64_t*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<uint64_t*>(array->data)[i]);
     }
   } else if (array->dtype.code == kDLFloat) {
     if (array->dtype.bits == 16) {
-      return __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
-          reinterpret_cast<uint16_t*>(array->data)[i]);
+      return dmlc::optional<long double>(
+          __extendXfYf2__<uint16_t, uint16_t, 10, float, uint32_t, 23>(
+              reinterpret_cast<uint16_t*>(array->data)[i]));
     }
     if (array->dtype.bits == 32) {
-      return reinterpret_cast<float*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<float*>(array->data)[i]);
     } else if (array->dtype.bits == 64) {
-      return reinterpret_cast<double*>(array->data)[i];
+      return dmlc::optional<long double>(reinterpret_cast<double*>(array->data)[i]);
     }
   }
-  LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
-  // make compiler happy
-  return -std::numeric_limits<double>::infinity();
+  return dmlc::optional<long double>();
+}
+
+/*!
+ * \brief Convert an element of a NDArray with type int or float to scalar.
+ * \param array Input NDArray
+ * \param i element index
+ * \return Converted scalar value
+ */
+static inline long double ToScalar(const runtime::NDArray& array, size_t i = 0) {
+  auto try_value = TryToScalar(array, i);
+  ICHECK(try_value) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
+  return try_value.value();
 }
 
 /*!
diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index b4f4cc1..762aa58 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -22,35 +22,28 @@
  * \brief A pass for simplifying the Relay expression.
  */
 
+#include "simplify_expr.h"
+
 #include <tvm/relay/dataflow_matcher.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
 #include <tvm/runtime/logging.h>
 
+#include <limits>
+#include <utility>
+
 #include "../op/tensor/transform.h"
 #include "pattern_utils.h"
 
 namespace tvm {
 namespace relay {
 
-class SimplifyPattern {
- public:
-  virtual Expr callback(const Expr& pre, const Expr& post,
-                        const Map<DFPattern, Array<Expr>>& node_map) const = 0;
-
-  DFPattern pattern() const { return pattern_; }
-
- protected:
-  /*! \brief Pattern for rewriting */
-  DFPattern pattern_;
-};
-
 /*!
  * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
  *   and merges into one reshape op.
  */
-class SimplifyReshape : public SimplifyPattern {
+class SimplifyReshape : public DFPatternRewrite {
  public:
   SimplifyReshape() {
     x_ = IsWildcard();
@@ -59,7 +52,7 @@ class SimplifyReshape : public SimplifyPattern {
     pattern_ = reshape1({reshape2({x_})});
   }
 
-  Expr callback(const Expr& pre, const Expr& post,
+  Expr Callback(const Expr& pre, const Expr& post,
                 const Map<DFPattern, Array<Expr>>& node_map) const override {
     auto x = node_map[x_][0];
     bool const_shape = true;
@@ -86,7 +79,7 @@ class SimplifyReshape : public SimplifyPattern {
  * \brief SimplifyTranspose matches the pattern of consecutive transpose op,
  *   and merges or cancels them.
  */
-class SimplifyTranspose : public SimplifyPattern {
+class SimplifyTranspose : public DFPatternRewrite {
  public:
   SimplifyTranspose() {
     x_ = IsWildcard();
@@ -95,7 +88,7 @@ class SimplifyTranspose : public SimplifyPattern {
     pattern_ = trans1({trans2({x_})});
   }
 
-  Expr callback(const Expr& pre, const Expr& post,
+  Expr Callback(const Expr& pre, const Expr& post,
                 const Map<DFPattern, Array<Expr>>& node_map) const override {
     // Helper function to get the axes from call node attribute
     auto get_axes_from_call = [](const Call trans_call, int ndim) {
@@ -176,9 +169,10 @@ class SimplifyTranspose : public SimplifyPattern {
 };
 
 /*!
- * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
+ * \brief FullElementwise finds full like ops followed by broadcasting ops, and eliminates
+ * the full op by directly passing the fill value into the broadcasting op.
  */
-class FullElementwise : public SimplifyPattern {
+class FullElementwise : public DFPatternRewrite {
  public:
   FullElementwise() {
     x_ = IsWildcard();
@@ -196,7 +190,7 @@ class FullElementwise : public SimplifyPattern {
     pattern_ = op({full, x_}) || op({x_, full});
   }
 
-  Expr callback(const Expr& pre, const Expr& post,
+  Expr Callback(const Expr& pre, const Expr& post,
                 const Map<DFPattern, Array<Expr>>& node_map) const override {
     const CallNode* call = pre.as<CallNode>();
     ICHECK(call);
@@ -249,36 +243,210 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
+  }
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    // add and multiply are commutative so we don't need another pattern for reversed args
+    DFPattern add_id = add_op({x_, zeros_expr});
+    DFPattern mul_id = mul_op({x_, ones_expr});
+
+    DFPattern sub_id = IsOp("subtract")({x_, zeros_expr});
+    DFPattern div_id = IsOp("divide")({x_, ones_expr});
+
+    pattern_ = add_id || mul_id || sub_id || div_id;
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  bool CheckConstant(const OpNode* op, const ConstantNode* constant) const {
+    if (!IsScalar(GetRef<Expr>(constant))) {
+      return false;
+    }
+    auto value = TryToScalar(constant->data, 0);
+    if (!value) {
+      // unsupported dtype
+      return false;
+    }
+    if (op->name == "add" || op->name == "subtract") {
+      return value.value() == 0.0;
+    } else if (op->name == "multiply" || op->name == "divide") {
+      return value.value() == 1.0;
+    }
+    return false;
+  }
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    const CallNode* call = pre.as<CallNode>();
+    ICHECK(call);
+    Type pre_type = pre->checked_type_;
+    ICHECK(pre_type.as<TensorTypeNode>());
+    auto x = node_map[x_][0];
+    bool is_left = post.as<CallNode>()->args[1] == x;
+    Type x_type;
+    if (is_left) {
+      x_type = call->args[1]->checked_type_;
+    } else {
+      x_type = call->args[0]->checked_type_;
+    }
+
+    if (node_map.count(const_)) {
+      // the other argument is a Constant in this case
+      const ConstantNode* constant = node_map[const_][0].as<ConstantNode>();
+      const OpNode* op = call->op.as<OpNode>();
+      ICHECK(constant);
+      ICHECK(op);
+      if (!CheckConstant(op, constant)) {
+        return post;
+      }
+    }
+
+    if (StructuralEqual()(x_type, pre_type)) {
+      return x;
+    }
+
+    return post;
+  }
 
  private:
-  IRModule mod_;
-  /*! \brief Callbacks for expr simplification */
-  Array<DFPatternCallback> callbacks_;
+  DFPattern x_;
+  DFPattern const_;
 };
 
 Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
-  return ExprSimplifier(mod).Simplify(expr);
+  // the rewrites will be applied in the given order, and repeated until fixed point
+  DFPatternRewriteComposer composer;
+  composer.AddRewrite<ConcretizeZerosLikeRewrite>();
+  composer.AddRewrite<ConcretizeOnesLikeRewrite>();
+  composer.AddRewrite<ConcretizeReshapeLikeRewrite>();
+  composer.AddRewrite<ConcretizeCollapseSumLikeRewrite>();
+  composer.AddRewrite<ConcretizeBroadcastToLikeRewrite>();
+  composer.AddRewrite<EliminateIdentityRewrite>();
+  composer.AddRewrite<SimplifyReshape>();
+  composer.AddRewrite<SimplifyTranspose>();
+  composer.AddRewrite<FullElementwise>();
+  return RewritePatterns(composer.MakeCallbacks(), expr, mod);
 }
 
 namespace transform {
diff --git a/src/relay/transforms/simplify_expr.h b/src/relay/transforms/simplify_expr.h
new file mode 100644
index 0000000..6b3925e
--- /dev/null
+++ b/src/relay/transforms/simplify_expr.h
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.h
+ * \brief Utility data structures for simplifying Relay expressions.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+
+#include <memory>
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief A wrapper class defining a rewrite matching a specific pattern. */
+class DFPatternRewrite {
+ public:
+  /*! \brief Returns the rewritten expression. */
+  virtual Expr Callback(const Expr& pre, const Expr& post,
+                        const Map<DFPattern, Array<Expr>>& node_map) const = 0;
+
+  virtual ~DFPatternRewrite() = default;
+
+  /*! \brief Returns the pattern to be used for matching and rewriting. */
+  inline DFPattern Pattern() const { return pattern_; }
+
+  inline bool RequireType() const { return require_type_; }
+
+  inline DFPatternCallback MakeCallback() const {
+    auto func = [this](TVMArgs args, TVMRetValue* rv) {
+      Expr pre = args[0];
+      Expr post = args[1];
+      Map<DFPattern, Array<Expr>> node_map = args[2];
+      *rv = this->Callback(pre, post, node_map);
+    };
+    return DFPatternCallback(pattern_, PackedFunc(func), require_type_);
+  }
+
+ protected:
+  /*! \brief The pattern for matching and rewriting. */
+  DFPattern pattern_;
+  /*! \brief Whether or not the rewrite requires types to be inferred. */
+  bool require_type_ = true;
+};
+
+/*! \brief Helper class for composing rewrites and getting callbacks. */
+class DFPatternRewriteComposer {
+ public:
+  template <typename T, typename... Args>
+  inline void AddRewrite(Args... args) {
+    rewrites_.push_back(std::make_shared<T, Args...>(&args...));
+  }
+
+  inline Array<DFPatternCallback> MakeCallbacks() const {
+    Array<DFPatternCallback> callbacks;
+    for (const auto rewrite : rewrites_) {
+      callbacks.push_back(rewrite->MakeCallback());
+    }
+    return callbacks;
+  }
+
+ private:
+  /*! \brief the rewrites to be composed. */
+  std::vector<std::shared_ptr<DFPatternRewrite>> rewrites_;
+};
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index 897f90b..d015cdd 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -14,10 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import pytest
 import tvm
 from tvm import relay
 from tvm.relay import transform
-from tvm.relay.testing import run_opt_pass
+from tvm.relay.testing import run_opt_pass, run_infer_type
 
 import numpy as np
 
@@ -123,12 +124,22 @@ def test_simplify_full_elementwise():
             return elem_op(full, x)
 
         def after_left(x, elem_op, value):
+            if elem_op == relay.add and value == 0:
+                return x
+            elif elem_op == relay.multiply and (value == 1 or (value > 1 and dtype == "bool")):
+                return x
             return elem_op(relay.const(value, dtype), x)
 
         def before_right(x, elem_op, full):
             return elem_op(x, full)
 
         def after_right(x, elem_op, value):
+            if elem_op in [relay.add, relay.subtract] and value == 0:
+                return x
+            elif elem_op in [relay.multiply, relay.divide] and (
+                value == 1 or (value > 1 and dtype == "bool")
+            ):
+                return x
             return elem_op(x, relay.const(value, dtype))
 
         x = relay.var("x", shape=shape, dtype=dtype)
@@ -181,7 +192,134 @@ def test_simplify_full_elementwise():
                 validate(shape, value, dtype)
 
 
+def test_eliminate_identity():
+    def check(x, y=None, do_nothing=False):
+        expected = run_infer_type(x)
+        if do_nothing:
+            actual = run_opt_pass(x, transform.SimplifyExpr())
+            assert tvm.ir.structural_equal(actual, expected)
+        else:
+            assert y is not None
+            actual = run_opt_pass(y, transform.SimplifyExpr())
+            assert tvm.ir.structural_equal(actual, expected)
+
+    shape = [2, 3, 4]
+    dtype = "float32"
+    x = relay.var("x", shape=shape, dtype=dtype)
+    x = run_opt_pass(x, transform.InferType())
+
+    for (op, op_like, id_op, const) in [
+        (relay.zeros, relay.zeros_like, relay.add, relay.const(0, dtype)),
+        (relay.ones, relay.ones_like, relay.multiply, relay.const(1, dtype)),
+    ]:
+        check(x, id_op(op_like(x), x))
+        check(x, id_op(op(shape, dtype), x))
+        check(x, id_op(const, x))
+        check(x, id_op(op(shape[1:], dtype), x))
+        check(x, id_op(x, op_like(x)))
+        check(x, id_op(x, op(shape, dtype)))
+        check(x, id_op(x, const))
+        check(x, id_op(x, op(shape[1:], dtype)))
+        check(id_op(x, op([2] + shape, dtype)), do_nothing=True)
+        check(id_op(op([2] + shape, dtype), x), do_nothing=True)
+
+    for (op, op_like, id_op, const) in [
+        (relay.zeros, relay.zeros_like, relay.subtract, relay.const(0, dtype)),
+        (relay.ones, relay.ones_like, relay.divide, relay.const(1, dtype)),
+    ]:
+        check(x, id_op(x, op_like(x)))
+        check(x, id_op(x, const))
+        check(x, id_op(x, op(shape, dtype)))
+        check(x, id_op(x, op(shape[1:], dtype)))
+        check(id_op(x, op([2] + shape, dtype)), do_nothing=True)
+        check(id_op(const, x), id_op(op(shape, dtype), x))
+        check(id_op(const, x), id_op(op_like(x), x))
+
+
+def test_concretize_reshape_like():
+    data = relay.var("data", shape=(2, 3, 4), dtype="float32")
+    shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
+    expr = relay.reshape_like(data, shape_like)
+
+    expected = run_infer_type(relay.reshape(data, (6, 2, 2)))
+    actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+    assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_reshape_like_attrs():
+    data = relay.var("data", shape=(2, 3, 4), dtype="float32")
+    shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32")
+    expr = relay.reshape_like(data, shape_like, lhs_begin=2, rhs_begin=1)
+
+    expected = run_infer_type(relay.reshape(data, (2, 3, 2, 2)))
+    actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+    assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_zeros_like():
+    dtype = "int32"
+    shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype)
+    expr = relay.zeros_like(shape_like)
+
+    expected = run_infer_type(relay.zeros((3, 4, 5), dtype))
+    actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+    assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_ones_like():
+    dtype = "int32"
+    shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype)
+    expr = relay.ones_like(shape_like)
+
+    expected = run_infer_type(relay.ones((3, 4, 5), dtype))
+    actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+    assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_collapse_sum_like():
+    data = relay.var("data", shape=(3, 3, 3), dtype="float32")
+    shape_like = relay.var("shape_like", shape=(3,), dtype="float32")
+    expr = relay.collapse_sum_like(data, shape_like)
+
+    expected = run_infer_type(relay.collapse_sum_to(data, (3,)))
+    actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+    assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_broadcast_to_like():
+    data = relay.var("data", shape=(3,), dtype="float32")
+    shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32")
+    expr = relay.broadcast_to_like(data, shape_like)
+
+    expected = run_infer_type(relay.broadcast_to(data, (3, 3, 3)))
+    actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
+    assert tvm.ir.structural_equal(actual, expected)
+
+
+def test_concretize_multiple():
+    x = relay.var("x", shape=(2, 3), dtype="float32")
+    y = relay.var("y", shape=(3,), dtype="float32")
+    l = x + y
+
+    dl = relay.ones_like(l)
+    dx = relay.zeros_like(x)
+    dy = relay.zeros_like(y)
+    dx = dx + relay.collapse_sum_like(dl, dx)
+    dy = dy + relay.collapse_sum_like(dl, dy)
+    ret = relay.Tuple([dx, dy])
+
+    dl_c = relay.ones((2, 3), "float32")
+    # NOTE: these are removed by EliminateIdentity
+    # dx_c = relay.zeros((2, 3), "float32")
+    # dy_c = relay.zeros((3,), "float32")
+    dx_c = relay.collapse_sum_to(dl_c, (2, 3))
+    dy_c = relay.collapse_sum_to(dl_c, (3,))
+    ret_c = relay.Tuple([dx_c, dy_c])
+
+    expected = run_infer_type(ret_c)
+    actual = run_opt_pass(ret, relay.transform.SimplifyExpr())
+    assert tvm.ir.structural_equal(actual, expected)
+
+
 if __name__ == "__main__":
-    test_simplify_reshape()
-    test_simplify_transpose()
-    test_simplify_full_elementwise()
+    pytest.main([__file__])