You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/24 06:05:16 UTC

[GitHub] [tvm] altanh opened a new pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

altanh opened a new pull request #7731:
URL: https://github.com/apache/tvm/pull/7731


   This PR introduces two new passes:
   - `ConcretizeLike`: replaces `*_like` operators with their concrete-shape equivalent when the result shape is concrete.
   - `EliminateIdentity`: eliminates identity expressions like `x + zeros()`, `ones() * x`, etc. Expressions that broadcast `x` to a new shape are not removed, although we could explicitly replace them with broadcasting ops (not sure of the performance difference for this). This pass also doesn't examine the value of constants, so `x + const(0)` will not be eliminated; if anyone has a datatype-portable solution for this let me know. For this reason, this pass should be run before `SimplifyExpr`.
   
   I also refactored the existing DFPatternCallback-based passes slightly to lift out common machinery. I tried an approach of making these pattern rewrites statically initialized (@comaniac please let me know what you think of this approach).
   
   Together, these passes should help optimize the generated AD code (and credit to @t-vi for prototyping them in the blog post), by `FirstOrderGradient -> ConcretizeLike -> EliminateIdentity -> ...`
   
   cc @tqchen @comaniac @MarisaKirisame @yzhliu 
   
   (I'll work on a `DeadParameterElimination` pass to complement `ConcretizeLike` as we discussed in the previous PR, but will send as a follow-up.)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600673972



##########
File path: src/relay/transforms/concretize_like.cc
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file concretize_like.cc
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
+ */
+
+#include <tvm/relay/transform.h>
+
+#include "pattern_utils.h"
+#include "simplify_expr.h"
+
+namespace tvm {
+namespace relay {
+
+class ConcretizeLikeRewrite : public DFPatternRewrite {
+ public:
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
+    require_type_ = true;
+  }
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type_.defined()) {
+      // TODO(@altanh): maybe because of the input being rewritten?

Review comment:
       You could manually assign the `checked_type_` of input to the new created node in the rewrite callback.

##########
File path: src/relay/transforms/concretize_like.cc
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file concretize_like.cc
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
+ */
+
+#include <tvm/relay/transform.h>
+
+#include "pattern_utils.h"
+#include "simplify_expr.h"
+
+namespace tvm {
+namespace relay {
+
+class ConcretizeLikeRewrite : public DFPatternRewrite {
+ public:
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
+    require_type_ = true;
+  }
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type_.defined()) {
+      // TODO(@altanh): maybe because of the input being rewritten?
+      return false;
+    }
+
+    const TensorTypeNode* like_ty = call_node->checked_type().as<TensorTypeNode>();
+    ICHECK(like_ty) << "got non-Tensor *_like call type " << PrettyPrint(call_node->checked_type());
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        return post;

Review comment:
       Add a comment saying like we do nothing here when the reference shape is not static.

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -22,44 +22,37 @@
  * \brief A pass for simplifying the Relay expression.
  */
 
+#include "simplify_expr.h"
+
 #include <tvm/relay/dataflow_matcher.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
 #include <tvm/runtime/logging.h>
 
+#include <utility>
+
 #include "../op/tensor/transform.h"
 #include "pattern_utils.h"
 
 namespace tvm {
 namespace relay {
 
-class SimplifyPattern {
- public:
-  virtual Expr callback(const Expr& pre, const Expr& post,
-                        const Map<DFPattern, Array<Expr>>& node_map) const = 0;
-
-  DFPattern pattern() const { return pattern_; }
-
- protected:
-  /*! \brief Pattern for rewriting */
-  DFPattern pattern_;
-};
-
 /*!
  * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
  *   and merges into one reshape op.
  */
-class SimplifyReshape : public SimplifyPattern {
+class SimplifyReshape : public DFPatternRewrite {
  public:
   SimplifyReshape() {
     x_ = IsWildcard();
     auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
     auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
     pattern_ = reshape1({reshape2({x_})});
+    require_type_ = true;

Review comment:
       I think it should be fine to set this variable as true by default in `DFPatternRewrite` so that you could get rid of all of these statements. The case that we don't need type checking should be rare.

##########
File path: src/relay/transforms/concretize_like.cc
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file concretize_like.cc
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
+ */
+
+#include <tvm/relay/transform.h>
+
+#include "pattern_utils.h"
+#include "simplify_expr.h"
+
+namespace tvm {
+namespace relay {
+
+class ConcretizeLikeRewrite : public DFPatternRewrite {
+ public:
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
+    require_type_ = true;
+  }
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type_.defined()) {
+      // TODO(@altanh): maybe because of the input being rewritten?
+      return false;
+    }
+
+    const TensorTypeNode* like_ty = call_node->checked_type().as<TensorTypeNode>();
+    ICHECK(like_ty) << "got non-Tensor *_like call type " << PrettyPrint(call_node->checked_type());
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }

Review comment:
       Because of this use case, I'm wondering if `Check` should always be safe to return (i.e., no `ICHECK`)?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600852619



##########
File path: src/relay/transforms/simplify_expr.h
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.h
+ * \brief Utility data structures for simplifying Relay expressions.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized
+ * instance of RewriteType. */
+#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType)                    \
+  static DFPatternRewrite* Get() {                                    \
+    static RewriteType rw;                                            \
+    return &rw;                                                       \
+  }                                                                   \
+  static DFPatternCallback GetCallback() {                            \
+    static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \
+    return cb;                                                        \
+  }

Review comment:
       The previously PR has a different implemetation and my point was the pattern table itself should be static. Given the current implemntation is based on SimplifyExpr, I agree with @mbrookhart that we don't need to make those functions static in the pattern class.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600912352



##########
File path: src/relay/transforms/simplify_expr.h
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.h
+ * \brief Utility data structures for simplifying Relay expressions.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized
+ * instance of RewriteType. */
+#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType)                    \
+  static DFPatternRewrite* Get() {                                    \
+    static RewriteType rw;                                            \
+    return &rw;                                                       \
+  }                                                                   \
+  static DFPatternCallback GetCallback() {                            \
+    static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \
+    return cb;                                                        \
+  }

Review comment:
       removed, made a helper class for composing rewrites since we need to ensure the lifetimes of the DFPatternCallbacks do not exceed the Rewrite objects




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806029284


   Yeah, as long as we aren't commonly manifesting full sized arrays of zero or one, that should be fine. Given the full/zeros/ones ops and their like counterparts, plus auto-broadcasting, I think that's generally a reasonable assumption to make.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806348360


   Ended up having to make ToScalar return an optional value due to a custom datatype test (which as far as I can tell, we don't have a good way of supporting conversion at compile time in C++ currently). Let me know if this is fine, as I don't see an alternative within the scope of this PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac merged pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
comaniac merged pull request #7731:
URL: https://github.com/apache/tvm/pull/7731


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600858602



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +249,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});

Review comment:
       Not sure what you mean by `x + 0 = x` as I think I already have that covered, can you confirm?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600863512



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +249,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});

Review comment:
       I had this roughly implementend in the pattern matcher tests, but never really productized: https://github.com/apache/tvm/blob/3ba586803ac7956813177aebf8072e7d7c0ab9b2/tests/python/relay/test_dataflow_pattern.py#L1039-L1127
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-810748527


   ping @mbrookhart @comaniac, I adjusted the change to `ToScalar` by making a new function `TryToScalar` so that the existing API does not need to change (although we should probably keep in mind where we use `ToScalar` for bring-your-own-datatype compatibility)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600823616



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +248,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());

Review comment:
       while extremely unlikely, in theory it's possible for `shape.size()` (which is `size_t`) to be greater than INT_MAX, so I just added a check here. (I have to cast the size later to int64_t to pass the dimension to `MakeConstantTensor`)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806018953


   It looks like you have a windows build issue?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600800776



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +248,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());

Review comment:
       Why we need this check?

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +248,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});
+    DFPattern sub_id = IsOp("subtract")({x_, zeros_expr});
+    DFPattern div_id = IsOp("divide")({x_, ones_expr});
+
+    pattern_ = add_id || mul_id || sub_id || div_id;
+  }
+
+  bool CheckConstant(const OpNode* op, const ConstantNode* constant) const {
+    if (!IsScalar(GetRef<Expr>(constant))) {
+      return false;
+    }
+    long double value = ToScalar(constant->data);
+    if (op->name == "add" || op->name == "subtract") {
+      return value == 0.0;
+    } else if (op->name == "multiply" || op->name == "divide") {
+      return value == 1.0;
+    }
+    return false;
+  }
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    const CallNode* call = pre.as<CallNode>();
+    ICHECK(call);
+    Type pre_type = pre->checked_type_;
+    ICHECK(pre_type.as<TensorTypeNode>());
+    auto x = node_map[x_][0];
+    bool is_left = post.as<CallNode>()->args[1] == x;
+    Type x_type;
+    if (is_left) {
+      x_type = call->args[1]->checked_type_;
+    } else {
+      x_type = call->args[0]->checked_type_;
+    }
+
+    if (node_map.count(const_)) {
+      // the other argument is a Constant in this case
+      const ConstantNode* constant = node_map[const_][0].as<ConstantNode>();
+      const OpNode* op = call->op.as<OpNode>();
+      ICHECK(constant);
+      ICHECK(op);
+      if (!CheckConstant(op, constant)) {
+        return post;
+      }
+    }
+
+    if (StructuralEqual()(x_type, pre_type)) {
+      return x;
+    }
+
+    return post;
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(EliminateIdentityRewrite);
 
  private:
-  IRModule mod_;
-  /*! \brief Callbacks for expr simplification */
-  Array<DFPatternCallback> callbacks_;
+  DFPattern x_;
+  DFPattern const_;
 };
 
 Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
-  return ExprSimplifier(mod).Simplify(expr);
+  static Array<DFPatternCallback> callbacks = {ConcretizeZerosLikeRewrite::GetCallback(),

Review comment:
       You may need to comment here if the order is enforced.

##########
File path: src/relay/transforms/simplify_expr.h
##########
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.h
+ * \brief Utility data structures for simplifying Relay expressions.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized
+ * instance of RewriteType. */
+#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType)                    \
+  static DFPatternRewrite* Get() {                                    \
+    static RewriteType rw;                                            \
+    return &rw;                                                       \
+  }                                                                   \
+  static DFPatternCallback GetCallback() {                            \
+    static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \
+    return cb;                                                        \
+  }
+
+/*! \brief A wrapper class defining a rewrite matching a specific pattern. */
+class DFPatternRewrite {
+ public:
+  /*! \brief Returns the rewritten expression. */
+  virtual Expr Callback(const Expr& pre, const Expr& post,
+                        const Map<DFPattern, Array<Expr>>& node_map) const = 0;
+
+  /*! \brief Returns the pattern to be used for matching and rewriting. */
+  inline DFPattern Pattern() const { return pattern_; }
+
+  inline bool RequireType() const { return require_type_; }
+
+  inline DFPatternCallback MakeCallback() const {
+    auto func = [this](TVMArgs args, TVMRetValue* rv) {
+      Expr pre = args[0];
+      Expr post = args[1];
+      Map<DFPattern, Array<Expr>> node_map = args[2];
+      *rv = this->Callback(pre, post, node_map);
+    };
+    return DFPatternCallback(pattern_, PackedFunc(func), require_type_);
+  }
+
+ protected:
+  /*! \brief The pattern for matching and rewriting. */
+  DFPattern pattern_;
+  bool require_type_ = true;

Review comment:
       docstring




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-805976301


   cc @comaniac @yzhliu please help to review this PR


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-811419094


   I'm happy with this. I'll merge this afternoon unless @comaniac objects.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600844682



##########
File path: src/relay/transforms/simplify_expr.h
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.h
+ * \brief Utility data structures for simplifying Relay expressions.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized
+ * instance of RewriteType. */
+#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType)                    \
+  static DFPatternRewrite* Get() {                                    \
+    static RewriteType rw;                                            \
+    return &rw;                                                       \
+  }                                                                   \
+  static DFPatternCallback GetCallback() {                            \
+    static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \
+    return cb;                                                        \
+  }

Review comment:
       Indeed the overhead of initializing and calling is probably negligible compared to running the pass itself, I did this following comments on my previous PR. @comaniac maybe you can comment? In the end I am fine with either way




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600683869



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -22,44 +22,37 @@
  * \brief A pass for simplifying the Relay expression.
  */
 
+#include "simplify_expr.h"
+
 #include <tvm/relay/dataflow_matcher.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/transform.h>
 #include <tvm/runtime/logging.h>
 
+#include <utility>
+
 #include "../op/tensor/transform.h"
 #include "pattern_utils.h"
 
 namespace tvm {
 namespace relay {
 
-class SimplifyPattern {
- public:
-  virtual Expr callback(const Expr& pre, const Expr& post,
-                        const Map<DFPattern, Array<Expr>>& node_map) const = 0;
-
-  DFPattern pattern() const { return pattern_; }
-
- protected:
-  /*! \brief Pattern for rewriting */
-  DFPattern pattern_;
-};
-
 /*!
  * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
  *   and merges into one reshape op.
  */
-class SimplifyReshape : public SimplifyPattern {
+class SimplifyReshape : public DFPatternRewrite {
  public:
   SimplifyReshape() {
     x_ = IsWildcard();
     auto reshape1 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
     auto reshape2 = IsOp("reshape") || IsOp("contrib_reverse_reshape");
     pattern_ = reshape1({reshape2({x_})});
+    require_type_ = true;

Review comment:
       good point, thanks




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600826676



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +248,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});
+    DFPattern sub_id = IsOp("subtract")({x_, zeros_expr});
+    DFPattern div_id = IsOp("divide")({x_, ones_expr});
+
+    pattern_ = add_id || mul_id || sub_id || div_id;
+  }
+
+  bool CheckConstant(const OpNode* op, const ConstantNode* constant) const {
+    if (!IsScalar(GetRef<Expr>(constant))) {
+      return false;
+    }
+    long double value = ToScalar(constant->data);
+    if (op->name == "add" || op->name == "subtract") {
+      return value == 0.0;
+    } else if (op->name == "multiply" || op->name == "divide") {
+      return value == 1.0;
+    }
+    return false;
+  }
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    const CallNode* call = pre.as<CallNode>();
+    ICHECK(call);
+    Type pre_type = pre->checked_type_;
+    ICHECK(pre_type.as<TensorTypeNode>());
+    auto x = node_map[x_][0];
+    bool is_left = post.as<CallNode>()->args[1] == x;
+    Type x_type;
+    if (is_left) {
+      x_type = call->args[1]->checked_type_;
+    } else {
+      x_type = call->args[0]->checked_type_;
+    }
+
+    if (node_map.count(const_)) {
+      // the other argument is a Constant in this case
+      const ConstantNode* constant = node_map[const_][0].as<ConstantNode>();
+      const OpNode* op = call->op.as<OpNode>();
+      ICHECK(constant);
+      ICHECK(op);
+      if (!CheckConstant(op, constant)) {
+        return post;
+      }
+    }
+
+    if (StructuralEqual()(x_type, pre_type)) {
+      return x;
+    }
+
+    return post;
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(EliminateIdentityRewrite);
 
  private:
-  IRModule mod_;
-  /*! \brief Callbacks for expr simplification */
-  Array<DFPatternCallback> callbacks_;
+  DFPattern x_;
+  DFPattern const_;
 };
 
 Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
-  return ExprSimplifier(mod).Simplify(expr);
+  static Array<DFPatternCallback> callbacks = {ConcretizeZerosLikeRewrite::GetCallback(),

Review comment:
       (in any case it shouldn't matter now that I've added support for eliminating constants)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806077018


   @comaniac @mbrookhart I've merged them and updated the unit tests


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600785922



##########
File path: src/relay/transforms/concretize_like.cc
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file concretize_like.cc
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
+ */
+
+#include <tvm/relay/transform.h>
+
+#include "pattern_utils.h"
+#include "simplify_expr.h"
+
+namespace tvm {
+namespace relay {
+
+class ConcretizeLikeRewrite : public DFPatternRewrite {
+ public:
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
+    require_type_ = true;
+  }
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type_.defined()) {
+      // TODO(@altanh): maybe because of the input being rewritten?

Review comment:
       I ended up removing this, I think the checked type should always be defined for the `pre` node when `require_type` is true. Previously I was getting the type from a different node which was the wrong approach, now this works.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806004438


   > Overall LGTM. Meanwhile I have some questions about the design:
   > 
   >     1. It seems to me that ConcretizeLike and EliminateIdentity can also be merged to SImplifyExpr in terms of the implementation and semantic. What's the concern of having 3 separate passes?
   > 
   >     2. You mentioned that for a certain reason, EliminateIdentity should be run before SimplifyExpr, but I didn't get the point about what would happen if we run them in the reverse order. Could you elaborate a bit further?
   
   1. It is definitely possible- I separated them mainly for ability to test them separately, as otherwise the overall semantics of the combined pass might be a bit tricky to write test cases for (e.g. will need to adjust the cases where we are adding 0 or multiplying by 1). I can definitely add additional test cases that run all of them in sequence (as if it was 1 single pass), or just try to merge them into SimplifyExpr and update the test cases. lmk
   2. Yeah, so SimplifyExpr has a rewrite called `FullElementwise` that takes (for example) `x + zeros_like(x)` and rewrites it to `x + const(0)`. I couldn't think of a portable way to rewrite `x + const(0)` to `x` in `EliminateIdentity`, so it won't reduce this expression. For this reason you should run `EliminateIdentity` first- hope this makes sense. That being said, if there is a good way to examine constant values for any dtype (e.g. casting?) then we could also eliminate this.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-807105375


   > @altanh I'm confused why you need this change to ToScalar. What changed elsewhere in your PR that broke this unit test?
   
   here's the offending test https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/PR-7731/7/pipeline
   
   the alternative I see to changing `ToScalar` is basically just mirroring the code exactly but instead returning a bool if there is a way to convert, which seems less sustainable but perhaps better in the name of API stability.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806008203


   If the purpose is just for testing, then I'll prefer to have them in a single pass. You can still test the pattern one-by-one as SimplifyExpr does now. Since the unrelated patterns won't be matched, I didn't see the problem of testing. In this case, you can also control the order of rewriting patterns. i.e., always run `EliminateIdentity` before `FullElementWise` in the SimplifyExpr pass. This can also reduce the possible confusion from users.
    


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600863750



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +249,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});

Review comment:
       You don't have to do the full thing for this PR, you can keep this as it is and we can extend later




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600865480



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +249,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});

Review comment:
       yep sounds good




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600863124



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +249,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});

Review comment:
       Yep, you do, sorry!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600836110



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +249,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});

Review comment:
       Can we also simplify `x * 0 = 0` and `x + 0 = x`?
   
   The pattern matcher should be able to match these ops irrespective of order:
   https://github.com/apache/tvm/blob/cfe2e288a331b10e72e10c7e465df375b44e6ae9/src/relay/ir/dataflow_matcher.cc#L275-L281
   
   You can probably get away without the AltPattern here?

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +248,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});
+    DFPattern sub_id = IsOp("subtract")({x_, zeros_expr});
+    DFPattern div_id = IsOp("divide")({x_, ones_expr});
+
+    pattern_ = add_id || mul_id || sub_id || div_id;
+  }
+
+  bool CheckConstant(const OpNode* op, const ConstantNode* constant) const {
+    if (!IsScalar(GetRef<Expr>(constant))) {
+      return false;
+    }
+    long double value = ToScalar(constant->data);
+    if (op->name == "add" || op->name == "subtract") {
+      return value == 0.0;
+    } else if (op->name == "multiply" || op->name == "divide") {
+      return value == 1.0;
+    }
+    return false;
+  }
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    const CallNode* call = pre.as<CallNode>();
+    ICHECK(call);
+    Type pre_type = pre->checked_type_;
+    ICHECK(pre_type.as<TensorTypeNode>());
+    auto x = node_map[x_][0];
+    bool is_left = post.as<CallNode>()->args[1] == x;
+    Type x_type;
+    if (is_left) {
+      x_type = call->args[1]->checked_type_;
+    } else {
+      x_type = call->args[0]->checked_type_;
+    }
+
+    if (node_map.count(const_)) {
+      // the other argument is a Constant in this case
+      const ConstantNode* constant = node_map[const_][0].as<ConstantNode>();
+      const OpNode* op = call->op.as<OpNode>();
+      ICHECK(constant);
+      ICHECK(op);
+      if (!CheckConstant(op, constant)) {
+        return post;
+      }
+    }
+
+    if (StructuralEqual()(x_type, pre_type)) {
+      return x;
+    }
+
+    return post;
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(EliminateIdentityRewrite);
 
  private:
-  IRModule mod_;
-  /*! \brief Callbacks for expr simplification */
-  Array<DFPatternCallback> callbacks_;
+  DFPattern x_;
+  DFPattern const_;
 };
 
 Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
-  return ExprSimplifier(mod).Simplify(expr);
+  static Array<DFPatternCallback> callbacks = {ConcretizeZerosLikeRewrite::GetCallback(),

Review comment:
       Correct, they will do as much as they can to the graph in this order, and then loop back and try again. That being said, that's the same thing you'd get if you ran them as separate passes, you'd rewrite everything you could before the next pass, but the next pass might open opportunities to do more with the current pass if you ran it again.

##########
File path: src/relay/transforms/simplify_expr.h
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.h
+ * \brief Utility data structures for simplifying Relay expressions.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized
+ * instance of RewriteType. */
+#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType)                    \
+  static DFPatternRewrite* Get() {                                    \
+    static RewriteType rw;                                            \
+    return &rw;                                                       \
+  }                                                                   \
+  static DFPatternCallback GetCallback() {                            \
+    static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \
+    return cb;                                                        \
+  }

Review comment:
       1) I don't like Macros
   2) I don't like static initialization
   
   Why not just initialize the object and call the method?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-811453421


   Sorry I missed the previous message. Yeah I'm good with it so I've merged it. Thanks @altanh @mbrookhart 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806028276


   Ah, well looks like this might only make sense for constants that only have one element, unless we want to loop over every single element and check that it is equal to 0 or 1. But if I understand correctly, the `FullElementwise` pass only rewrites to scalar constants so those will be rewritten correctly; it's just that if the input IR has non-scalar constants that it won't be simplified.
   
   Do you guys think this is a reasonable tradeoff?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806017937


   I think I agree that merging these with SimplifyExpr would be a win in terms of our ability to control the order of execution.
   
   On the simplification of things like `x * const(0)`, you can get the ndarray out of the const as ConstantNode->data, and then you can pass that to [this utility](https://github.com/apache/tvm/blob/63d8e97dfbe046e70c91c72cbbf7da8646824217/src/relay/transforms/pattern_utils.h#L385), which will return a `long double` version of the value, which shouldn't loose precision for any of the 64 bit or smaller datatypes we use. You can then do your comparison in a single dtype.
   
   I've been meaning to implement this for like 6 months, and I haven't had a strong enough forcing function to bubble it up to the top of my priority list.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600824893



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +248,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});
+    DFPattern sub_id = IsOp("subtract")({x_, zeros_expr});
+    DFPattern div_id = IsOp("divide")({x_, ones_expr});
+
+    pattern_ = add_id || mul_id || sub_id || div_id;
+  }
+
+  bool CheckConstant(const OpNode* op, const ConstantNode* constant) const {
+    if (!IsScalar(GetRef<Expr>(constant))) {
+      return false;
+    }
+    long double value = ToScalar(constant->data);
+    if (op->name == "add" || op->name == "subtract") {
+      return value == 0.0;
+    } else if (op->name == "multiply" || op->name == "divide") {
+      return value == 1.0;
+    }
+    return false;
+  }
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    const CallNode* call = pre.as<CallNode>();
+    ICHECK(call);
+    Type pre_type = pre->checked_type_;
+    ICHECK(pre_type.as<TensorTypeNode>());
+    auto x = node_map[x_][0];
+    bool is_left = post.as<CallNode>()->args[1] == x;
+    Type x_type;
+    if (is_left) {
+      x_type = call->args[1]->checked_type_;
+    } else {
+      x_type = call->args[0]->checked_type_;
+    }
+
+    if (node_map.count(const_)) {
+      // the other argument is a Constant in this case
+      const ConstantNode* constant = node_map[const_][0].as<ConstantNode>();
+      const OpNode* op = call->op.as<OpNode>();
+      ICHECK(constant);
+      ICHECK(op);
+      if (!CheckConstant(op, constant)) {
+        return post;
+      }
+    }
+
+    if (StructuralEqual()(x_type, pre_type)) {
+      return x;
+    }
+
+    return post;
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(EliminateIdentityRewrite);
 
  private:
-  IRModule mod_;
-  /*! \brief Callbacks for expr simplification */
-  Array<DFPatternCallback> callbacks_;
+  DFPattern x_;
+  DFPattern const_;
 };
 
 Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
-  return ExprSimplifier(mod).Simplify(expr);
+  static Array<DFPatternCallback> callbacks = {ConcretizeZerosLikeRewrite::GetCallback(),

Review comment:
       cc @mbrookhart , I believe the ordering is respected by the rewriter but because the rewriter iterates until fixed point, I don't think it would be correct to say that globally the order is enforced. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] comaniac commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
comaniac commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-805995681


   Also cc @mbrookhart 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600860158



##########
File path: src/relay/transforms/simplify_expr.h
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/simplify_expr.h
+ * \brief Utility data structures for simplifying Relay expressions.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+#define TVM_RELAY_TRANSFORMS_SIMPLIFY_EXPR_H_
+
+#include <tvm/relay/dataflow_matcher.h>
+#include <tvm/relay/expr.h>
+
+#include <vector>
+
+namespace tvm {
+namespace relay {
+
+/*! \brief Defines a static function `RewriteType::Get()` that returns a statically initialized
+ * instance of RewriteType. */
+#define TVM_DF_PATTERN_REWRITE_GETTER(RewriteType)                    \
+  static DFPatternRewrite* Get() {                                    \
+    static RewriteType rw;                                            \
+    return &rw;                                                       \
+  }                                                                   \
+  static DFPatternCallback GetCallback() {                            \
+    static DFPatternCallback cb = RewriteType::Get()->MakeCallback(); \
+    return cb;                                                        \
+  }

Review comment:
       got it, I'll remove this, thanks for the clarification




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on a change in pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
altanh commented on a change in pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#discussion_r600858326



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -249,36 +249,214 @@ class FullElementwise : public SimplifyPattern {
 };
 
 /*!
- * \brief ExprSimplifier simplifies the Relay expression.
+ * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
+ * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
+ * and can enable more opportunities for operator fusion.
  */
-class ExprSimplifier {
+class ConcretizeLikeRewrite : public DFPatternRewrite {
  public:
-  explicit ExprSimplifier(IRModule mod) : mod_(mod) {
-    CreateCallback(SimplifyReshape());
-    CreateCallback(SimplifyTranspose());
-    CreateCallback(FullElementwise());
+  explicit ConcretizeLikeRewrite(const Op& op) {
+    ICHECK(op->num_inputs == 1 || op->num_inputs == 2)
+        << "ConcretizeLike does not handle operators that aren't unary or binary, got: " << op;
+    like_pat_ = IsWildcard();
+    data_pat_ = IsWildcard();
+    if (op->num_inputs == 1) {
+      pattern_ = IsExpr(op)({like_pat_});
+    } else {
+      pattern_ = IsExpr(op)({data_pat_, like_pat_});
+    }
   }
-  template <typename T>
-  void CreateCallback(const T& pattern) {
-    auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
-      Expr pre = args[0];
-      Expr post = args[1];
-      Map<DFPattern, Array<Expr>> node_map = args[2];
-      *rv = pattern.callback(pre, post, node_map);
-    };
-    callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
+
+  virtual bool Check(const Expr& pre, const Expr& post,
+                     const Map<DFPattern, Array<Expr>>& node_map) const {
+    const CallNode* call_node = pre.as<CallNode>();
+    ICHECK(call_node);
+
+    if (!call_node->checked_type().as<TensorTypeNode>()) {
+      return false;
+    }
+
+    return true;
+  }
+
+  virtual Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                          DataType dtype) const = 0;
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    if (!Check(pre, post, node_map)) {
+      return post;
+    }
+
+    const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
+    Array<Integer> cshape;
+    for (const auto& dim : like_ty->shape) {
+      if (const auto* imm = dim.as<IntImmNode>()) {
+        cshape.push_back(Integer(GetRef<IntImm>(imm)));
+      } else {
+        // shape is not static, don't concretize
+        return post;
+      }
+    }
+
+    return Concretize(node_map, cshape, like_ty->dtype);
+  }
+
+ protected:
+  DFPattern data_pat_;
+  DFPattern like_pat_;
+};
+
+class ConcretizeZerosLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeZerosLikeRewrite() : ConcretizeLikeRewrite(Op::Get("zeros_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeZeros(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeZerosLikeRewrite);
+};
+
+class ConcretizeOnesLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeOnesLikeRewrite() : ConcretizeLikeRewrite(Op::Get("ones_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeOnes(shape, dtype);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeOnesLikeRewrite);
+};
+
+class ConcretizeReshapeLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeReshapeLikeRewrite() : ConcretizeLikeRewrite(Op::Get("reshape_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeReshape(node_map[data_pat_][0], shape);
   }
 
-  Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeReshapeLikeRewrite);
+};
+
+class ConcretizeCollapseSumLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeCollapseSumLikeRewrite() : ConcretizeLikeRewrite(Op::Get("collapse_sum_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
+    static const Op& op = Op::Get("collapse_sum_to");
+    auto attrs = make_object<InitOpAttrs>();
+    attrs->shape = shape;
+    auto cshape =
+        MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
+    return Call(op, {node_map[data_pat_][0], cshape}, Attrs(attrs));
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeCollapseSumLikeRewrite);
+};
+
+class ConcretizeBroadcastToLikeRewrite : public ConcretizeLikeRewrite {
+ public:
+  ConcretizeBroadcastToLikeRewrite() : ConcretizeLikeRewrite(Op::Get("broadcast_to_like")) {}
+
+  Expr Concretize(const Map<DFPattern, Array<Expr>>& node_map, Array<Integer> shape,
+                  DataType dtype) const override {
+    return MakeBroadCastTo(node_map[data_pat_][0], shape);
+  }
+
+  TVM_DF_PATTERN_REWRITE_GETTER(ConcretizeBroadcastToLikeRewrite);
+};
+
+/*! \brief Eliminates expressions that are equivalent to identity. */
+class EliminateIdentityRewrite : public DFPatternRewrite {
+ public:
+  EliminateIdentityRewrite() {
+    x_ = IsWildcard();
+    const_ = IsConstant();
+
+    DFPattern add_op = IsOp("add");
+    DFPattern mul_op = IsOp("multiply");
+    DFPattern zeros_expr = IsOp("zeros")({}) || IsOp("zeros_like")({IsWildcard()}) || const_;
+    DFPattern ones_expr = IsOp("ones")({}) || IsOp("ones_like")({IsWildcard()}) || const_;
+
+    DFPattern add_id = add_op({x_, zeros_expr}) || add_op({zeros_expr, x_});
+    DFPattern mul_id = mul_op({x_, ones_expr}) || mul_op({ones_expr, x_});

Review comment:
       I can add x * 0 = 0 although to me it semantically doesn't exactly fit and would need slightly different logic, perhaps a new rewrite `ZeroMultiply` or something?
   
   And I wasn't aware of the commutative matching, that's helpful (although I wonder if it should be more visibly defined somewhere?)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] mbrookhart commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity rewrites for SimplifyExpr

Posted by GitBox <gi...@apache.org>.
mbrookhart commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-807086802


   @altanh I'm confused why you need this change to ToScalar. What changed elsewhere in your PR that broke this unit test?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] altanh commented on pull request #7731: [Relay][Pass] ConcretizeLike and EliminateIdentity Passes

Posted by GitBox <gi...@apache.org>.
altanh commented on pull request #7731:
URL: https://github.com/apache/tvm/pull/7731#issuecomment-806019388


   > I think I agree that merging these with SimplifyExpr would be a win in terms of our ability to control the order of execution.
   > 
   > On the simplification of things like `x * const(0)`, you can get the ndarray out of the const as ConstantNode->data, and then you can pass that to [this utility](https://github.com/apache/tvm/blob/63d8e97dfbe046e70c91c72cbbf7da8646824217/src/relay/transforms/pattern_utils.h#L385), which will return a `long double` version of the value, which shouldn't loose precision for any of the 64 bit or smaller datatypes we use. You can then do your comparison in a single dtype.
   > 
   > I've been meaning to implement this for like 6 months, and I haven't had a strong enough forcing function to bubble it up to the top of my priority list.
   
   This is just what I needed, thanks! I think since I'm also just checking 0 or 1, there should be no problem casting since the floating point repr should all be the same.
   
   I'll fix the Windows issue, it was because I overloaded the same name too much.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org