You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "Lunderberg (via GitHub)" <gi...@apache.org> on 2023/10/06 21:07:40 UTC

[PR] [Unity][Transform] Allow static Relax arguments to dynamic PrimFunc [tvm]

Lunderberg opened a new pull request, #15883:
URL: https://github.com/apache/tvm/pull/15883

   Prior to this commit, the `relax.transform.FuseTIR` transform required that the shapes arguments passed into a `PrimFunc` be structurally equivalent to the shapes of the parameters, and that any replacement of symbolic `tir.Var` be with a symbolic `tir.Var` in the fused function.
   
   This commit updates the `SymbolicMatcher` to instead extract a `Map<tir::Var, PrimExpr>`.  As a result, a Relax tensor with statically-known shape can be passed into a TIR PrimFunc with dynamic shape.  The resulting fused TIR function is in terms of the statically-known shape, and no longer contains the symbolic variable.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Unity][Transform] Allow static Relax arguments to dynamic PrimFunc [tvm]

Posted by "Lunderberg (via GitHub)" <gi...@apache.org>.
Lunderberg merged PR #15883:
URL: https://github.com/apache/tvm/pull/15883


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Unity][Transform] Allow static Relax arguments to dynamic PrimFunc [tvm]

Posted by "Lunderberg (via GitHub)" <gi...@apache.org>.
Lunderberg commented on code in PR #15883:
URL: https://github.com/apache/tvm/pull/15883#discussion_r1349586374


##########
src/relax/transform/fuse_tir.cc:
##########
@@ -39,31 +39,41 @@ namespace tir {
  */
 class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)> {
  public:
-  explicit SymbolicMatcher(Map<tir::Var, tir::Var>* var_remap) : var_remap_(var_remap) {}
+  explicit SymbolicMatcher(Map<tir::Var, PrimExpr>* var_remap) : var_remap_(var_remap) {}
 
-  void Match(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
-    CHECK_EQ(lhs.size(), rhs.size());
-    for (size_t i = 0; i < lhs.size(); ++i) {
-      Match(lhs[i], rhs[i]);
+  void Match(const Array<PrimExpr>& params, const Array<PrimExpr>& args) {
+    CHECK_EQ(params.size(), args.size());
+    for (size_t i = 0; i < params.size(); ++i) {
+      Match(params[i], args[i]);
     }
   }
-  void Match(const PrimExpr& lhs, const PrimExpr& rhs) {
-    if (!VisitExpr(lhs, rhs)) {
-      LOG(FATAL) << "Failed to match PrimExpr " << lhs << " with " << rhs;
+  void Match(const PrimExpr& param, const PrimExpr& arg) {
+    if (!VisitExpr(param, arg)) {
+      LOG(FATAL) << "Failed to match PrimExpr " << param << " with " << arg;
     }
   }
 
  private:
-  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) {
-    bool matched = n.same_as(other) || ((n->type_index() == other->type_index()) &&
-                                        n.dtype().code() == other.dtype().code());
-    return matched && ExprFunctor::VisitExpr(n, other);
+  bool VisitExpr(const PrimExpr& node, const PrimExpr& other) {
+    if (node.same_as(other)) {
+      return true;
+    } else if (node.dtype().code() != other.dtype().code()) {
+      return false;
+    } else {
+      return ExprFunctor::VisitExpr(node, other);
+    }
+    // bool matched = node.same_as(other) || ((node->type_index() == other->type_index()) &&

Review Comment:
   Thank you for the catch, and removed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Unity][Transform] Allow static Relax arguments to dynamic PrimFunc [tvm]

Posted by "masahi (via GitHub)" <gi...@apache.org>.
masahi commented on PR #15883:
URL: https://github.com/apache/tvm/pull/15883#issuecomment-1751490974

   cc @Hzfengsy 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


Re: [PR] [Unity][Transform] Allow static Relax arguments to dynamic PrimFunc [tvm]

Posted by "Hzfengsy (via GitHub)" <gi...@apache.org>.
Hzfengsy commented on code in PR #15883:
URL: https://github.com/apache/tvm/pull/15883#discussion_r1349462821


##########
src/relax/transform/fuse_tir.cc:
##########
@@ -39,31 +39,41 @@ namespace tir {
  */
 class SymbolicMatcher : ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)> {
  public:
-  explicit SymbolicMatcher(Map<tir::Var, tir::Var>* var_remap) : var_remap_(var_remap) {}
+  explicit SymbolicMatcher(Map<tir::Var, PrimExpr>* var_remap) : var_remap_(var_remap) {}
 
-  void Match(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) {
-    CHECK_EQ(lhs.size(), rhs.size());
-    for (size_t i = 0; i < lhs.size(); ++i) {
-      Match(lhs[i], rhs[i]);
+  void Match(const Array<PrimExpr>& params, const Array<PrimExpr>& args) {
+    CHECK_EQ(params.size(), args.size());
+    for (size_t i = 0; i < params.size(); ++i) {
+      Match(params[i], args[i]);
     }
   }
-  void Match(const PrimExpr& lhs, const PrimExpr& rhs) {
-    if (!VisitExpr(lhs, rhs)) {
-      LOG(FATAL) << "Failed to match PrimExpr " << lhs << " with " << rhs;
+  void Match(const PrimExpr& param, const PrimExpr& arg) {
+    if (!VisitExpr(param, arg)) {
+      LOG(FATAL) << "Failed to match PrimExpr " << param << " with " << arg;
     }
   }
 
  private:
-  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) {
-    bool matched = n.same_as(other) || ((n->type_index() == other->type_index()) &&
-                                        n.dtype().code() == other.dtype().code());
-    return matched && ExprFunctor::VisitExpr(n, other);
+  bool VisitExpr(const PrimExpr& node, const PrimExpr& other) {
+    if (node.same_as(other)) {
+      return true;
+    } else if (node.dtype().code() != other.dtype().code()) {
+      return false;
+    } else {
+      return ExprFunctor::VisitExpr(node, other);
+    }
+    // bool matched = node.same_as(other) || ((node->type_index() == other->type_index()) &&

Review Comment:
   Please remove useless lines



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org