You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/04/07 19:13:03 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7807: [Relay][Pass] Update SimplifyTranspose to correctly simplify rank changing layout transforms

comaniac commented on a change in pull request #7807:
URL: https://github.com/apache/tvm/pull/7807#discussion_r608943282



##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -91,36 +91,15 @@ class SimplifyTranspose : public DFPatternRewrite {
   Expr Callback(const Expr& pre, const Expr& post,
                 const Map<DFPattern, Array<Expr>>& node_map) const override {
     // Helper function to get the axes from call node attribute
-    auto get_axes_from_call = [](const Call trans_call, int ndim) {
-      std::vector<int> attr_axes;
-      if (auto attr = trans_call->attrs.as<TransposeAttrs>()) {
-        if (attr->axes.defined()) {
-          for (int i = 0; i < ndim; ++i) {
-            int64_t axis = attr->axes[i];
-            axis += (axis < 0) ? ndim : 0;
-            attr_axes.push_back(axis);
-          }
-        } else {
-          // Empty axes means reverse
-          for (int i = ndim - 1; i >= 0; --i) {
-            attr_axes.push_back(i);
-          }
-        }
-      } else if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
-        Layout src_layout(attr->src_layout);
-        Layout dst_layout(attr->dst_layout);
-        for (int i = 0; i < ndim; ++i) {
-          attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
-        }
-      } else {
-        CHECK(false) << "Expected transpose or layout_transform, but got "
-                     << Downcast<Op>(trans_call->op)->name;
-      }
-      return std::move(attr_axes);
-    };
 
     auto x = node_map[x_][0];
 
+    Call trans_call = Downcast<Call>(post);
+
+    if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) {

Review comment:
       Add comments.

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -91,36 +91,15 @@ class SimplifyTranspose : public DFPatternRewrite {
   Expr Callback(const Expr& pre, const Expr& post,
                 const Map<DFPattern, Array<Expr>>& node_map) const override {
     // Helper function to get the axes from call node attribute

Review comment:
       Remove this comment as there's no helper function defined here anymore.

##########
File path: src/relay/transforms/type_infer.cc
##########
@@ -741,7 +741,6 @@ class TypeInferencer::Resolver : public MixedModeMutator, PatternMutator {
 Expr TypeInferencer::Infer(GlobalVar var, Function function) {
   // Set the current function being type checked.
   this->current_func_ = var;
-

Review comment:
       Mis-deleted?

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -163,6 +141,69 @@ class SimplifyTranspose : public DFPatternRewrite {
     return x;
   }
 
+  String PermuteLayout(const String& layout, std::vector<int> axes) const {
+    std::string new_layout{};
+    std::string old_layout{layout};
+    for (auto axis : axes) {
+      new_layout += old_layout[axis];
+    }
+    return String(new_layout);
+  }
+
+  Optional<Expr> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const {
+    Optional<Expr> layout_trans;
+    if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
+      Layout src_layout(attr->src_layout);
+      Layout dst_layout(attr->dst_layout);
+      if (src_layout->axes.size() != dst_layout->axes.size()) {

Review comment:
       It would be better to lift this logic to the beginning of this function so that people can clearly know this function only dealing with the case that source and destination layout axis numbers are mismatching.

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -163,6 +141,69 @@ class SimplifyTranspose : public DFPatternRewrite {
     return x;
   }
 
+  String PermuteLayout(const String& layout, std::vector<int> axes) const {
+    std::string new_layout{};
+    std::string old_layout{layout};
+    for (auto axis : axes) {
+      new_layout += old_layout[axis];
+    }
+    return String(new_layout);
+  }
+
+  Optional<Expr> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const {
+    Optional<Expr> layout_trans;
+    if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
+      Layout src_layout(attr->src_layout);
+      Layout dst_layout(attr->dst_layout);
+      if (src_layout->axes.size() != dst_layout->axes.size()) {
+        auto axes = GetTransposeAxisOrder(Downcast<Call>(call->args[0]), src_layout->axes.size());
+        std::vector<int> inverse(axes.size());
+        for (size_t i = 0; i < axes.size(); i++) {
+          inverse[axes[i]] = i;
+        }
+        String new_layout = PermuteLayout(attr->src_layout, inverse);
+        layout_trans = MakeLayoutTransform(data, new_layout, dst_layout->name);
+      }
+    } else if (auto attr = Downcast<Call>(call->args[0])->attrs.as<LayoutTransformAttrs>()) {

Review comment:
       I'm not sure if `call->args[0]->attrs` is guaranteed to be available. May need examples for `if` and `else if` here to help understand the handled cases.

##########
File path: src/relay/transforms/simplify_expr.cc
##########
@@ -163,6 +141,69 @@ class SimplifyTranspose : public DFPatternRewrite {
     return x;
   }
 
+  String PermuteLayout(const String& layout, std::vector<int> axes) const {
+    std::string new_layout{};
+    std::string old_layout{layout};
+    for (auto axis : axes) {
+      new_layout += old_layout[axis];
+    }
+    return String(new_layout);
+  }
+
+  Optional<Expr> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const {

Review comment:
       Please add comments to explain each step in this function.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org