You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "slyubomirsky (via GitHub)" <gi...@apache.org> on 2023/03/21 22:42:09 UTC

[GitHub] [tvm] slyubomirsky opened a new pull request, #14361: [Unity][Transform] Common Subexpression Elimination

slyubomirsky opened a new pull request, #14361:
URL: https://github.com/apache/tvm/pull/14361

   This PR implements a dataflow block CSE transformation. Since we use ANF internally, the only nesting really occurs with tuples. The pass only needs to look at the RHS of bindings to determine if we have encountered any subexpressions.
   
   Co-authored by @psrivas2  Prakalp Srivastava <pr...@octoml.ai>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] psrivas2 commented on pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "psrivas2 (via GitHub)" <gi...@apache.org>.
psrivas2 commented on PR #14361:
URL: https://github.com/apache/tvm/pull/14361#issuecomment-1479684755

   cc @masahi 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1146973483


##########
tests/python/relax/test_transform_cse.py:
##########
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                lv1 = R.add(x, y)
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                # can combine with canonicalizing bindings
+                # and getting rid of unused bindings to eliminate this line too
+                lv1 = lv0
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_constants():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                # we are not going to bind the constant 1 to a var
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                # we expect to bind the repeated large constants
+                lv1 = R.add(
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                )
+                gv = (lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
+                lv2 = R.add(lv1, lv1)
+                gv = (lv0, lv2)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_repeated_inner_tuples():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x)))
+                tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x)))
+                gv = tup[0][0][1]
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                t1 = (x, x)
+                t2 = (x, t1)
+                t3 = (t1, t2)
+                t4 = (t3, t3, t2)
+                gv = t4[0][0][1]
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_inner_function():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # we are going to do CSE inside the local function
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):

Review Comment:
   This actually will handle duplicate local functions. We could do a module-level analysis too for that



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] spectrometerHBH commented on pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "spectrometerHBH (via GitHub)" <gi...@apache.org>.
spectrometerHBH commented on PR #14361:
URL: https://github.com/apache/tvm/pull/14361#issuecomment-1480130861

   > 
   
   Yeah I agree. One issue that comes to my mind is that it might be critical to decide what subexprs to eliminate. If it's lightweight and inlinable to surrounding ops, then we should probably reject to eliminate the redundancy.
   
   In classical settings, the trade-off of CSE is to enlarge the live range of some vars, which might cause performance regression due to register spill. But for DL workloads, it's likely that the model simply can not be deployed given a certain amount of GPU Memory.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on PR #14361:
URL: https://github.com/apache/tvm/pull/14361#issuecomment-1480037566

   The ideal scenario would be if some model (especially an imported one) repeatedly invokes the same operation (an expensive one). We encountered that pattern already in some imported models. CSE is a classic optimization.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1144061776


##########
tests/python/relax/test_transform_cse.py:
##########
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                lv1 = R.add(x, y)
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                # can combine with canonicalizing bindings
+                # and getting rid of unused bindings to eliminate this line too
+                lv1 = lv0
+                gv = R.multiply(lv0, lv1)

Review Comment:
   I didn't want to reimplement the functionality of the canonicalize bindings pass, but we could fold that in here if we want to (I don't think it's a good idea).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yongwww commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "yongwww (via GitHub)" <gi...@apache.org>.
yongwww commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1146979043


##########
tests/python/relax/test_transform_cse.py:
##########
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                lv1 = R.add(x, y)
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                # can combine with canonicalizing bindings
+                # and getting rid of unused bindings to eliminate this line too
+                lv1 = lv0
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_constants():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                # we are not going to bind the constant 1 to a var
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                # we expect to bind the repeated large constants
+                lv1 = R.add(
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                )
+                gv = (lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
+                lv2 = R.add(lv1, lv1)
+                gv = (lv0, lv2)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_repeated_inner_tuples():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x)))
+                tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x)))
+                gv = tup[0][0][1]
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                t1 = (x, x)
+                t2 = (x, t1)
+                t3 = (t1, t2)
+                t4 = (t3, t3, t2)
+                gv = t4[0][0][1]
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_inner_function():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # we are going to do CSE inside the local function
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):

Review Comment:
   wow, looks great! module-level analysis for global funcs will be helpful for my case, it could be a todo. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1144062334


##########
tests/python/relax/test_transform_cse.py:
##########
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                lv1 = R.add(x, y)
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                # can combine with canonicalizing bindings
+                # and getting rid of unused bindings to eliminate this line too
+                lv1 = lv0
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_constants():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                # we are not going to bind the constant 1 to a var
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                # we expect to bind the repeated large constants
+                lv1 = R.add(
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                )
+                gv = (lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
+                lv2 = R.add(lv1, lv1)
+                gv = (lv0, lv2)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_repeated_inner_tuples():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x)))
+                tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x)))
+                gv = tup[0][0][1]
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                t1 = (x, x)
+                t2 = (x, t1)
+                t3 = (t1, t2)
+                t4 = (t3, t3, t2)
+                gv = t4[0][0][1]
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_inner_function():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # we are going to do CSE inside the local function
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+                    # not in dataflow: should not be touched
+                    z = R.add(R.add(y, y), R.add(y, y))
+                    with R.dataflow():
+                        # writing this out in ANF to illustrate why CSE behaves as it does
+                        # result of ANF transforming R.add(R.add(y, y), R.add(y, y))
+                        lv0 = R.add(y, y)
+                        lv1 = R.add(y, y)
+                        lv2 = R.add(lv0, lv1)
+                        gv = lv2
+                        R.output(gv)
+                    return R.add(z, gv)
+
+                # also making the ANF explicit to better illustrate the result of CSE
+                # result of ANF transforming R.add(R.add(bar(x), bar(x)), R.add(bar(x), bar(x)))
+                lv0 = bar(x)
+                lv1 = bar(x)
+                lv2 = R.add(lv0, lv1)
+                lv3 = bar(x)
+                lv4 = bar(x)
+                lv5 = R.add(lv3, lv4)
+                lv6 = R.add(lv2, lv5)
+                gv = lv6
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+                    z = R.add(R.add(y, y), R.add(y, y))
+                    with R.dataflow():
+                        lv0 = R.add(y, y)
+                        lv1 = lv0
+                        lv2 = R.add(lv0, lv1)
+                        gv = lv2
+                        R.output(gv)
+                    return R.add(z, gv)
+
+                # can further clean this up
+                # using canonicalize bindings, eliminate unused bindings, and CSE again
+                lv0 = bar(x)
+                lv1 = lv0
+                lv2 = R.add(lv0, lv1)
+                lv3 = lv0
+                lv4 = lv0
+                lv5 = R.add(lv3, lv4)
+                lv6 = R.add(lv2, lv5)
+                gv = lv6
+                R.output(gv)

Review Comment:
   This is not ideal because there are still repeated additions. @psrivas2 once proposed having a canonicalization pass that runs until fixpoint. This is an example where running such a pass will help.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yongwww commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "yongwww (via GitHub)" <gi...@apache.org>.
yongwww commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1146365497


##########
tests/python/relax/test_transform_cse.py:
##########
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                lv1 = R.add(x, y)
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                # can combine with canonicalizing bindings
+                # and getting rid of unused bindings to eliminate this line too
+                lv1 = lv0
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_constants():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                # we are not going to bind the constant 1 to a var
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                # we expect to bind the repeated large constants
+                lv1 = R.add(
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                )
+                gv = (lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
+                lv2 = R.add(lv1, lv1)
+                gv = (lv0, lv2)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_repeated_inner_tuples():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x)))
+                tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x)))
+                gv = tup[0][0][1]
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                t1 = (x, x)
+                t2 = (x, t1)
+                t3 = (t1, t2)
+                t4 = (t3, t3, t2)
+                gv = t4[0][0][1]
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_inner_function():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # we are going to do CSE inside the local function
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):

Review Comment:
   Not sure if we should handle the case: duplicate gloabl/local functions. I met this case when lower jax to hlo.
   
                   
                   @R.function
                   def bar1(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
                       z = R.add(R.add(y, y), R.add(y, y))
                       with R.dataflow():
                           lv0 = R.add(y, y)
                           lv1 = R.add(y, y)
                           lv2 = R.add(lv0, lv1)
                           gv = lv2
                           R.output(gv)
                       return R.add(z, gv)
                   
                   @R.function
                   def bar2(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
                       z = R.add(R.add(y, y), R.add(y, y))
                       with R.dataflow():
                           lv0 = R.add(y, y)
                           lv1 = R.add(y, y)
                           lv2 = R.add(lv0, lv1)
                           gv = lv2
                           R.output(gv)
                       return R.add(z, gv)
   
                   lv0 = bar1(x)
                   lv1 = bar2(x)
                   lv2 = R.add(lv0, lv1)
                   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1147092182


##########
tests/python/relax/test_transform_cse.py:
##########
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                lv1 = R.add(x, y)
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                # can combine with canonicalizing bindings
+                # and getting rid of unused bindings to eliminate this line too
+                lv1 = lv0
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_constants():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                # we are not going to bind the constant 1 to a var
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                # we expect to bind the repeated large constants
+                lv1 = R.add(
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                )
+                gv = (lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
+                lv2 = R.add(lv1, lv1)
+                gv = (lv0, lv2)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_repeated_inner_tuples():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x)))
+                tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x)))
+                gv = tup[0][0][1]
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                t1 = (x, x)
+                t2 = (x, t1)
+                t3 = (t1, t2)
+                t4 = (t3, t3, t2)
+                gv = t4[0][0][1]
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_inner_function():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # we are going to do CSE inside the local function
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):

Review Comment:
   To be clear, it would only handle it if the bindings happen inside a `DataflowBlock` :zany_face: 
   
   We might want to consider expanding this pass to handle non-dataflow sections as well. Purity tracking would help with that



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #14361:
URL: https://github.com/apache/tvm/pull/14361#issuecomment-1478690192

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * cc @quic-sanirudh <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1144060178


##########
src/relax/transform/eliminate_common_subexpr.cc:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/eliminate_common_subexpr.cc
+ * \brief Eliminrate common subexpression pass.
+ *
+ * Currently it removes common subexpressions within a DataflowBlock.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class SubexprCounter : public ExprVisitor {
+ public:
+  // overriding VisitExpr ensures we do this for every subexpression
+  void VisitExpr(const Expr& e) override {
+    // Cases we ignore because we will not substitute them:
+    // 1. Vars of all kinds
+    // 2. Op nodes (nothing we can do)
+    // 3. Scalar constants (not much benefit from binding to a var)
+    if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
+          e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
+          (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
+      int count = 0;
+      if (count_map_.count(e)) {
+        count = count_map_.at(e);
+      }
+      count_map_[e] = count + 1;
+    }
+    ExprVisitor::VisitExpr(e);
+  }
+
+  // do not visit inner functions: we will do CSE within those
+  void VisitExpr_(const FunctionNode* func) override {}
+
+  // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+  void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
+
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(
+      const DataflowBlock& df_block) {
+    for (auto binding : df_block->bindings) {
+      VisitBinding(binding);
+    }
+    return count_map_;
+  }
+
+ private:
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
+};
+
+// forward declaration
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock&);
+
+class CommonSubexprEliminator : public ExprMutator {
+ public:
+  explicit CommonSubexprEliminator(
+      const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map)
+      : count_map_(count_map) {}
+
+  // overriding here ensures we visit every subexpression
+  Expr VisitExpr(const Expr& e) override {
+    if (count_map_.count(e) && count_map_.at(e) > 1) {
+      // if we already have a mapping for it, get it
+      if (replacements_.count(e)) {
+        return replacements_.at(e);
+      }
+      // Otherwise, insert a new binding for the current expression.
+      // Visit before emitting to do inner replacements
+      Expr new_e = ExprMutator::VisitExpr(e);
+      Var v = builder_->Emit(new_e);
+      replacements_[e] = v;
+      return v;
+    }
+    return ExprMutator::VisitExpr(e);
+  }
+
+  // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+  StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
+    return struct_info;
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) override {
+    // for an inner function, we will do CSE on its body
+    Expr new_body = ExprMutator::VisitExpr(func->body);
+    if (new_body.same_as(func->body)) {
+      return GetRef<Expr>(func);
+    }
+    return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span);
+  }
+
+  // this should happen only for the inner function case
+  Expr VisitExpr_(const SeqExprNode* seq) override {
+    bool all_unchanged = true;
+    Array<BindingBlock> new_blocks;
+    // apply CSE within dataflow blocks only
+    for (auto block : seq->blocks) {
+      if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
+        auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block));
+        if (!new_df_block.same_as(block)) {
+          new_blocks.push_back(new_df_block);
+          all_unchanged = false;
+          continue;
+        }
+      }
+      new_blocks.push_back(block);
+    }
+
+    if (all_unchanged) {
+      return GetRef<Expr>(seq);
+    }
+    // do not visit the body
+    return SeqExpr(new_blocks, seq->body, seq->span);
+  }

Review Comment:
   This seemed like a bit of a strange thing to have to implement in a dataflow block pass. It could be avoided if we require lambda-lifting (arguably we should for all DF block passes).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1144059745


##########
src/relax/transform/eliminate_common_subexpr.cc:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/eliminate_common_subexpr.cc
+ * \brief Eliminrate common subexpression pass.
+ *
+ * Currently it removes common subexpressions within a DataflowBlock.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class SubexprCounter : public ExprVisitor {
+ public:
+  // overriding VisitExpr ensures we do this for every subexpression
+  void VisitExpr(const Expr& e) override {
+    // Cases we ignore because we will not substitute them:
+    // 1. Vars of all kinds
+    // 2. Op nodes (nothing we can do)
+    // 3. Scalar constants (not much benefit from binding to a var)
+    if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
+          e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
+          (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {

Review Comment:
   It's up to us if we think it's worth rebinding scalar constants. I doubt it, though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1146973483


##########
tests/python/relax/test_transform_cse.py:
##########
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test eliminate common subexpr pass"""
+import tvm
+import tvm.testing
+from tvm.relax.transform import EliminateCommonSubexpr
+from tvm.script.parser import ir as I, relax as R, tir as T
+
+import numpy as np
+
+
+def verify(input, expected):
+    tvm.ir.assert_structural_equal(EliminateCommonSubexpr()(input), expected)
+
+
+def test_simple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                lv1 = R.add(x, y)
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv0 = R.add(x, y)
+                # can combine with canonicalizing bindings
+                # and getting rid of unused bindings to eliminate this line too
+                lv1 = lv0
+                gv = R.multiply(lv0, lv1)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_constants():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                # we are not going to bind the constant 1 to a var
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                # we expect to bind the repeated large constants
+                lv1 = R.add(
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                )
+                gv = (lv0, lv1)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), dtype="int32")):
+            with R.dataflow():
+                lv0 = R.add(R.const(1, dtype="int32"), R.const(1, dtype="int32"))
+                lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
+                lv2 = R.add(lv1, lv1)
+                gv = (lv0, lv2)
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_repeated_inner_tuples():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # repeated units: (x, x), (x, (x, x)), ((x, x), (x, (x, x)))
+                tup = (((x, x), (x, (x, x))), ((x, x), (x, (x, x))), (x, (x, x)))
+                gv = tup[0][0][1]
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                t1 = (x, x)
+                t2 = (x, t1)
+                t3 = (t1, t2)
+                t4 = (t3, t3, t2)
+                gv = t4[0][0][1]
+                R.output(gv)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_inner_function():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                # we are going to do CSE inside the local function
+                @R.function
+                def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):

Review Comment:
   This actually will handle duplicate local functions. We could do a module-level analysis too for global functions



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] spectrometerHBH commented on pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "spectrometerHBH (via GitHub)" <gi...@apache.org>.
spectrometerHBH commented on PR #14361:
URL: https://github.com/apache/tvm/pull/14361#issuecomment-1479748491

   Nice work!
   
   In what scenarios do we want to apply this pass?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on PR #14361:
URL: https://github.com/apache/tvm/pull/14361#issuecomment-1480328892

   In this case, the bindings will be live only within a single dataflow block, so I don't think there will be many issues with keeping values live for much longer than they would be otherwise.
   
   It would be easy to add a heuristic for deciding when we shouldn't deduplicate.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1144153247


##########
src/relax/transform/eliminate_common_subexpr.cc:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/eliminate_common_subexpr.cc
+ * \brief Eliminrate common subexpression pass.
+ *
+ * Currently it removes common subexpressions within a DataflowBlock.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class SubexprCounter : public ExprVisitor {
+ public:
+  // overriding VisitExpr ensures we do this for every subexpression
+  void VisitExpr(const Expr& e) override {
+    // Cases we ignore because we will not substitute them:
+    // 1. Vars of all kinds
+    // 2. Op nodes (nothing we can do)
+    // 3. Scalar constants (not much benefit from binding to a var)
+    if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
+          e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
+          (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
+      int count = 0;
+      if (count_map_.count(e)) {
+        count = count_map_.at(e);
+      }
+      count_map_[e] = count + 1;
+    }
+    ExprVisitor::VisitExpr(e);
+  }
+
+  // do not visit inner functions: we will do CSE within those
+  void VisitExpr_(const FunctionNode* func) override {}
+
+  // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+  void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
+
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(
+      const DataflowBlock& df_block) {
+    for (auto binding : df_block->bindings) {
+      VisitBinding(binding);
+    }
+    return count_map_;
+  }
+
+ private:
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
+};
+
+// forward declaration
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock&);
+
+class CommonSubexprEliminator : public ExprMutator {
+ public:
+  explicit CommonSubexprEliminator(
+      const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map)
+      : count_map_(count_map) {}
+
+  // overriding here ensures we visit every subexpression
+  Expr VisitExpr(const Expr& e) override {
+    if (count_map_.count(e) && count_map_.at(e) > 1) {
+      // if we already have a mapping for it, get it
+      if (replacements_.count(e)) {
+        return replacements_.at(e);
+      }
+      // Otherwise, insert a new binding for the current expression.
+      // Visit before emitting to do inner replacements
+      Expr new_e = ExprMutator::VisitExpr(e);
+      Var v = builder_->Emit(new_e);
+      replacements_[e] = v;
+      return v;
+    }
+    return ExprMutator::VisitExpr(e);
+  }
+
+  // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+  StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
+    return struct_info;
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) override {
+    // for an inner function, we will do CSE on its body
+    Expr new_body = ExprMutator::VisitExpr(func->body);
+    if (new_body.same_as(func->body)) {
+      return GetRef<Expr>(func);
+    }
+    return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span);
+  }
+
+  // this should happen only for the inner function case
+  Expr VisitExpr_(const SeqExprNode* seq) override {
+    bool all_unchanged = true;
+    Array<BindingBlock> new_blocks;
+    // apply CSE within dataflow blocks only
+    for (auto block : seq->blocks) {
+      if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
+        auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block));
+        if (!new_df_block.same_as(block)) {
+          new_blocks.push_back(new_df_block);
+          all_unchanged = false;
+          continue;
+        }
+      }
+      new_blocks.push_back(block);
+    }
+
+    if (all_unchanged) {
+      return GetRef<Expr>(seq);
+    }
+    // do not visit the body
+    return SeqExpr(new_blocks, seq->body, seq->span);
+  }

Review Comment:
   Update: Based on the Unity Community Meeting discussion, it doesn't sound like there is much appetite for imposing phase orderings like this, so I would be interested instead if there is a clean way to deal with local functions in dataflow block passes (generalizing the approach shown here, for example). I am sure that other dataflow block passes don't handle the local function case and might exhibit strange bugs if given a program with local functions
   
   e: Possible solution: Change `DataflowBlockMutator` to look for inner functions and process them separately. This way the `pass_func` for dataflow block passes can safely just ignore inner functions and it would be handled by the pass infrastructure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky commented on a diff in pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky commented on code in PR #14361:
URL: https://github.com/apache/tvm/pull/14361#discussion_r1144153247


##########
src/relax/transform/eliminate_common_subexpr.cc:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/eliminate_common_subexpr.cc
+ * \brief Eliminrate common subexpression pass.
+ *
+ * Currently it removes common subexpressions within a DataflowBlock.
+ */
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+
+class SubexprCounter : public ExprVisitor {
+ public:
+  // overriding VisitExpr ensures we do this for every subexpression
+  void VisitExpr(const Expr& e) override {
+    // Cases we ignore because we will not substitute them:
+    // 1. Vars of all kinds
+    // 2. Op nodes (nothing we can do)
+    // 3. Scalar constants (not much benefit from binding to a var)
+    if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
+          e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
+          (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
+      int count = 0;
+      if (count_map_.count(e)) {
+        count = count_map_.at(e);
+      }
+      count_map_[e] = count + 1;
+    }
+    ExprVisitor::VisitExpr(e);
+  }
+
+  // do not visit inner functions: we will do CSE within those
+  void VisitExpr_(const FunctionNode* func) override {}
+
+  // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+  void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
+
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(
+      const DataflowBlock& df_block) {
+    for (auto binding : df_block->bindings) {
+      VisitBinding(binding);
+    }
+    return count_map_;
+  }
+
+ private:
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
+};
+
+// forward declaration
+DataflowBlock EliminateCommonSubexpr(const DataflowBlock&);
+
+class CommonSubexprEliminator : public ExprMutator {
+ public:
+  explicit CommonSubexprEliminator(
+      const std::unordered_map<Expr, int, StructuralHash, StructuralEqual>& count_map)
+      : count_map_(count_map) {}
+
+  // overriding here ensures we visit every subexpression
+  Expr VisitExpr(const Expr& e) override {
+    if (count_map_.count(e) && count_map_.at(e) > 1) {
+      // if we already have a mapping for it, get it
+      if (replacements_.count(e)) {
+        return replacements_.at(e);
+      }
+      // Otherwise, insert a new binding for the current expression.
+      // Visit before emitting to do inner replacements
+      Expr new_e = ExprMutator::VisitExpr(e);
+      Var v = builder_->Emit(new_e);
+      replacements_[e] = v;
+      return v;
+    }
+    return ExprMutator::VisitExpr(e);
+  }
+
+  // we are not going to do replacements inside struct info to avoid binding lots of reused shapes
+  StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override {
+    return struct_info;
+  }
+
+  Expr VisitExpr_(const FunctionNode* func) override {
+    // for an inner function, we will do CSE on its body
+    Expr new_body = ExprMutator::VisitExpr(func->body);
+    if (new_body.same_as(func->body)) {
+      return GetRef<Expr>(func);
+    }
+    return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span);
+  }
+
+  // this should happen only for the inner function case
+  Expr VisitExpr_(const SeqExprNode* seq) override {
+    bool all_unchanged = true;
+    Array<BindingBlock> new_blocks;
+    // apply CSE within dataflow blocks only
+    for (auto block : seq->blocks) {
+      if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
+        auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block));
+        if (!new_df_block.same_as(block)) {
+          new_blocks.push_back(new_df_block);
+          all_unchanged = false;
+          continue;
+        }
+      }
+      new_blocks.push_back(block);
+    }
+
+    if (all_unchanged) {
+      return GetRef<Expr>(seq);
+    }
+    // do not visit the body
+    return SeqExpr(new_blocks, seq->body, seq->span);
+  }

Review Comment:
   Update: Based on the Unity Community Meeting discussion, it doesn't sound like there is much appetite for imposing phase orderings like this, so I would be interested instead if there is a clean way to deal with local functions in dataflow block passes (generalizing the approach shown here, for example). I am sure that other dataflow block passes don't handle the local function case and might exhibit strange bugs if given a program with local functions



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] slyubomirsky merged pull request #14361: [Unity][Transform] Common Subexpression Elimination

Posted by "slyubomirsky (via GitHub)" <gi...@apache.org>.
slyubomirsky merged PR #14361:
URL: https://github.com/apache/tvm/pull/14361


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org