You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/26 16:52:06 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #8010: [UnitTests] Automatic parametrization over targets, with explicit opt-out

tkonolige commented on a change in pull request #8010:
URL: https://github.com/apache/tvm/pull/8010#discussion_r639912818



##########
File path: python/tvm/testing.py
##########
@@ -366,24 +368,37 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap):
     )
 
 
-def _get_targets():
-    target_str = os.environ.get("TVM_TEST_TARGETS", "")
+def _get_targets(target_str=None):
+    if target_str is None:
+        target_str = os.environ.get("TVM_TEST_TARGETS", "")
+
     if len(target_str) == 0:
         target_str = DEFAULT_TEST_TARGETS
-    targets = set()
-    for dev in target_str.split(";"):
-        if len(dev) == 0:
-            continue
-        target_kind = dev.split()[0]
-        if tvm.runtime.enabled(target_kind) and tvm.device(target_kind, 0).exist:
-            targets.add(dev)
-    if len(targets) == 0:
+
+    target_names = set(t.strip() for t in target_str.split(";") if t.strip())
+
+    targets = []
+    for target in target_names:
+        target_kind = target.split()[0]
+        is_enabled = tvm.runtime.enabled(target_kind)
+        is_runnable = is_enabled and tvm.device(target_kind).exist
+        targets.append(
+            {
+                "target": target,
+                "target_kind": target_kind,
+                "is_enabled": is_enabled,
+                "is_runnable": is_runnable,
+            }
+        )
+
+    if all(not t["is_runnable"] for t in targets):
         logging.warning(
             "None of the following targets are supported by this build of TVM: %s."
             " Try setting TVM_TEST_TARGETS to a supported target. Defaulting to llvm.",
             target_str,
         )
-        return {"llvm"}
+        return _get_targets("llvm")

Review comment:
       Does this loop forever if llvm is not enabled?

##########
File path: python/tvm/testing.py
##########
@@ -718,33 +802,364 @@ def parametrize_targets(*args):
 
     Example
     -------
-    >>> @tvm.testing.parametrize
+    >>> @tvm.testing.parametrize_targets
     >>> def test_mytest(target, dev):
     >>>     ...  # do something
 
     Or
 
-    >>> @tvm.testing.parametrize("llvm", "cuda")
+    >>> @tvm.testing.parametrize_targets("llvm", "cuda")
     >>> def test_mytest(target, dev):
     >>>     ...  # do something
+
     """
 
     def wrap(targets):
         def func(f):
-            params = [
-                pytest.param(target, tvm.device(target, 0), marks=_target_to_requirement(target))
-                for target in targets
-            ]
-            return pytest.mark.parametrize("target,dev", params)(f)
+            return pytest.mark.parametrize(
+                "target", _pytest_target_params(targets), scope="session"
+            )(f)
 
         return func
 
     if len(args) == 1 and callable(args[0]):
-        targets = [t for t, _ in enabled_targets()]
-        return wrap(targets)(args[0])
+        return wrap(None)(args[0])
     return wrap(args)
 
 
+def exclude_targets(*args):
+    """Exclude a test from running on a particular target.
+
+    Use this decorator when you want your test to be run over a
+    variety of targets and devices (including cpu and gpu devices),
+    but want to exclude some particular target or targets.  For
+    example, a test may wish to be run against all targets in
+    tvm.testing.enabled_targets(), except for a particular target that
+    does not support the capabilities.
+
+    Applies pytest.mark.skipif to the targets given.
+
+    Parameters
+    ----------
+    f : function
+        Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, dev)`:,
+        where `xxxxxxxxx` is any name.
+    targets : list[str]
+        Set of targets to exclude.
+
+    Example
+    -------
+    >>> @tvm.testing.exclude_targets("cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+
+    Or
+
+    >>> @tvm.testing.exclude_targets("llvm", "cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+
+    """
+
+    def wraps(func):
+        func.tvm_excluded_targets = args
+        return func
+
+    return wraps
+
+
+def known_failing_targets(*args):
+    """Skip a test that is known to fail on a particular target.
+
+    Use this decorator when you want your test to be run over a
+    variety of targets and devices (including cpu and gpu devices),
+    but know that it fails for some targets.  For example, a newly
+    implemented runtime may not support all features being tested, and
+    should be excluded.
+
+    Applies pytest.mark.xfail to the targets given.
+
+    Parameters
+    ----------
+    f : function
+        Function to parametrize. Must be of the form `def test_xxxxxxxxx(target, dev)`:,
+        where `xxxxxxxxx` is any name.
+    targets : list[str]
+        Set of targets to skip.
+
+    Example
+    -------
+    >>> @tvm.testing.known_failing_targets("cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+
+    Or
+
+    >>> @tvm.testing.known_failing_targets("llvm", "cuda")
+    >>> def test_mytest(target, dev):
+    >>>     ...  # do something
+
+    """
+
+    def wraps(func):
+        func.tvm_known_failing_targets = args
+        return func
+
+    return wraps
+
+
+def parameter(*values, ids=None):
+    """Convenience function to define pytest parametrized fixtures.
+
+    Declaring a variable using ``tvm.testing.parameter`` will define a
+    parametrized pytest fixture that can be used by test
+    functions. This is intended for cases that have no setup cost,
+    such as strings, integers, tuples, etc.  For cases that have a
+    significant setup cost, please use :py:func:`tvm.testing.fixture`
+    instead.
+
+    If a test function accepts multiple parameters defined using
+    ``tvm.testing.parameter``, then the test will be run using every
+    combination of those parameters.
+
+    The parameter definition applies to all tests in a module.  If a
+    specific test should have different values for the parameter, that
+    test should be marked with ``@pytest.mark.parametrize``.
+
+    Parameters
+    ----------
+    values
+       A list of parameter values.  A unit test that accepts this
+       parameter as an argument will be run once for each parameter
+       given.
+
+    ids : List[str], optional
+       A list of names for the parameters.  If None, pytest will
+       generate a name from the value.  These generated names may not
+       be readable/useful for composite types such as tuples.
+
+    Returns
+    -------
+    function
+       A function output from pytest.fixture.
+
+    Example
+    -------
+    >>> size = tvm.testing.parameter(1, 10, 100)
+    >>> def test_using_size(size):
+    >>>     ... # Test code here
+
+    Or
+
+    >>> shape = tvm.testing.parameter((5,10), (512,1024), ids=['small','large'])
+    >>> def test_using_size(shape):
+    >>>     ... # Test code here
+
+    """
+
+    @pytest.fixture(params=values, ids=ids)
+    def as_fixture(request):
+        return request.param
+
+    return as_fixture
+
+
+_parametrize_group = 0
+
+
+def parameters(*value_sets):
+    """Convenience function to define pytest parametrized fixtures.
+
+    Declaring a variable using tvm.testing.parameters will define a
+    parametrized pytest fixture that can be used by test
+    functions. Like :py:func:`tvm.testing.parameter`, this is intended
+    for cases that have no setup cost, such as strings, integers,
+    tuples, etc.  For cases that have a significant setup cost, please
+    use :py:func:`tvm.testing.fixture` instead.
+
+    Unlike :py:func:`tvm.testing.parameter`, if a test function
+    accepts multiple parameters defined using a single call to
+    ``tvm.testing.parameters``, then the test will only be run once
+    for each set of parameters, not for all combinations of
+    parameters.
+
+    These parameter definitions apply to all tests in a module.  If a
+    specific test should have different values for some parameters,
+    that test should be marked with ``@pytest.mark.parametrize``.
+
+    Parameters
+    ----------
+    values : List[tuple]
+       A list of parameter value sets.  Each set of values represents
+       a single combination of values to be tested.  A unit test that
+       accepts parameters defined will be run once for every set of
+       parameters in the list.
+
+    Returns
+    -------
+    List[function]
+       Function outputs from pytest.fixture.  These should be unpacked
+       into individual named parameters.
+
+    Example
+    -------
+    >>> size, dtype = tvm.testing.parameters( (16,'float32'), (512,'float16') )
+    >>> def test_feature_x(size, dtype):
+    >>>     # Test code here
+    >>>     assert( (size,dtype) in [(16,'float32'), (512,'float16')])
+
+    """
+    global _parametrize_group
+    parametrize_group = _parametrize_group
+    _parametrize_group += 1
+
+    outputs = []
+    for param_values in zip(*value_sets):
+
+        def fixture_func(request):
+            return request.param
+
+        fixture_func.parametrize_group = parametrize_group
+        fixture_func.parametrize_values = param_values
+        outputs.append(pytest.fixture(fixture_func))
+
+    return outputs
+
+
+def _parametrize_correlated_parameters(metafunc):
+    parametrize_needed = collections.defaultdict(list)
+
+    for name, fixturedefs in metafunc.definition._fixtureinfo.name2fixturedefs.items():
+        fixturedef = fixturedefs[-1]
+        if hasattr(fixturedef.func, "parametrize_group") and hasattr(
+            fixturedef.func, "parametrize_values"
+        ):
+            group = fixturedef.func.parametrize_group
+            values = fixturedef.func.parametrize_values
+            parametrize_needed[group].append((name, values))
+
+    for parametrize_group in parametrize_needed.values():
+        if len(parametrize_group) == 1:
+            name, values = parametrize_group[0]
+            metafunc.parametrize(name, values, indirect=True)
+        else:
+            names = ",".join(name for name, values in parametrize_group)
+            value_sets = zip(*[values for name, values in parametrize_group])
+            metafunc.parametrize(names, value_sets, indirect=True)
+
+
+def fixture(func=None, *, cache_return_value=False):

Review comment:
       Are you allowed to have an optional parameter before regular arguments? I think lint will not be happy with this one.

##########
File path: python/tvm/testing.py
##########
@@ -701,11 +717,79 @@ def _target_to_requirement(target):
     return []
 
 
+def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None):
+    # Include unrunnable targets here.  They get skipped by the
+    # pytest.mark.skipif in _target_to_requirement(), showing up as
+    # skipped tests instead of being hidden entirely.
+    if targets is None:
+        if excluded_targets is None:
+            excluded_targets = set()
+
+        if xfail_targets is None:
+            xfail_targets = set()
+
+        target_marks = []
+        for t in _get_targets():
+            # Excluded targets aren't included in the params at all.
+            if t["target_kind"] not in excluded_targets:
+
+                # Known failing targets are included, but are marked
+                # as expected to fail.
+                extra_marks = []
+                if t["target_kind"] in xfail_targets:
+                    extra_marks.append(
+                        pytest.mark.xfail(
+                            reason='Known failing test for target "{}"'.format(t["target_kind"])
+                        )
+                    )
+                target_marks.append((t["target"], extra_marks))
+
+    else:
+        target_marks = [(target, []) for target in targets]
+
+    return [
+        pytest.param(target, marks=_target_to_requirement(target) + extra_marks)
+        for target, extra_marks in target_marks
+    ]
+
+
+def _auto_parametrize_target(metafunc):
+    """Automatically applies parametrize_targets
+
+    Used if a test function uses the "target" fixture, but isn't
+    already marked with @tvm.testing.parametrize_targets.  Intended
+    for use in the pytest_generate_tests() handler of a conftest.py
+    file.
+
+    """
+    if "target" in metafunc.fixturenames:
+        parametrized_args = [
+            arg.strip()
+            for mark in metafunc.definition.iter_markers("parametrize")
+            for arg in mark.args[0].split(",")
+        ]
+
+        if "target" not in parametrized_args:
+            # Check if the function is marked with either excluded or
+            # known failing targets.
+            excluded_targets = getattr(metafunc.function, "tvm_excluded_targets", [])
+            xfail_targets = getattr(metafunc.function, "tvm_known_failing_targets", [])
+            metafunc.parametrize(
+                "target",
+                _pytest_target_params(None, excluded_targets, xfail_targets),
+                scope="session",
+            )
+
+
 def parametrize_targets(*args):
     """Parametrize a test over all enabled targets.
 
-    Use this decorator when you want your test to be run over a variety of
-    targets and devices (including cpu and gpu devices).
+    Use this decorator when you want your test to be run over a
+    variety of targets and devices (including cpu and gpu devices).
+
+    Alternatively, a test that accepts the "target" and "dev" will

Review comment:
       Maybe specify that you want to use `parameterize_targets` when you have a specific set of targets you want to run over. Otherwise users should not use the decorator. Also mention that exclude_targets may be a better option.

##########
File path: python/tvm/testing.py
##########
@@ -701,11 +717,79 @@ def _target_to_requirement(target):
     return []
 
 
+def _pytest_target_params(targets, excluded_targets=None, xfail_targets=None):
+    # Include unrunnable targets here.  They get skipped by the
+    # pytest.mark.skipif in _target_to_requirement(), showing up as
+    # skipped tests instead of being hidden entirely.
+    if targets is None:
+        if excluded_targets is None:
+            excluded_targets = set()
+
+        if xfail_targets is None:
+            xfail_targets = set()
+
+        target_marks = []
+        for t in _get_targets():

Review comment:
       Doesn't `_get_targets` filter out all non-unable targets? So we are not including unrunable targets here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org