You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by sa...@apache.org on 2023/12/04 04:59:18 UTC

(tvm) branch unity updated: [Unity] [Transform] Remove iteration over functions in function pass (#16173)

This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8f95f6147a [Unity] [Transform] Remove iteration over functions in function pass (#16173)
8f95f6147a is described below

commit 8f95f6147a0a5a633a6759ab2e70d8ab0785c65a
Author: Anirudh Sundar Subramaniam <qu...@quicinc.com>
AuthorDate: Mon Dec 4 10:29:13 2023 +0530

    [Unity] [Transform] Remove iteration over functions in function pass (#16173)
    
    [Unity] [Transform] Remove iterating over functions in function pass
    
    There was a small redundancy in a couple of function passes where they
    iterate over all passes in the mod within the transform function, while
    function pass itself would also do that iteration.
    
    So we would be un-necessarily iterating through all functions and
    modifying them, but only one of them would be updated each time
    transform_function is called.
---
 .../relax/transform/optimize_layout_transform.py   | 37 ++++++++++------------
 .../relax/transform/remove_redundant_reshape.py    | 37 ++++++++++------------
 2 files changed, 33 insertions(+), 41 deletions(-)

diff --git a/python/tvm/relax/transform/optimize_layout_transform.py b/python/tvm/relax/transform/optimize_layout_transform.py
index 4fe9d86555..b743e98e5c 100644
--- a/python/tvm/relax/transform/optimize_layout_transform.py
+++ b/python/tvm/relax/transform/optimize_layout_transform.py
@@ -19,7 +19,7 @@
 from tvm.ir import structural_equal
 from tvm.ir.module import IRModule
 from tvm.ir.transform import PassContext
-from tvm.relax import Expr, Function
+from tvm.relax import Expr
 from tvm.relax.dpl import TuplePattern, is_op, rewrite_call, wildcard
 
 from . import function_pass
@@ -63,27 +63,24 @@ class OptimizeLayoutTransform:
 
         self.mod = mod
         updated_func = func
-        for _, func in mod.functions_items():
-            # Skip non-relax functions
-            if not isinstance(func, Function):
-                continue
-            # Skip primitive functions
-            if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
-                continue
 
-            def rewriter(expr, matches):
-                arg1 = matches[self.pattern]
-                if self.pattern_2 not in matches.keys():
+        # Skip primitive functions
+        if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
+            return updated_func
+
+        def rewriter(expr, matches):
+            arg1 = matches[self.pattern]
+            if self.pattern_2 not in matches.keys():
+                arg2 = matches[self.input]
+            else:
+                arg2 = matches[self.gv_]
+                if "remove_pad" == self.mod[arg2].attrs["operator_name"]:
                     arg2 = matches[self.input]
-                else:
-                    arg2 = matches[self.gv_]
-                    if "remove_pad" == self.mod[arg2].attrs["operator_name"]:
-                        arg2 = matches[self.input]
-                if hasattr(arg1.struct_info, "shape") and hasattr(arg2.struct_info, "shape"):
-                    if structural_equal(arg1.struct_info.shape, arg2.struct_info.shape):
-                        return arg2
-                return expr
+            if hasattr(arg1.struct_info, "shape") and hasattr(arg2.struct_info, "shape"):
+                if structural_equal(arg1.struct_info.shape, arg2.struct_info.shape):
+                    return arg2
+            return expr
 
-            updated_func = rewrite_call(self.pattern, rewriter, func)
+        updated_func = rewrite_call(self.pattern, rewriter, func)
 
         return updated_func
diff --git a/python/tvm/relax/transform/remove_redundant_reshape.py b/python/tvm/relax/transform/remove_redundant_reshape.py
index bdd89e2761..07e0963caf 100644
--- a/python/tvm/relax/transform/remove_redundant_reshape.py
+++ b/python/tvm/relax/transform/remove_redundant_reshape.py
@@ -19,7 +19,7 @@
 from tvm import IRModule, relax
 from tvm.ir import structural_equal
 from tvm.ir.transform import PassContext
-from tvm.relax import Expr, Function
+from tvm.relax import Expr
 from tvm.relax.dpl import is_op, rewrite_call, wildcard
 
 from . import function_pass
@@ -58,29 +58,24 @@ class RemoveRedundantReshape:
         """
 
         updated_func = func
-        for _, funct in mod.functions_items():
-            # Skip non-relax functions
-            if not isinstance(funct, Function):
-                continue
-            # Skip primitive functions
-            if "Primitive" in funct.attrs.keys() and funct.attrs["Primitive"] != 0:
-                continue
 
-            def rewriter(expr, matches):
-                arg = matches[self.input1]
+        # Skip primitive functions
+        if "Primitive" in func.attrs.keys() and func.attrs["Primitive"] != 0:
+            return updated_func
 
-                if self.repeated_reshape in matches:
-                    output_shape = matches[self.repeated_reshape].args[1]
-                    return relax.op.reshape(arg, output_shape)
+        def rewriter(expr, matches):
+            arg = matches[self.input1]
 
-                elif self.no_op_reshape in matches:
-                    output_shape = matches[self.no_op_reshape].args[1]
-                    if arg.struct_info.shape and structural_equal(
-                        arg.struct_info.shape, output_shape
-                    ):
-                        return arg
-                return expr
+            if self.repeated_reshape in matches:
+                output_shape = matches[self.repeated_reshape].args[1]
+                return relax.op.reshape(arg, output_shape)
 
-            updated_func = rewrite_call(self.pattern, rewriter, funct)
+            elif self.no_op_reshape in matches:
+                output_shape = matches[self.no_op_reshape].args[1]
+                if arg.struct_info.shape and structural_equal(arg.struct_info.shape, output_shape):
+                    return arg
+            return expr
+
+        updated_func = rewrite_call(self.pattern, rewriter, func)
 
         return updated_func