You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/02 00:46:41 UTC
[tvm] branch main updated: [UnitTest][TIR] Testing utility for before/after transform tests (#12264)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bca0385862 [UnitTest][TIR] Testing utility for before/after transform tests (#12264)
bca0385862 is described below
commit bca0385862d93918fde877ed512c834fc32d80d5
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Mon Aug 1 19:46:35 2022 -0500
[UnitTest][TIR] Testing utility for before/after transform tests (#12264)
This PR adds `tvm.testing.CompareBeforeAfter`, a generalization of the `BaseBeforeAfter` utility previously used in `test_tir_transform_simplify.py`, which performs unit tests that perform a transformation on a TIR function and compare the results to an expected TIR output. This arose when minimizing the boilerplate required for unit tests in the implementation of https://github.com/apache/tvm/issues/12261.
---
python/tvm/testing/utils.py | 206 +++++++++++++++++++++
.../python/unittest/test_tir_transform_simplify.py | 38 +---
.../unittest/test_tvm_testing_before_after.py | 83 +++++++++
3 files changed, 291 insertions(+), 36 deletions(-)
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index e3148a26c2..5b7a600c78 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -74,6 +74,7 @@ import os
import pickle
import platform
import sys
+import textwrap
import time
import shutil
@@ -1712,3 +1713,208 @@ def fetch_model_from_url(
def main():
test_file = inspect.getsourcefile(sys._getframe(1))
sys.exit(pytest.main([test_file] + sys.argv[1:]))
+
+
+class CompareBeforeAfter:
+ """Utility for comparing before/after of TIR transforms
+
+ A standard framework for writing tests that take a TIR PrimFunc as
+ input, apply a transformation, then either compare against an
+ expected output or assert that the transformation raised an error.
+ A test should subclass CompareBeforeAfter, defining class members
+ `before`, `transform`, and `expected`. CompareBeforeAfter will
+ then use these members to define a test method and test fixture.
+
+ `transform` may be one of the following.
+
+ - An instance of `tvm.ir.transform.Pass`
+
+ - A method that takes no arguments and returns a `tvm.ir.transform.Pass`
+
+ - A pytest fixture that returns a `tvm.ir.transform.Pass`
+
+ `before` may be any one of the following.
+
+ - An instance of `tvm.tir.PrimFunc`. This is allowed, but is not
+ the preferred method, as any errors in constructing the
+ `PrimFunc` occur while collecting the test, preventing any other
+ tests in the same file from being run.
+
+ - An TVMScript function, without the ``@T.prim_func`` decoration.
+ The ``@T.prim_func`` decoration will be applied when running the
+ test, rather than at module import.
+
+ - A method that takes no arguments and returns a `tvm.tir.PrimFunc`
+
+ - A pytest fixture that returns a `tvm.tir.PrimFunc`
+
+ `expected` may be any one of the following. The type of
+ `expected` defines the test being performed. If `expected`
+ provides a `tvm.tir.PrimFunc`, the result of the transformation
+ must match `expected`. If `expected` is an exception, then the
+ transformation must raise that exception type.
+
+ - Any option supported for `before`.
+
+ - The `Exception` class object, or a class object that inherits
+ from `Exception`.
+
+ - A method that takes no arguments and returns `Exception` or a
+ class object that inherits from `Exception`.
+
+ - A pytest fixture that returns `Exception` or an class object
+ that inherits from `Exception`.
+
+ Examples
+ --------
+
+ .. python::
+
+ class TestRemoveIf(tvm.testing.CompareBeforeAfter):
+ transform = tvm.tir.transform.Simplify()
+
+ def before(A: T.Buffer[1, "int32"]):
+ if True:
+ A[0] = 42
+ else:
+ A[0] = 5
+
+ def expected(A: T.Buffer[1, "int32"]):
+ A[0] = 42
+
+ """
+
+ def __init_subclass__(cls):
+ if hasattr(cls, "before"):
+ cls.before = cls._normalize_before(cls.before)
+ if hasattr(cls, "expected"):
+ cls.expected = cls._normalize_expected(cls.expected)
+ if hasattr(cls, "transform"):
+ cls.transform = cls._normalize_transform(cls.transform)
+
+ @classmethod
+ def _normalize_before(cls, func):
+ if hasattr(func, "_pytestfixturefunction"):
+ return func
+
+ if isinstance(func, tvm.tir.PrimFunc):
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ return func
+
+ elif cls._is_method(func):
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ return func(self)
+
+ else:
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
+ return tvm.script.from_source(source_code)
+
+ return pytest.fixture(inner)
+
+ @classmethod
+ def _normalize_expected(cls, func):
+ if hasattr(func, "_pytestfixturefunction"):
+ return func
+
+ if isinstance(func, tvm.tir.PrimFunc) or (
+ inspect.isclass(func) and issubclass(func, Exception)
+ ):
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ return func
+
+ elif cls._is_method(func):
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ return func(self)
+
+ else:
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ source_code = "@T.prim_func\n" + textwrap.dedent(inspect.getsource(func))
+ return tvm.script.from_source(source_code)
+
+ return pytest.fixture(inner)
+
+ @classmethod
+ def _normalize_transform(cls, transform):
+ if hasattr(transform, "_pytestfixturefunction"):
+ return transform
+
+ if isinstance(transform, tvm.ir.transform.Pass):
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ return transform
+
+ elif cls._is_method(transform):
+
+ def inner(self):
+ # pylint: disable=unused-argument
+ return transform(self)
+
+ else:
+
+ raise TypeError(
+ "Expected transform to be a tvm.ir.transform.Pass, or a method returning a Pass"
+ )
+
+ return pytest.fixture(inner)
+
+ @staticmethod
+ def _is_method(func):
+ sig = inspect.signature(func)
+ return "self" in sig.parameters
+
+ def test_compare(self, before, expected, transform):
+ """Unit test to compare the expected TIR PrimFunc to actual"""
+
+ before_mod = tvm.IRModule.from_expr(before)
+
+ if inspect.isclass(expected) and issubclass(expected, Exception):
+ with pytest.raises(expected):
+ after_mod = transform(before_mod)
+
+ # This portion through pytest.fail isn't strictly
+ # necessary, but gives a better error message that
+ # includes the before/after.
+ after = after_mod["main"]
+ script = tvm.IRModule({"after": after, "before": before}).script()
+ pytest.fail(
+ msg=(
+ f"Expected {expected.__name__} to be raised from transformation, "
+ f"instead received TIR\n:{script}"
+ )
+ )
+
+ elif isinstance(expected, tvm.tir.PrimFunc):
+ after_mod = transform(before_mod)
+ after = after_mod["main"]
+
+ try:
+ tvm.ir.assert_structural_equal(after, expected)
+ except ValueError as err:
+ script = tvm.IRModule(
+ {"expected": expected, "after": after, "before": before}
+ ).script()
+ raise ValueError(
+ f"TIR after transformation did not match expected:\n{script}"
+ ) from err
+
+ else:
+ raise TypeError(
+ f"tvm.testing.CompareBeforeAfter requires the `expected` fixture "
+ f"to return either `Exception`, an `Exception` subclass, "
+ f"or an instance of `tvm.tir.PrimFunc`. "
+ f"Instead, received {type(exception)}."
+ )
diff --git a/tests/python/unittest/test_tir_transform_simplify.py b/tests/python/unittest/test_tir_transform_simplify.py
index 529b454811..4ac502b211 100644
--- a/tests/python/unittest/test_tir_transform_simplify.py
+++ b/tests/python/unittest/test_tir_transform_simplify.py
@@ -136,31 +136,16 @@ def test_complex_likely_elimination():
assert "if" not in str(stmt)
-class BaseBeforeAfter:
- def test_simplify(self):
- before = self.before
- before_mod = tvm.IRModule.from_expr(before)
- after_mod = tvm.tir.transform.Simplify()(before_mod)
- after = after_mod["main"]
- expected = self.expected
-
- try:
- tvm.ir.assert_structural_equal(after, expected)
- except ValueError as err:
- script = tvm.IRModule({"expected": expected, "after": after, "before": before}).script()
- raise ValueError(
- f"Function after simplification did not match expected:\n{script}"
- ) from err
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+ transform = tvm.tir.transform.Simplify()
class TestLoadStoreNoop(BaseBeforeAfter):
"""Store of a value that was just read from the same location is a no-op."""
- @T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0]
- @T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)
@@ -174,11 +159,9 @@ class TestLoadStoreNoopAfterSimplify(BaseBeforeAfter):
regression.
"""
- @T.prim_func
def before(A: T.Buffer[(1,), "float32"]):
A[0] = A[0] + (5.0 - 5.0)
- @T.prim_func
def expected(A: T.Buffer[(1,), "float32"]):
T.evaluate(0)
@@ -191,14 +174,12 @@ class TestNestedCondition(BaseBeforeAfter):
constraint.
"""
- @T.prim_func
def before(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
if i == 5:
A[i] = 0.0
- @T.prim_func
def expected(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
@@ -212,14 +193,12 @@ class TestNestedProvableCondition(BaseBeforeAfter):
conditional.
"""
- @T.prim_func
def before(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
if i < 7:
A[i] = 0.0
- @T.prim_func
def expected(A: T.Buffer[(16,), "float32"]):
for i in T.serial(16):
if i == 5:
@@ -233,14 +212,12 @@ class TestNestedVarCondition(BaseBeforeAfter):
constraint.
"""
- @T.prim_func
def before(A: T.Buffer[(16,), "float32"], n: T.int32):
for i in T.serial(16):
if i == n:
if i == n:
A[i] = 0.0
- @T.prim_func
def expected(A: T.Buffer[(16,), "float32"], n: T.int32):
for i in T.serial(16):
if i == n:
@@ -256,7 +233,6 @@ class TestAlteredBufferContents(BaseBeforeAfter):
may not.
"""
- @T.prim_func
def before(A: T.Buffer[(1,), "int32"], n: T.int32):
if A[0] == n:
A[0] = A[0] + 1
@@ -273,7 +249,6 @@ class TestNegationOfCondition(BaseBeforeAfter):
condition is known to be false.
"""
- @T.prim_func
def before(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i == 5:
@@ -282,7 +257,6 @@ class TestNegationOfCondition(BaseBeforeAfter):
else:
A[i] = 1
- @T.prim_func
def expected(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i == 5:
@@ -298,7 +272,6 @@ class TestNegationOfNotEqual(BaseBeforeAfter):
``i==5`` as the negation of a literal constraint.
"""
- @T.prim_func
def before(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i != 5:
@@ -307,7 +280,6 @@ class TestNegationOfNotEqual(BaseBeforeAfter):
else:
A[i] = 1
- @T.prim_func
def expected(A: T.Buffer[(16,), "int32"]):
for i in T.serial(16):
if i != 5:
@@ -321,7 +293,6 @@ class TestNegationOfVarCondition(BaseBeforeAfter):
must rely on RewriteSimplifier recognizing the repeated literal.
"""
- @T.prim_func
def before(A: T.Buffer[(16,), "int32"], n: T.int32):
for i in T.serial(16):
if i == n:
@@ -330,7 +301,6 @@ class TestNegationOfVarCondition(BaseBeforeAfter):
else:
A[i] = 1
- @T.prim_func
def expected(A: T.Buffer[(16,), "int32"], n: T.int32):
for i in T.serial(16):
if i == n:
@@ -346,14 +316,12 @@ class TestLiteralConstraintSplitBooleanAnd(BaseBeforeAfter):
the condition is to ensure we exercise RewriteSimplifier.
"""
- @T.prim_func
def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n and j == n:
if i == n:
A[i, j] = 0
- @T.prim_func
def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n and j == n:
@@ -371,7 +339,6 @@ class TestLiteralConstraintSplitBooleanOr(BaseBeforeAfter):
RewriteSimplifier.
"""
- @T.prim_func
def before(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n or j == n:
@@ -382,7 +349,6 @@ class TestLiteralConstraintSplitBooleanOr(BaseBeforeAfter):
else:
A[i, j] = 2
- @T.prim_func
def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
for i, j in T.grid(16, 16):
if i == n or j == n:
diff --git a/tests/python/unittest/test_tvm_testing_before_after.py b/tests/python/unittest/test_tvm_testing_before_after.py
new file mode 100644
index 0000000000..613d66ccdb
--- /dev/null
+++ b/tests/python/unittest/test_tvm_testing_before_after.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+ def transform(self):
+ return lambda x: x
+
+
+class TestBeforeAfterPrimFunc(BaseBeforeAfter):
+ @T.prim_func
+ def before():
+ T.evaluate(0)
+
+ expected = before
+
+
+class TestBeforeAfterMethod(BaseBeforeAfter):
+ def before(self):
+ @T.prim_func
+ def func():
+ T.evaluate(0)
+
+ return func
+
+ expected = before
+
+
+class TestBeforeAfterFixture(BaseBeforeAfter):
+ @tvm.testing.fixture
+ def before(self):
+ @T.prim_func
+ def func():
+ T.evaluate(0)
+
+ return func
+
+ expected = before
+
+
+class TestBeforeAfterDelayedPrimFunc(BaseBeforeAfter):
+ def before():
+ T.evaluate(0)
+
+ expected = before
+
+
+class TestBeforeAfterParametrizedFixture(BaseBeforeAfter):
+ n = tvm.testing.parameter(1, 8, 16)
+
+ @tvm.testing.fixture
+ def before(self, n):
+ @T.prim_func
+ def func(A: T.Buffer[n, "float32"]):
+ for i in T.serial(n):
+ A[i] = 0.0
+
+ return func
+
+ expected = before
+
+
+if __name__ == "__main__":
+ tvm.testing.main()