You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/02 00:46:41 UTC

[tvm] branch main updated: [UnitTest][TIR] Testing utility for before/after transform tests (#12264)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bca0385862 [UnitTest][TIR] Testing utility for before/after transform tests (#12264)
bca0385862 is described below

commit bca0385862d93918fde877ed512c834fc32d80d5
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Aug 1 19:46:35 2022 -0500

    [UnitTest][TIR] Testing utility for before/after transform tests (#12264)
    
    This PR adds `tvm.testing.CompareBeforeAfter`, a generalization of the `BaseBeforeAfter` utility previously used in `test_tir_transform_simplify.py`, which performs unit tests that perform a transformation on a TIR function and compare the results to an expected TIR output.  This arose when minimizing the boilerplate required for unit tests in the implementation of https://github.com/apache/tvm/issues/12261.
---
 python/tvm/testing/utils.py                        | 206 +++++++++++++++++++++
 .../python/unittest/test_tir_transform_simplify.py |  38 +---
 .../unittest/test_tvm_testing_before_after.py      |  83 +++++++++
 3 files changed, 291 insertions(+), 36 deletions(-)

diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index e3148a26c2..5b7a600c78 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -74,6 +74,7 @@ import os
 import pickle
 import platform
 import sys
+import textwrap
 import time
 import shutil
 
@@ -1712,3 +1713,208 @@ def fetch_model_from_url(
 def main():
     test_file = inspect.getsourcefile(sys._getframe(1))
     sys.exit(pytest.main([test_file] + sys.argv[1:]))
+
+
+class CompareBeforeAfter:
+    """Utility for comparing before/after of TIR transforms
+
+    A standard framework for writing tests that take a TIR PrimFunc as
+    input, apply a transformation, then either compare against an
+    expected output or assert that the transformation raised an error.
+    A test should subclass CompareBeforeAfter, defining class members
+    `before`, `transform`, and `expected`.  CompareBeforeAfter will
+    then use these members to define a test method and test fixture.
+
+    `transform` may be one of the following.
+
+    - An instance of `tvm.ir.transform.Pass`
+
+    - A method that takes no arguments and returns a `tvm.ir.transform.Pass`
+
+    - A pytest fixture that returns a `tvm.ir.transform.Pass`
+
+    `before` may be any one of the following.
+
+    - An instance of `tvm.tir.PrimFunc`.  This is allowed, but is not
+      the preferred method, as any errors in constructing the
+      `PrimFunc` occur while collecting the test, preventing any other
+      tests in the same file from being run.
+
+    - An TVMScript function, without the ``@T.prim_func`` decoration.
+      The ``@T.prim_func`` decoration will be applied when running the
+      test, rather than at module import.
+
+    - A method that takes no arguments and returns a `tvm.tir.PrimFunc`
+
+    - A pytest fixture that returns a `tvm.tir.PrimFunc`
+
+    `expected` may be any one of the following.  The type of
+    `expected` defines the test being performed.  If `expected`
+    provides a `tvm.tir.PrimFunc`, the result of the transformation
+    must match `expected`.  If `expected` is an exception, then the
+    transformation must raise that exception type.
+
+    - Any option supported for `before`.
+
+    - The `Exception` class object, or a class object that inherits
+      from `Exception`.
+
+    - A method that takes no arguments and returns `Exception` or a
+      class object that inherits from `Exception`.
+
+    - A pytest fixture that returns `Exception` or an class object
+      that inherits from `Exception`.
+
+    Examples
+    --------
+
+    .. python::
+
+        class TestRemoveIf(tvm.testing.CompareBeforeAfter):
+            transform = tvm.tir.transform.Simplify()
+
+            def before(A: T.Buffer[1, "int32"]):
+                if True:
+                    A[0] = 42
+                else:
+                    A[0] = 5
+
+            def expected(A: T.Buffer[1, "int32"]):
+                A[0] = 42
+
+    """
+
+    def __init_subclass__(cls):
+        if hasattr(cls, "before"):
+            cls.before = cls._normalize_before(cls.before)
+        if hasattr(cls, "expected"):
+            cls.expected = cls._normalize_expected(cls.expected)
+        if hasattr(cls, "transform"):
+            cls.transform = cls._normalize_transform(cls.transform)
+
+    @classmethod
+    def _normalize_before(cls, func):
+        if hasattr(func, "_pytestfixturefunction"):
+            return func
+
+        if isinstance(func, tvm.tir.PrimFunc):
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                return func
+
+        elif cls._is_method(func):
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                return func(self)
+
+        else:
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
+                return tvm.script.from_source(source_code)
+
+        return pytest.fixture(inner)
+
+    @classmethod
+    def _normalize_expected(cls, func):
+        if hasattr(func, "_pytestfixturefunction"):
+            return func
+
+        if isinstance(func, tvm.tir.PrimFunc) or (
+            inspect.isclass(func) and issubclass(func, Exception)
+        ):
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                return func
+
+        elif cls._is_method(func):
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                return func(self)
+
+        else:
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
+                return tvm.script.from_source(source_code)
+
+        return pytest.fixture(inner)
+
+    @classmethod
+    def _normalize_transform(cls, transform):
+        if hasattr(transform, "_pytestfixturefunction"):
+            return transform
+
+        if isinstance(transform, tvm.ir.transform.Pass):
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                return transform
+
+        elif cls._is_method(transform):
+
+            def inner(self):
+                # pylint: disable=unused-argument
+                return transform(self)
+
+        else:
+
+            raise TypeError(
+                "Expected transform to be a tvm.ir.transform.Pass, or a method returning a Pass"
+            )
+
+        return pytest.fixture(inner)
+
+    @staticmethod
+    def _is_method(func):
+        sig = inspect.signature(func)
+        return "self" in sig.parameters
+
+    def test_compare(self, before, expected, transform):
+        """Unit test to compare the expected TIR PrimFunc to actual"""
+
+        before_mod = tvm.IRModule.from_expr(before)
+
+        if inspect.isclass(expected) and issubclass(expected, Exception):
+            with pytest.raises(expected):
+                after_mod = transform(before_mod)
+
+                # This portion through pytest.fail isn't strictly
+                # necessary, but gives a better error message that
+                # includes the before/after.
+                after = after_mod["main"]
+                script = tvm.IRModule({"after": after, "before": before}).script()
+                pytest.fail(
+                    msg=(
+                        f"Expected {expected.__name__} to be raised from transformation, "
+                        f"instead received TIR\n:{script}"
+                    )
+                )
+
+        elif isinstance(expected, tvm.tir.PrimFunc):
+            after_mod = transform(before_mod)
+            after = after_mod["main"]
+
+            try:
+                tvm.ir.assert_structural_equal(after, expected)
+            except ValueError as err:
+                script = tvm.IRModule(
+                    {"expected": expected, "after": after, "before": before}
+                ).script()
+                raise ValueError(
+                    f"TIR after transformation did not match expected:\n{script}"
+                ) from err
+
+        else:
+            raise TypeError(
+                f"tvm.testing.CompareBeforeAfter requires the `expected` fixture "
+                f"to return either `Exception`, an `Exception` subclass, "
+                f"or an instance of `tvm.tir.PrimFunc`.  "
+                f"Instead, received {type(exception)}."
+            )
diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py
index 529b454811..4ac502b211 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -136,31 +136,16 @@ def test_complex_likely_elimination():
     assert "if" not in str(stmt)
 
 
-class BaseBeforeAfter:
-    def test_simplify(self):
-        before = self.before
-        before_mod = tvm.IRModule.from_expr(before)
-        after_mod = tvm.tir.transform.Simplify()(before_mod)
-        after = after_mod["main"]
-        expected = self.expected
-
-        try:
-            tvm.ir.assert_structural_equal(after, expected)
-        except ValueError as err:
-            script = tvm.IRModule({"expected": expected, "after": after, "before": before}).script()
-            raise ValueError(
-                f"Function after simplification did not match expected:\n{script}"
-            ) from err
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+    transform = tvm.tir.transform.Simplify()
 
 
 class TestLoadStoreNoop(BaseBeforeAfter):
     """Store of a value that was just read from the same location is a no-op."""
 
-    @T.prim_func
     def before(A: T.Buffer[(1,), "float32"]):
         A[0] = A[0]
 
-    @T.prim_func
     def expected(A: T.Buffer[(1,), "float32"]):
         T.evaluate(0)
 
@@ -174,11 +159,9 @@ class TestLoadStoreNoopAfterSimplify(BaseBeforeAfter):
     regression.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(1,), "float32"]):
         A[0] = A[0] + (5.0 - 5.0)
 
-    @T.prim_func
     def expected(A: T.Buffer[(1,), "float32"]):
         T.evaluate(0)
 
@@ -191,14 +174,12 @@ class TestNestedCondition(BaseBeforeAfter):
     constraint.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(16,), "float32"]):
         for i in T.serial(16):
             if i == 5:
                 if i == 5:
                     A[i] = 0.0
 
-    @T.prim_func
     def expected(A: T.Buffer[(16,), "float32"]):
         for i in T.serial(16):
             if i == 5:
@@ -212,14 +193,12 @@ class TestNestedProvableCondition(BaseBeforeAfter):
     conditional.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(16,), "float32"]):
         for i in T.serial(16):
             if i == 5:
                 if i < 7:
                     A[i] = 0.0
 
-    @T.prim_func
     def expected(A: T.Buffer[(16,), "float32"]):
         for i in T.serial(16):
             if i == 5:
@@ -233,14 +212,12 @@ class TestNestedVarCondition(BaseBeforeAfter):
     constraint.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(16,), "float32"], n: T.int32):
         for i in T.serial(16):
             if i == n:
                 if i == n:
                     A[i] = 0.0
 
-    @T.prim_func
     def expected(A: T.Buffer[(16,), "float32"], n: T.int32):
         for i in T.serial(16):
             if i == n:
@@ -256,7 +233,6 @@ class TestAlteredBufferContents(BaseBeforeAfter):
     may not.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(1,), "int32"], n: T.int32):
         if A[0] == n:
             A[0] = A[0] + 1
@@ -273,7 +249,6 @@ class TestNegationOfCondition(BaseBeforeAfter):
     condition is known to be false.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(16,), "int32"]):
         for i in T.serial(16):
             if i == 5:
@@ -282,7 +257,6 @@ class TestNegationOfCondition(BaseBeforeAfter):
                 else:
                     A[i] = 1
 
-    @T.prim_func
     def expected(A: T.Buffer[(16,), "int32"]):
         for i in T.serial(16):
             if i == 5:
@@ -298,7 +272,6 @@ class TestNegationOfNotEqual(BaseBeforeAfter):
     ``i==5`` as the negation of a literal constraint.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(16,), "int32"]):
         for i in T.serial(16):
             if i != 5:
@@ -307,7 +280,6 @@ class TestNegationOfNotEqual(BaseBeforeAfter):
                 else:
                     A[i] = 1
 
-    @T.prim_func
     def expected(A: T.Buffer[(16,), "int32"]):
         for i in T.serial(16):
             if i != 5:
@@ -321,7 +293,6 @@ class TestNegationOfVarCondition(BaseBeforeAfter):
     must rely on RewriteSimplifier recognizing the repeated literal.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(16,), "int32"], n: T.int32):
         for i in T.serial(16):
             if i == n:
@@ -330,7 +301,6 @@ class TestNegationOfVarCondition(BaseBeforeAfter):
                 else:
                     A[i] = 1
 
-    @T.prim_func
     def expected(A: T.Buffer[(16,), "int32"], n: T.int32):
         for i in T.serial(16):
             if i == n:
@@ -346,14 +316,12 @@ class TestLiteralConstraintSplitBooleanAnd(BaseBeforeAfter):
     the condition is to ensure we exercise RewriteSimplifier.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
         for i, j in T.grid(16, 16):
             if i == n and j == n:
                 if i == n:
                     A[i, j] = 0
 
-    @T.prim_func
     def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
         for i, j in T.grid(16, 16):
             if i == n and j == n:
@@ -371,7 +339,6 @@ class TestLiteralConstraintSplitBooleanOr(BaseBeforeAfter):
     RewriteSimplifier.
     """
 
-    @T.prim_func
     def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
         for i, j in T.grid(16, 16):
             if i == n or j == n:
@@ -382,7 +349,6 @@ class TestLiteralConstraintSplitBooleanOr(BaseBeforeAfter):
                 else:
                     A[i, j] = 2
 
-    @T.prim_func
     def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
         for i, j in T.grid(16, 16):
             if i == n or j == n:
diff --git a/tests/python/unittest/test_tvm_testing_before_after.py b/tests/python/unittest/test_tvm_testing_before_after.py
new file mode 100644
index 0000000000..613d66ccdb
--- /dev/null
+++ b/tests/python/unittest/test_tvm_testing_before_after.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+    def transform(self):
+        return lambda x: x
+
+
+class TestBeforeAfterPrimFunc(BaseBeforeAfter):
+    @T.prim_func
+    def before():
+        T.evaluate(0)
+
+    expected = before
+
+
+class TestBeforeAfterMethod(BaseBeforeAfter):
+    def before(self):
+        @T.prim_func
+        def func():
+            T.evaluate(0)
+
+        return func
+
+    expected = before
+
+
+class TestBeforeAfterFixture(BaseBeforeAfter):
+    @tvm.testing.fixture
+    def before(self):
+        @T.prim_func
+        def func():
+            T.evaluate(0)
+
+        return func
+
+    expected = before
+
+
+class TestBeforeAfterDelayedPrimFunc(BaseBeforeAfter):
+    def before():
+        T.evaluate(0)
+
+    expected = before
+
+
+class TestBeforeAfterParametrizedFixture(BaseBeforeAfter):
+    n = tvm.testing.parameter(1, 8, 16)
+
+    @tvm.testing.fixture
+    def before(self, n):
+        @T.prim_func
+        def func(A: T.Buffer[n, "float32"]):
+            for i in T.serial(n):
+                A[i] = 0.0
+
+        return func
+
+    expected = before
+
+
+if __name__ == "__main__":
+    tvm.testing.main()