You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/05/04 08:06:07 UTC

[tvm] branch main updated: [Relay][Pass] Update SimplifyTranspose to correctly simplify rank changing layout transforms (#7807)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 396a09e  [Relay][Pass] Update SimplifyTranspose to correctly simplify rank changing layout transforms (#7807)
396a09e is described below

commit 396a09e06441024f5b95dcf6762745368cf9d8e6
Author: Chris Sullivan <cs...@octoml.ai>
AuthorDate: Tue May 4 01:05:41 2021 -0700

    [Relay][Pass] Update SimplifyTranspose to correctly simplify rank changing layout transforms (#7807)
---
 src/relay/transforms/simplify_expr.cc         | 175 +++++++++++++++++++++-----
 tests/python/relay/test_pass_simplify_expr.py | 166 ++++++++++++++++++++++++
 2 files changed, 310 insertions(+), 31 deletions(-)

diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index 5662ef5..fb7a76f 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -31,6 +31,8 @@
 #include <tvm/runtime/logging.h>
 
 #include <limits>
+#include <memory>
+#include <string>
 #include <utility>
 
 #include "../op/tensor/transform.h"
@@ -117,36 +119,20 @@ class SimplifyTranspose : public DFPatternRewrite {
 
   Expr Callback(const Expr& pre, const Expr& post,
                 const Map<DFPattern, Array<Expr>>& node_map) const override {
-    // Helper function to get the axes from call node attribute
-    auto get_axes_from_call = [](const Call trans_call, int ndim) {
-      std::vector<int> attr_axes;
-      if (auto attr = trans_call->attrs.as<TransposeAttrs>()) {
-        if (attr->axes.defined()) {
-          for (int i = 0; i < ndim; ++i) {
-            int64_t axis = attr->axes[i];
-            axis += (axis < 0) ? ndim : 0;
-            attr_axes.push_back(axis);
-          }
-        } else {
-          // Empty axes means reverse
-          for (int i = ndim - 1; i >= 0; --i) {
-            attr_axes.push_back(i);
-          }
-        }
-      } else if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
-        Layout src_layout(attr->src_layout);
-        Layout dst_layout(attr->dst_layout);
-        for (int i = 0; i < ndim; ++i) {
-          attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
+    auto x = node_map[x_][0];
+
+    Call trans_call = Downcast<Call>(post);
+
+    // Try to fuse any rank changing layout transformations
+    if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) {
+      if (auto attr = layout_trans.value()->attrs.as<LayoutTransformAttrs>()) {
+        // Prune any trivial layout transformation
+        if (attr->src_layout == attr->dst_layout) {
+          return x;
         }
-      } else {
-        CHECK(false) << "Expected transpose or layout_transform, but got "
-                     << Downcast<Op>(trans_call->op)->name;
       }
-      return std::move(attr_axes);
-    };
-
-    auto x = node_map[x_][0];
+      return layout_trans.value();
+    }
 
     // Initialize axes
     int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
@@ -157,10 +143,9 @@ class SimplifyTranspose : public DFPatternRewrite {
 
     // Collect axes changes from the matched pattern, including two consecutive transposes.
     std::vector<std::vector<int>> interm_axes;
-    Call trans_call = Downcast<Call>(post);
-    interm_axes.push_back(get_axes_from_call(trans_call, ndim));
+    interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));
     trans_call = Downcast<Call>(trans_call->args[0]);
-    interm_axes.push_back(get_axes_from_call(trans_call, ndim));
+    interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));
 
     // Calculate the final axes in reverse order (from root to output)
     auto it = interm_axes.rbegin();
@@ -190,6 +175,134 @@ class SimplifyTranspose : public DFPatternRewrite {
     return x;
   }
 
+  String PermuteLayout(const String& layout, std::vector<int> axes_order) const {
+    std::string new_layout{};
+    std::string old_layout{layout};
+    ICHECK_EQ(axes_order.size(), layout.size())
+        << "Number of axes must match the number of named axes in the layout to permute: length("
+        << old_layout << ") != " << axes_order.size();
+    std::stringstream order;
+    for (auto axis : axes_order) {
+      new_layout += old_layout[axis];
+      order << axis << ", ";
+    }
+    DLOG(INFO) << "Using transpose axes order {" << order.str()
+               << "} to permute layout: " << old_layout << " to " << new_layout;
+    return new_layout;
+  }
+
+  struct RankChangingLayoutDescriptor {
+    Layout src_layout;
+    Layout dst_layout;
+    // Either a rank changing layout transform or a transpose
+    Call other_transform;
+  };
+
+  std::unique_ptr<RankChangingLayoutDescriptor> GetRankChangeDescriptor(const Call& call) const {
+    std::unique_ptr<RankChangingLayoutDescriptor> desc{nullptr};
+    if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
+      if (attr->src_layout.length() != attr->dst_layout.length()) {
+        desc = std::make_unique<RankChangingLayoutDescriptor>();
+        desc->src_layout = Layout(attr->src_layout);
+        desc->dst_layout = Layout(attr->dst_layout);
+        desc->other_transform = Downcast<Call>(call->args[0]);
+      }
+    }
+    if (auto attr = Downcast<Call>(call->args[0])->attrs.as<LayoutTransformAttrs>()) {
+      if (attr->src_layout.length() != attr->dst_layout.length()) {
+        if (!desc) {
+          desc = std::make_unique<RankChangingLayoutDescriptor>();
+          desc->src_layout = Layout(attr->src_layout);
+          desc->dst_layout = Layout(attr->dst_layout);
+          desc->other_transform = call;
+        } else {
+          ICHECK(desc->src_layout->name == attr->dst_layout)
+              << "Back-to-back layout transforms must have the same intermediate layout: "
+              << desc->src_layout->name << " != " << attr->dst_layout;
+          desc->src_layout = Layout(attr->src_layout);
+        }
+      }
+    }
+    return desc;
+  }
+
+  /*
+   * \brief Fuse call and it's argument into a single layout_transform operator
+   * when either call or it's argument is a rang changing layout_transform, e.g.,
+   *
+   *  Simplify
+   *
+   *  [N, H, W, C] -> Transpose -> [N, C, H, W] -> LayoutTrans -> [N, C, H, W, 4c]
+   *
+   *  to,
+   *
+   *  [N, H, W, C] -> LayoutTrans -> [N, C, H, W, 4c].
+   *
+   * \param The input expression to the matched pattern
+   * \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops
+   */
+  Optional<Call> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const {
+    // Check to see if either the first or second call in matched pattern
+    // is a rank changing layout transform. If so, return a descriptor containing
+    // the layouts and any additional transpose or layout transform op.
+    auto desc = GetRankChangeDescriptor(call);
+    if (desc == nullptr) {
+      // No rank changing layout transform
+      return Optional<Call>{nullptr};
+    }
+
+    Optional<Expr> output_layout_trans;
+    // Fuse a rank increasing layout transform and a preceeding transpose
+    if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) {
+      auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size());
+      // Calculate the reverse axis order and apply to the source layout
+      std::vector<int> inverse(axes.size());
+      for (size_t i = 0; i < axes.size(); i++) {
+        inverse[axes[i]] = i;
+      }
+      String new_layout = PermuteLayout(desc->src_layout->name, inverse);
+      output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name);
+      // Fuse a rank descreasing layout transform followed by a transpose
+    } else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) {
+      auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size());
+      String new_layout = PermuteLayout(desc->dst_layout->name, axes);
+      output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout);
+      // Fuse two back-to-back layout transformations which change rank
+    } else if (desc->other_transform->attrs.as<LayoutTransformAttrs>()) {
+      output_layout_trans =
+          MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name);
+    }
+    return Downcast<Call>(output_layout_trans);
+  }
+
+  std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
+    std::vector<int> attr_axes;
+    if (auto attr = call->attrs.as<TransposeAttrs>()) {
+      if (attr->axes.defined()) {
+        for (int i = 0; i < ndim; ++i) {
+          int64_t axis = attr->axes[i];
+          axis += (axis < 0) ? ndim : 0;
+          attr_axes.push_back(axis);
+        }
+      } else {
+        // Empty axes means reverse
+        for (int i = ndim - 1; i >= 0; --i) {
+          attr_axes.push_back(i);
+        }
+      }
+    } else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
+      Layout src_layout(attr->src_layout);
+      Layout dst_layout(attr->dst_layout);
+      for (int i = 0; i < ndim; ++i) {
+        attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
+      }
+    } else {
+      CHECK(false) << "Expected transpose or layout_transform, but got "
+                   << Downcast<Op>(call->op)->name;
+    }
+    return std::move(attr_axes);
+  }
+
  private:
   /*! \brief Pattern input */
   DFPattern x_;
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index d1dffa3..9f11d38 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -106,10 +106,176 @@ def test_simplify_transpose():
         y = relay.transpose(y, axes=[0, 2, 3, 1])
         return relay.Function([x], y)
 
+    # Test a series of transpose and rank changing layout_transform
+    def before4():
+        """
+        Simplify transpose->layout_transform and its inverse.
+
+        Input:
+        NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC
+
+        Simplified:
+        NHWC -> NCHW4c -> op -> NCHW4c -> NHWC
+        """
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")
+        y = relay.transpose(x, axes=[0, 3, 1, 2])
+        y = relay.layout_transform(y, "NCHW", "NCHW4c")
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW4c", "NCHW")
+        y = relay.transpose(y, axes=[0, 2, 3, 1])
+        return relay.Function([x], y)
+
+    def expected4():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.layout_transform(x, "NHWC", "NCHW4c")  # To NCHW4c
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW4c", "NHWC")  # To NHWC
+        return relay.Function([x], y)
+
+    def before5():
+        """
+        Simplify layout_transform->layout_transform and its inverse.
+
+        Input:
+        NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC
+
+        Simplified:
+        NHWC -> NCHW4c -> op -> NCHW4c -> NHWC
+        """
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.layout_transform(x, "NHWC", "NCHW")  # To NCHW
+        y = relay.layout_transform(y, "NCHW", "NCHW4c")  # To NCHW4c
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW4c", "NCHW")  # To NCHW
+        y = relay.layout_transform(y, "NCHW", "NHWC")  # To NHWC
+        return relay.Function([x], y)
+
+    def expected5():
+        x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")  # NHWC
+        y = relay.layout_transform(x, "NHWC", "NCHW4c")  # To NCHW4c
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW4c", "NHWC")  # To NHWC
+        return relay.Function([x], y)
+
+    def before6():
+        """
+        Remove trivial layout_transform->layout_transform.
+
+        Input:
+        NCHW -> NHWC -> NCHW -> op
+
+        Simplified:
+        NHWC -> op
+        """
+
+        x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+        y = relay.layout_transform(x, "NCHW", "NHWC")
+        y = relay.layout_transform(y, "NHWC", "NCHW")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def expected6():
+        x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+        y = relay.nn.relu(x)
+        return relay.Function([x], y)
+
+    def before7():
+        """
+        Remove trivial layout_transform->layout_transform.
+
+        Input:
+        NCHW4c -> NCHW8c -> NCHW4c -> op
+
+        Simplified:
+        NCHW4c -> op
+        """
+        x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
+        y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
+        y = relay.layout_transform(y, "NCHW8c", "NCHW4c")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def expected7():
+        x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
+        y = relay.nn.relu(x)
+        return relay.Function([x], y)
+
+    def before8():
+        """
+        Simplify layout_transform->layout_transform with rank contraction and expansion
+
+        Input:
+        NCHW4c -> NCHW -> NCHW8c -> op
+
+        Simplified:
+        NCHW4c -> NCHW8c -> op
+        """
+        x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
+        y = relay.layout_transform(x, "NCHW4c", "NCHW")
+        y = relay.layout_transform(y, "NCHW", "NCHW8c")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def expected8():
+        x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
+        y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def before9():
+        """
+        Remove trivial layout_transform->layout_transform.
+
+        Input:
+        NCHW -> NCHW4c -> NCHW -> op
+
+        Simplified:
+        NCHW -> op
+        """
+        x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+        y = relay.layout_transform(x, "NCHW", "NCHW4c")
+        y = relay.layout_transform(y, "NCHW4c", "NCHW")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def expected9():
+        x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+        y = relay.nn.relu(x)
+        return relay.Function([x], y)
+
+    def before10():
+        """
+        Simplify layout_transform->layout_transform without rank change to transpose.
+
+        Input:
+        NCHW -> NHWC -> CHWN -> op
+
+        Simplified:
+        NCHW -> CHWN -> op
+        """
+        x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+        y = relay.layout_transform(x, "NCHW", "NHWC")
+        y = relay.layout_transform(y, "NHWC", "CHWN")
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
+    def expected10():
+        x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+        y = relay.transpose(x, axes=[1, 2, 3, 0])
+        y = relay.nn.relu(y)
+        return relay.Function([x], y)
+
     for before, expected in [
         [before1(), expected1()],
         [before2(), expected2()],
         [before3(), expected3()],
+        [before4(), expected4()],
+        [before5(), expected5()],
+        [before6(), expected6()],
+        [before7(), expected7()],
+        [before8(), expected8()],
+        [before9(), expected9()],
+        [before10(), expected10()],
     ]:
         after = run_opt_pass(before, transform.SimplifyExpr())
         expected = run_opt_pass(expected, transform.InferType())