You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/09/12 08:02:42 UTC
[tvm] branch unity updated: [UNITY][Pass] Optimize redundant layout transform ops (#15678)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0e800cf3f8 [UNITY][Pass] Optimize redundant layout transform ops (#15678)
0e800cf3f8 is described below
commit 0e800cf3f85a7f8df3f331c00dd2a946b4c94f19
Author: Abhikrant Sharma <qu...@quicinc.com>
AuthorDate: Tue Sep 12 13:32:34 2023 +0530
[UNITY][Pass] Optimize redundant layout transform ops (#15678)
* [UNITY][Pass] Optimize redundant layout transform ops
Relax AlterOpImpl pass introduces layout_transform operations. If the layouts match for consecutive layout_transform operations, they can be cancelled out.
This pass tries to optimize redundant transform_layout operations.
* Fix LINT errors
* Use function pass instead of module pass
* Fix more LINT errors
---
python/tvm/relax/transform/__init__.py | 1 +
.../relax/transform/optimize_layout_transform.py | 75 ++++++
.../python/relax/test_optimize_layout_transform.py | 277 +++++++++++++++++++++
3 files changed, 353 insertions(+)
diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py
index 1a8696ca06..68128db62d 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -19,6 +19,7 @@
from .transform import *
from .lazy_transform_params import LazyTransformParams
+from .optimize_layout_transform import OptimizeLayoutTransform
# Import to register the legalization functions.
from . import legalize_ops
diff --git a/python/tvm/relax/transform/optimize_layout_transform.py b/python/tvm/relax/transform/optimize_layout_transform.py
new file mode 100644
index 0000000000..a61a3bc239
--- /dev/null
+++ b/python/tvm/relax/transform/optimize_layout_transform.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
+"""Relax Optimize Layout Transform pass."""
+from tvm.ir.module import IRModule
+from tvm.ir.transform import PassContext
+from tvm.relax import Expr, Function
+from tvm.relax.dpl import is_op, rewrite_call, wildcard
+from . import function_pass
+
+
+@function_pass(opt_level=0)
+class OptimizeLayoutTransform:
+ """
+ Pass to remove redundant transform layout operators
+ introduced by AlterOpImpl pass.
+ """
+
+ def __init__(self):
+ self.input = wildcard()
+ pattern_transform_layout = is_op("relax.layout_transform")(self.input)
+ pattern_1 = is_op("relax.layout_transform")(pattern_transform_layout)
+
+ self.pattern = pattern_1
+
+ def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule:
+ """
+ Tranformation function to pattern match layout_transform -> layout_transform
+ pattern
+
+ Parameters
+ ----------
+ func: Expr
+ The relax function to be optimized
+
+ mod: IRModule
+ The ir module
+
+ ctx: PassContext
+ Relax pass context
+ """
+
+ updated_func = func
+ for _, func in mod.functions.items():
+ # Skip non-relax functions
+ if not isinstance(func, Function):
+ continue
+ # Skip primitive functions
+ if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
+ continue
+
+ def rewriter(expr, matches):
+ arg1 = matches[self.pattern]
+ arg2 = matches[self.input]
+ if list(arg1.struct_info.shape) == list(arg2.struct_info.shape):
+ return arg2
+ return expr
+
+ updated_func = rewrite_call(self.pattern, rewriter, func)
+
+ return updated_func
diff --git a/tests/python/relax/test_optimize_layout_transform.py b/tests/python/relax/test_optimize_layout_transform.py
new file mode 100644
index 0000000000..bb6db3c6ed
--- /dev/null
+++ b/tests/python/relax/test_optimize_layout_transform.py
@@ -0,0 +1,277 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests to validate relax optimize layout tranform pass."""
+
+import numpy as np
+import pytest
+import tvm.testing
+from tvm import relax
+from tvm.ir.base import assert_structural_equal
+from tvm.relax.transform import DeadCodeElimination, FuseTIR, OptimizeLayoutTransform
+from tvm.script import ir as I, tir as T, relax as R
+
+
+def _run_pass_compare_output(Before, Expected):
+ fused_mod = OptimizeLayoutTransform()(Before)
+ if not relax.analysis.well_formed(fused_mod):
+ print("IRModule is not well-formed")
+
+ fused_mode = DeadCodeElimination()(fused_mod)
+ if not relax.analysis.well_formed(fused_mod):
+ print("IRModule is not well-formed")
+
+ fused_mod = FuseTIR()(fused_mod)
+ if not relax.analysis.well_formed(fused_mod):
+ print("IRModule is not well-formed")
+
+ tvm.ir.assert_structural_equal(Expected, fused_mod)
+
+
+def test_optimize_transform_layout_pass_one_arg():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def relax_add_replacement(
+ arg0: T.Buffer((4, 4), "float32"),
+ arg1: T.Buffer((4, 4), "float32"),
+ output: T.Buffer((4, 4), "float32"),
+ ):
+ T.func_attr({"operator_name": "relax.add"})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
+ ) -> R.Tensor((16,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv2 = R.call_tir(
+ Before.relax_add_replacement,
+ (lv, lv1),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv0: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+ )
+ lv3: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv5 = R.call_tir(
+ Before.relax_add_replacement,
+ (lv4, lv3),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+ )
+ gv: R.Tensor((16,), dtype="float32") = lv2_1
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def relax_add_replacement(
+ arg0: T.Buffer((4, 4), "float32"),
+ arg1: T.Buffer((4, 4), "float32"),
+ output: T.Buffer((4, 4), "float32"),
+ ):
+ T.func_attr({"operator_name": "relax.add"})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")
+ ) -> R.Tensor((16,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv2 = R.call_tir(
+ Expected.relax_add_replacement,
+ (lv, lv1),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv4: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv5 = R.call_tir(
+ Expected.relax_add_replacement,
+ (lv4, lv2),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+ )
+ gv: R.Tensor((16,), dtype="float32") = lv2_1
+ R.output(gv)
+ return gv
+
+ _run_pass_compare_output(Before, Expected)
+
+
+def test_optimize_transform_layout_pass_two_args():
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def relax_add_replacement(
+ arg0: T.Buffer((4, 4), "float32"),
+ arg1: T.Buffer((4, 4), "float32"),
+ output: T.Buffer((4, 4), "float32"),
+ ):
+ T.func_attr({"operator_name": "relax.add"})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor((16,), dtype="float32"),
+ y: R.Tensor((16,), dtype="float32"),
+ z: R.Tensor((16,), dtype="float32"),
+ ) -> R.Tensor((16,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ z, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv3 = R.call_tir(
+ Before.relax_add_replacement,
+ (lv, lv1),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv4 = R.call_tir(
+ Before.relax_add_replacement,
+ (lv, lv2),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv5: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+ )
+ lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ lv4, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+ )
+ lv7: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ lv5, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv8: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ lv6, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv9 = R.call_tir(
+ Before.relax_add_replacement,
+ (lv7, lv8),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv10: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ lv9, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+ )
+ gv: R.Tensor((16,), dtype="float32") = lv10
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def relax_add_replacement(
+ arg0: T.Buffer((4, 4), "float32"),
+ arg1: T.Buffer((4, 4), "float32"),
+ output: T.Buffer((4, 4), "float32"),
+ ):
+ T.func_attr({"operator_name": "relax.add"})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+
+ @R.function
+ def main(
+ x: R.Tensor((16,), dtype="float32"),
+ y: R.Tensor((16,), dtype="float32"),
+ z: R.Tensor((16,), dtype="float32"),
+ ) -> R.Tensor((16,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ y, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv2: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ z, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ lv3 = R.call_tir(
+ Expected.relax_add_replacement,
+ (lv, lv1),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv4 = R.call_tir(
+ Expected.relax_add_replacement,
+ (lv, lv2),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv5 = R.call_tir(
+ Expected.relax_add_replacement,
+ (lv3, lv4),
+ out_sinfo=R.Tensor((4, 4), dtype="float32"),
+ )
+ lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
+ lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), pad_value=None
+ )
+ gv: R.Tensor((16,), dtype="float32") = lv6
+ R.output(gv)
+ return gv
+
+ _run_pass_compare_output(Before, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()