You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/05/04 08:06:07 UTC
[tvm] branch main updated: [Relay][Pass] Update SimplifyTranspose
to correctly simplify rank changing layout transforms (#7807)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 396a09e [Relay][Pass] Update SimplifyTranspose to correctly simplify rank changing layout transforms (#7807)
396a09e is described below
commit 396a09e06441024f5b95dcf6762745368cf9d8e6
Author: Chris Sullivan <cs...@octoml.ai>
AuthorDate: Tue May 4 01:05:41 2021 -0700
[Relay][Pass] Update SimplifyTranspose to correctly simplify rank changing layout transforms (#7807)
---
src/relay/transforms/simplify_expr.cc | 175 +++++++++++++++++++++-----
tests/python/relay/test_pass_simplify_expr.py | 166 ++++++++++++++++++++++++
2 files changed, 310 insertions(+), 31 deletions(-)
diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc
index 5662ef5..fb7a76f 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -31,6 +31,8 @@
#include <tvm/runtime/logging.h>
#include <limits>
+#include <memory>
+#include <string>
#include <utility>
#include "../op/tensor/transform.h"
@@ -117,36 +119,20 @@ class SimplifyTranspose : public DFPatternRewrite {
Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const override {
- // Helper function to get the axes from call node attribute
- auto get_axes_from_call = [](const Call trans_call, int ndim) {
- std::vector<int> attr_axes;
- if (auto attr = trans_call->attrs.as<TransposeAttrs>()) {
- if (attr->axes.defined()) {
- for (int i = 0; i < ndim; ++i) {
- int64_t axis = attr->axes[i];
- axis += (axis < 0) ? ndim : 0;
- attr_axes.push_back(axis);
- }
- } else {
- // Empty axes means reverse
- for (int i = ndim - 1; i >= 0; --i) {
- attr_axes.push_back(i);
- }
- }
- } else if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
- Layout src_layout(attr->src_layout);
- Layout dst_layout(attr->dst_layout);
- for (int i = 0; i < ndim; ++i) {
- attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
+ auto x = node_map[x_][0];
+
+ Call trans_call = Downcast<Call>(post);
+
+ // Try to fuse any rank changing layout transformations
+ if (auto layout_trans = FoldRankChangingLayoutTrans(x, trans_call)) {
+ if (auto attr = layout_trans.value()->attrs.as<LayoutTransformAttrs>()) {
+ // Prune any trivial layout transformation
+ if (attr->src_layout == attr->dst_layout) {
+ return x;
}
- } else {
- CHECK(false) << "Expected transpose or layout_transform, but got "
- << Downcast<Op>(trans_call->op)->name;
}
- return std::move(attr_axes);
- };
-
- auto x = node_map[x_][0];
+ return layout_trans.value();
+ }
// Initialize axes
int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
@@ -157,10 +143,9 @@ class SimplifyTranspose : public DFPatternRewrite {
// Collect axes changes from the matched pattern, including two consecutive transposes.
std::vector<std::vector<int>> interm_axes;
- Call trans_call = Downcast<Call>(post);
- interm_axes.push_back(get_axes_from_call(trans_call, ndim));
+ interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));
trans_call = Downcast<Call>(trans_call->args[0]);
- interm_axes.push_back(get_axes_from_call(trans_call, ndim));
+ interm_axes.push_back(GetTransposeAxisOrder(trans_call, ndim));
// Calculate the final axes in reverse order (from root to output)
auto it = interm_axes.rbegin();
@@ -190,6 +175,134 @@ class SimplifyTranspose : public DFPatternRewrite {
return x;
}
+ String PermuteLayout(const String& layout, std::vector<int> axes_order) const {
+ std::string new_layout{};
+ std::string old_layout{layout};
+ ICHECK_EQ(axes_order.size(), layout.size())
+ << "Number of axes must match the number of named axes in the layout to permute: length("
+ << old_layout << ") != " << axes_order.size();
+ std::stringstream order;
+ for (auto axis : axes_order) {
+ new_layout += old_layout[axis];
+ order << axis << ", ";
+ }
+ DLOG(INFO) << "Using transpose axes order {" << order.str()
+ << "} to permute layout: " << old_layout << " to " << new_layout;
+ return new_layout;
+ }
+
+ struct RankChangingLayoutDescriptor {
+ Layout src_layout;
+ Layout dst_layout;
+ // Either a rank changing layout transform or a transpose
+ Call other_transform;
+ };
+
+ std::unique_ptr<RankChangingLayoutDescriptor> GetRankChangeDescriptor(const Call& call) const {
+ std::unique_ptr<RankChangingLayoutDescriptor> desc{nullptr};
+ if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
+ if (attr->src_layout.length() != attr->dst_layout.length()) {
+ desc = std::make_unique<RankChangingLayoutDescriptor>();
+ desc->src_layout = Layout(attr->src_layout);
+ desc->dst_layout = Layout(attr->dst_layout);
+ desc->other_transform = Downcast<Call>(call->args[0]);
+ }
+ }
+ if (auto attr = Downcast<Call>(call->args[0])->attrs.as<LayoutTransformAttrs>()) {
+ if (attr->src_layout.length() != attr->dst_layout.length()) {
+ if (!desc) {
+ desc = std::make_unique<RankChangingLayoutDescriptor>();
+ desc->src_layout = Layout(attr->src_layout);
+ desc->dst_layout = Layout(attr->dst_layout);
+ desc->other_transform = call;
+ } else {
+ ICHECK(desc->src_layout->name == attr->dst_layout)
+ << "Back-to-back layout transforms must have the same intermediate layout: "
+ << desc->src_layout->name << " != " << attr->dst_layout;
+ desc->src_layout = Layout(attr->src_layout);
+ }
+ }
+ }
+ return desc;
+ }
+
+ /*
+ * \brief Fuse call and it's argument into a single layout_transform operator
+ * when either call or it's argument is a rang changing layout_transform, e.g.,
+ *
+ * Simplify
+ *
+ * [N, H, W, C] -> Transpose -> [N, C, H, W] -> LayoutTrans -> [N, C, H, W, 4c]
+ *
+ * to,
+ *
+ * [N, H, W, C] -> LayoutTrans -> [N, C, H, W, 4c].
+ *
+ * \param The input expression to the matched pattern
+ * \param The pattern root; the second of two consecutive Transpose/LayoutTransform ops
+ */
+ Optional<Call> FoldRankChangingLayoutTrans(const Expr& data, const Call& call) const {
+ // Check to see if either the first or second call in matched pattern
+ // is a rank changing layout transform. If so, return a descriptor containing
+ // the layouts and any additional transpose or layout transform op.
+ auto desc = GetRankChangeDescriptor(call);
+ if (desc == nullptr) {
+ // No rank changing layout transform
+ return Optional<Call>{nullptr};
+ }
+
+ Optional<Expr> output_layout_trans;
+ // Fuse a rank increasing layout transform and a preceeding transpose
+ if (desc->src_layout->axes.size() < desc->dst_layout->axes.size()) {
+ auto axes = GetTransposeAxisOrder(desc->other_transform, desc->src_layout->axes.size());
+ // Calculate the reverse axis order and apply to the source layout
+ std::vector<int> inverse(axes.size());
+ for (size_t i = 0; i < axes.size(); i++) {
+ inverse[axes[i]] = i;
+ }
+ String new_layout = PermuteLayout(desc->src_layout->name, inverse);
+ output_layout_trans = MakeLayoutTransform(data, new_layout, desc->dst_layout->name);
+ // Fuse a rank descreasing layout transform followed by a transpose
+ } else if (desc->src_layout->axes.size() > desc->dst_layout->axes.size()) {
+ auto axes = GetTransposeAxisOrder(desc->other_transform, desc->dst_layout->axes.size());
+ String new_layout = PermuteLayout(desc->dst_layout->name, axes);
+ output_layout_trans = MakeLayoutTransform(data, desc->src_layout->name, new_layout);
+ // Fuse two back-to-back layout transformations which change rank
+ } else if (desc->other_transform->attrs.as<LayoutTransformAttrs>()) {
+ output_layout_trans =
+ MakeLayoutTransform(data, desc->src_layout->name, desc->dst_layout->name);
+ }
+ return Downcast<Call>(output_layout_trans);
+ }
+
+ std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
+ std::vector<int> attr_axes;
+ if (auto attr = call->attrs.as<TransposeAttrs>()) {
+ if (attr->axes.defined()) {
+ for (int i = 0; i < ndim; ++i) {
+ int64_t axis = attr->axes[i];
+ axis += (axis < 0) ? ndim : 0;
+ attr_axes.push_back(axis);
+ }
+ } else {
+ // Empty axes means reverse
+ for (int i = ndim - 1; i >= 0; --i) {
+ attr_axes.push_back(i);
+ }
+ }
+ } else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
+ Layout src_layout(attr->src_layout);
+ Layout dst_layout(attr->dst_layout);
+ for (int i = 0; i < ndim; ++i) {
+ attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
+ }
+ } else {
+ CHECK(false) << "Expected transpose or layout_transform, but got "
+ << Downcast<Op>(call->op)->name;
+ }
+ return std::move(attr_axes);
+ }
+
private:
/*! \brief Pattern input */
DFPattern x_;
diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py
index d1dffa3..9f11d38 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -106,10 +106,176 @@ def test_simplify_transpose():
y = relay.transpose(y, axes=[0, 2, 3, 1])
return relay.Function([x], y)
+ # Test a series of transpose and rank changing layout_transform
+ def before4():
+ """
+ Simplify transpose->layout_transform and its inverse.
+
+ Input:
+ NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC
+
+ Simplified:
+ NHWC -> NCHW4c -> op -> NCHW4c -> NHWC
+ """
+ x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32")
+ y = relay.transpose(x, axes=[0, 3, 1, 2])
+ y = relay.layout_transform(y, "NCHW", "NCHW4c")
+ y = relay.nn.relu(y)
+ y = relay.layout_transform(y, "NCHW4c", "NCHW")
+ y = relay.transpose(y, axes=[0, 2, 3, 1])
+ return relay.Function([x], y)
+
+ def expected4():
+ x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
+ y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c
+ y = relay.nn.relu(y)
+ y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC
+ return relay.Function([x], y)
+
+ def before5():
+ """
+ Simplify layout_transform->layout_transform and its inverse.
+
+ Input:
+ NHWC -> NCHW -> NCHW4c -> op -> NCHW4c -> NCHW -> NHWC
+
+ Simplified:
+ NHWC -> NCHW4c -> op -> NCHW4c -> NHWC
+ """
+ x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
+ y = relay.layout_transform(x, "NHWC", "NCHW") # To NCHW
+ y = relay.layout_transform(y, "NCHW", "NCHW4c") # To NCHW4c
+ y = relay.nn.relu(y)
+ y = relay.layout_transform(y, "NCHW4c", "NCHW") # To NCHW
+ y = relay.layout_transform(y, "NCHW", "NHWC") # To NHWC
+ return relay.Function([x], y)
+
+ def expected5():
+ x = relay.var("x", shape=(1, 56, 56, 128), dtype="float32") # NHWC
+ y = relay.layout_transform(x, "NHWC", "NCHW4c") # To NCHW4c
+ y = relay.nn.relu(y)
+ y = relay.layout_transform(y, "NCHW4c", "NHWC") # To NHWC
+ return relay.Function([x], y)
+
+ def before6():
+ """
+ Remove trivial layout_transform->layout_transform.
+
+ Input:
+ NCHW -> NHWC -> NCHW -> op
+
+ Simplified:
+ NHWC -> op
+ """
+
+ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+ y = relay.layout_transform(x, "NCHW", "NHWC")
+ y = relay.layout_transform(y, "NHWC", "NCHW")
+ y = relay.nn.relu(y)
+ return relay.Function([x], y)
+
+ def expected6():
+ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+ y = relay.nn.relu(x)
+ return relay.Function([x], y)
+
+ def before7():
+ """
+ Remove trivial layout_transform->layout_transform.
+
+ Input:
+ NCHW4c -> NCHW8c -> NCHW4c -> op
+
+ Simplified:
+ NCHW4c -> op
+ """
+ x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
+ y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
+ y = relay.layout_transform(y, "NCHW8c", "NCHW4c")
+ y = relay.nn.relu(y)
+ return relay.Function([x], y)
+
+ def expected7():
+ x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
+ y = relay.nn.relu(x)
+ return relay.Function([x], y)
+
+ def before8():
+ """
+ Simplify layout_transform->layout_transform with rank contraction and expansion
+
+ Input:
+ NCHW4c -> NCHW -> NCHW8c -> op
+
+ Simplified:
+ NCHW4c -> NCHW8c -> op
+ """
+ x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
+ y = relay.layout_transform(x, "NCHW4c", "NCHW")
+ y = relay.layout_transform(y, "NCHW", "NCHW8c")
+ y = relay.nn.relu(y)
+ return relay.Function([x], y)
+
+ def expected8():
+ x = relay.var("x", shape=(1, 32, 56, 56, 4), dtype="float32")
+ y = relay.layout_transform(x, "NCHW4c", "NCHW8c")
+ y = relay.nn.relu(y)
+ return relay.Function([x], y)
+
+ def before9():
+ """
+ Remove trivial layout_transform->layout_transform.
+
+ Input:
+ NCHW -> NCHW4c -> NCHW -> op
+
+ Simplified:
+ NCHW -> op
+ """
+ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+ y = relay.layout_transform(x, "NCHW", "NCHW4c")
+ y = relay.layout_transform(y, "NCHW4c", "NCHW")
+ y = relay.nn.relu(y)
+ return relay.Function([x], y)
+
+ def expected9():
+ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+ y = relay.nn.relu(x)
+ return relay.Function([x], y)
+
+ def before10():
+ """
+ Simplify layout_transform->layout_transform without rank change to transpose.
+
+ Input:
+ NCHW -> NHWC -> CHWN -> op
+
+ Simplified:
+ NCHW -> CHWN -> op
+ """
+ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+ y = relay.layout_transform(x, "NCHW", "NHWC")
+ y = relay.layout_transform(y, "NHWC", "CHWN")
+ y = relay.nn.relu(y)
+ return relay.Function([x], y)
+
+ def expected10():
+ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+ y = relay.transpose(x, axes=[1, 2, 3, 0])
+ y = relay.nn.relu(y)
+ return relay.Function([x], y)
+
for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
[before3(), expected3()],
+ [before4(), expected4()],
+ [before5(), expected5()],
+ [before6(), expected6()],
+ [before7(), expected7()],
+ [before8(), expected8()],
+ [before9(), expected9()],
+ [before10(), expected10()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())