You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/06/13 18:35:31 UTC

[GitHub] [tvm] areusch commented on a diff in pull request #11632: [AutoTVM][Autoscheduler] Default build funcs inherit PassContext

areusch commented on code in PR #11632:
URL: https://github.com/apache/tvm/pull/11632#discussion_r896020908


##########
tests/python/integration/test_tuning.py:
##########
@@ -180,6 +184,114 @@ def runner(target, dev):
     run_test_with_all_multiprocessing(runner, target, dev)
 
 
+@tvm.testing.parametrize_targets("cuda", "opencl")
+def test_tuning_gpu_inherits_pass_context(target, dev):
+    """Autotvm tuner inherits PassContexts but also adds a gpu verification pass by default.
+
+    Test that using PassContext inherits passes properly but also runs gpu verification pass.
+    """
+    from tvm.tir.analysis import _ffi_api as _analysis_ffi_api
+
+    @pass_instrument
+    class PassInstrumentChecker:
+        """Pass Instrument that simply sees if it's been run."""
+
+        def __init__(self):
+            self.has_been_run = False
+
+        def run_after_pass(self, mod, info):
+            self.has_been_run = True
+
+    class GPUVerifyPassMocked:
+        """Context manager that mocks tir.analysis.verify_gpu_code meant
+        to verify the pass has been run. This is done by patching the ffi func handles."""
+
+        FFI_FUNC_HANDLE = "tir.analysis.verify_gpu_code"
+        FUNC_NAME = "verify_gpu_code"
+
+        def __init__(self) -> None:
+            self.old_impl = tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE)
+            self.has_been_run = False
+
+        def gpu_verify_pass_mocked(self):
+            """Get the replacement for the gpu verification pass."""
+
+            def _gpu_verify_pass_mocked(*args, **kwargs):
+                self.has_been_run = True
+                return self.old_impl(*args, **kwargs)
+
+            return _gpu_verify_pass_mocked
+
+        def __enter__(self):
+            tvm._ffi.register_func(
+                self.FFI_FUNC_HANDLE, self.gpu_verify_pass_mocked(), override=True
+            )
+
+            # Also overwrite the python bindings
+            setattr(
+                _analysis_ffi_api, self.FUNC_NAME, tvm._ffi.get_global_func(self.FFI_FUNC_HANDLE)
+            )
+
+        def __exit__(self, *args, **kwargs):
+            # Restore FFI status back to normal
+            tvm._ffi.register_func(self.FFI_FUNC_HANDLE, self.old_impl, override=True)
+            setattr(_analysis_ffi_api, self.FUNC_NAME, self.old_impl)
+
+    class OverwrittenBuildFunc(measure_methods._WrappedBuildFunc):
+        """BuildFunc that mocks and patches as necessary to test proper passes are run."""
+
+        def __call__(self, measure_input, tmp_dir, **kwargs):
+            instrument = PassInstrumentChecker()
+            mocked_pass_checker = GPUVerifyPassMocked()
+            with mocked_pass_checker:
+                with PassContext(instruments=[instrument]):
+                    regular_result = super().__call__(measure_input, tmp_dir, **kwargs)
+
+                    # Check instrument has been run, meaning context was inherited by builder
+                    assert instrument.has_been_run
+
+                    # But also check the gpu verification pass has been run
+                    # (which was not in the inherited ctx)
+                    assert mocked_pass_checker.has_been_run
+
+                    return regular_result
+
+    class MockedLocalBuilder(measure_methods.LocalBuilder):
+        """As measure_methods.LocalBuilder but overwrites the PassContext for testing."""
+
+        def __init__(
+            self,
+            timeout=10,
+            n_parallel=None,
+            build_kwargs=None,
+            build_func="default",
+            do_fork=False,
+            runtime=None,
+        ):
+            super().__init__(timeout, n_parallel, build_kwargs, build_func, do_fork, runtime)
+            self.build_func = OverwrittenBuildFunc(tar.tar, runtime)
+
+    def runner(target, dev):
+        task, target = get_sample_task(target, None)
+        logging.info("task config space: %s", task.config_space)
+
+        # Note: we use the MockedLocalBuilder here instead of autotvm.LocalBuilder()
+        measure_option = autotvm.measure_option(MockedLocalBuilder(), autotvm.LocalRunner())
+
+        results = []
+
+        tuner = RandomTuner(task)
+        tuner.tune(
+            n_trial=1,
+            measure_option=measure_option,
+            callbacks=(lambda _tuner, _inputs, rs: results.extend(rs),),
+        )
+
+        assert len(results) == 1

Review Comment:
   want to check the pass also succeeded? i think if one of those asserts fail we just get measure error here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org