You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/09/12 08:02:42 UTC

[tvm] branch unity updated: [UNITY][Pass] Optimize redundant layout transform ops (#15678)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0e800cf3f8 [UNITY][Pass] Optimize redundant layout transform ops (#15678)
0e800cf3f8 is described below

commit 0e800cf3f85a7f8df3f331c00dd2a946b4c94f19
Author: Abhikrant Sharma <qu...@quicinc.com>
AuthorDate: Tue Sep 12 13:32:34 2023 +0530

    [UNITY][Pass] Optimize redundant layout transform ops (#15678)
    
    * [UNITY][Pass] Optimize redundant layout transform ops
    
    Relax AlterOpImpl pass introduces layout_transform operations. If the layouts match for consecutive layout_transform operations, they can be cancelled out.
    This pass tries to optimize redundant transform_layout operations.
    
    * Fix LINT errors
    
    * Use function pass instead of module pass
    
    * Fix more LINT errors
---
 python/tvm/relax/transform/__init__.py             |   1 +
 .../relax/transform/optimize_layout_transform.py   |  75 ++++++
 .../python/relax/test_optimize_layout_transform.py | 277 +++++++++++++++++++++
 3 files changed, 353 insertions(+)

diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py
index 1a8696ca06..68128db62d 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -19,6 +19,7 @@
 
 from .transform import *
 from .lazy_transform_params import LazyTransformParams
+from .optimize_layout_transform import OptimizeLayoutTransform
 
 # Import to register the legalization functions.
 from . import legalize_ops
diff --git a/python/tvm/relax/transform/optimize_layout_transform.py b/python/tvm/relax/transform/optimize_layout_transform.py
new file mode 100644
index 0000000000..a61a3bc239
--- /dev/null
+++ b/python/tvm/relax/transform/optimize_layout_transform.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
+"""Relax Optimize Layout Transform pass."""
+from tvm.ir.module import IRModule
+from tvm.ir.transform import PassContext
+from tvm.relax import Expr, Function
+from tvm.relax.dpl import is_op, rewrite_call, wildcard
+from . import function_pass
+
+
+@function_pass(opt_level=0)
+class OptimizeLayoutTransform:
+    """
+    Pass to remove redundant transform layout operators
+    introduced by AlterOpImpl pass.
+    """
+
+    def __init__(self):
+        self.input = wildcard()
+        pattern_transform_layout = is_op("relax.layout_transform")(self.input)
+        pattern_1 = is_op("relax.layout_transform")(pattern_transform_layout)
+
+        self.pattern = pattern_1
+
+    def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule:
+        """
+        Tranformation function to pattern match layout_transform -> layout_transform
+        pattern
+
+        Parameters
+        ----------
+        func: Expr
+            The relax function to be optimized
+
+        mod: IRModule
+            The ir module
+
+        ctx: PassContext
+            Relax pass context
+        """
+
+        updated_func = func
+        for _, func in mod.functions.items():
+            # Skip non-relax functions
+            if not isinstance(func, Function):
+                continue
+            # Skip primitive functions
+            if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
+                continue
+
+            def rewriter(expr, matches):
+                arg1 = matches[self.pattern]
+                arg2 = matches[self.input]
+                if list(arg1.struct_info.shape) == list(arg2.struct_info.shape):
+                    return arg2
+                return expr
+
+            updated_func = rewrite_call(self.pattern, rewriter, func)
+
+        return updated_func
diff --git a/tests/python/relax/test_optimize_layout_transform.py b/tests/python/relax/test_optimize_layout_transform.py
new file mode 100644
index 0000000000..bb6db3c6ed
--- /dev/null
+++ b/tests/python/relax/test_optimize_layout_transform.py
@@ -0,0 +1,277 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests to validate relax optimize layout tranform pass."""
+
+import numpy as np
+import pytest
+import tvm.testing
+from tvm import relax
+from tvm.ir.base import assert_structural_equal
+from tvm.relax.transform import DeadCodeElimination, FuseTIR, OptimizeLayoutTransform
+from tvm.script import ir as I, tir as T, relax as R
+
+
+def _run_pass_compare_output(Before, Expected):
+    fused_mod = OptimizeLayoutTransform()(Before)
+    if not relax.analysis.well_formed(fused_mod):
+        print("IRModule is not well-formed")
+
+    fused_mode = DeadCodeElimination()(fused_mod)
+    if not relax.analysis.well_formed(fused_mod):
+        print("IRModule is not well-formed")
+
+    fused_mod = FuseTIR()(fused_mod)
+    if not relax.analysis.well_formed(fused_mod):
+        print("IRModule is not well-formed")
+
+    tvm.ir.assert_structural_equal(Expected, fused_mod)
+
+
+def test_optimize_transform_layout_pass_one_arg():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def relax_add_replacement(
+            arg0: T.Buffer((4, 4), "float32"),
+            arg1: T.Buffer((4, 4), "float32"),
+            output: T.Buffer((4, 4), "float32"),
+        ):
+            T.func_attr({"operator_name": "relax.add"})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output[v_ax0, v_ax1])
+                    output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+
+        @R.function
+        def main(
+            x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
+        ) -> R.Tensor((16,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv2 = R.call_tir(
+                    Before.relax_add_replacement,
+                    (lv, lv1),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv0: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+                )
+                lv3: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv5 = R.call_tir(
+                    Before.relax_add_replacement,
+                    (lv4, lv3),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+                )
+                gv: R.Tensor((16,), dtype="float32") = lv2_1
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def relax_add_replacement(
+            arg0: T.Buffer((4, 4), "float32"),
+            arg1: T.Buffer((4, 4), "float32"),
+            output: T.Buffer((4, 4), "float32"),
+        ):
+            T.func_attr({"operator_name": "relax.add"})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output[v_ax0, v_ax1])
+                    output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+
+        @R.function
+        def main(
+            x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
+        ) -> R.Tensor((16,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv2 = R.call_tir(
+                    Expected.relax_add_replacement,
+                    (lv, lv1),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv5 = R.call_tir(
+                    Expected.relax_add_replacement,
+                    (lv4, lv2),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+                )
+                gv: R.Tensor((16,), dtype="float32") = lv2_1
+                R.output(gv)
+            return gv
+
+    _run_pass_compare_output(Before, Expected)
+
+
+def test_optimize_transform_layout_pass_two_args():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def relax_add_replacement(
+            arg0: T.Buffer((4, 4), "float32"),
+            arg1: T.Buffer((4, 4), "float32"),
+            output: T.Buffer((4, 4), "float32"),
+        ):
+            T.func_attr({"operator_name": "relax.add"})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output[v_ax0, v_ax1])
+                    output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+
+        @R.function
+        def main(
+            x: R.Tensor((16,), dtype="float32"),
+            y: R.Tensor((16,), dtype="float32"),
+            z: R.Tensor((16,), dtype="float32"),
+        ) -> R.Tensor((16,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    z, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv3 = R.call_tir(
+                    Before.relax_add_replacement,
+                    (lv, lv1),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv4 = R.call_tir(
+                    Before.relax_add_replacement,
+                    (lv, lv2),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+                )
+                lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    lv4, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+                )
+                lv7: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    lv5, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv8: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    lv6, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv9 = R.call_tir(
+                    Before.relax_add_replacement,
+                    (lv7, lv8),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv10: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    lv9, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+                )
+                gv: R.Tensor((16,), dtype="float32") = lv10
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def relax_add_replacement(
+            arg0: T.Buffer((4, 4), "float32"),
+            arg1: T.Buffer((4, 4), "float32"),
+            output: T.Buffer((4, 4), "float32"),
+        ):
+            T.func_attr({"operator_name": "relax.add"})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output[v_ax0, v_ax1])
+                    output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+
+        @R.function
+        def main(
+            x: R.Tensor((16,), dtype="float32"),
+            y: R.Tensor((16,), dtype="float32"),
+            z: R.Tensor((16,), dtype="float32"),
+        ) -> R.Tensor((16,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    z, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                lv3 = R.call_tir(
+                    Expected.relax_add_replacement,
+                    (lv, lv1),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv4 = R.call_tir(
+                    Expected.relax_add_replacement,
+                    (lv, lv2),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv5 = R.call_tir(
+                    Expected.relax_add_replacement,
+                    (lv3, lv4),
+                    out_sinfo=R.Tensor((4, 4), dtype="float32"),
+                )
+                lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+                )
+                gv: R.Tensor((16,), dtype="float32") = lv6
+                R.output(gv)
+            return gv
+
+    _run_pass_compare_output(Before, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()