You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/21 23:10:49 UTC

[GitHub] [tvm] zxybazh opened a new pull request, #13459: [MetaSchedule] Enhance Database Validation Script

zxybazh opened a new pull request, #13459:
URL: https://github.com/apache/tvm/pull/13459

   Following up on #12948, this PR further enhances the validation script by
   - [x] Reusing same TIR results for speed up
   - [x] Reusing RPCRunner for benchmarking & validation at the same time
   - [x] Consolidate output interface and output maximum difference if wrong answer
   - [x] Support Top-K validation (superfast for relay compilation validation)
   - [x] Skip failed tuning records
   - [ ] Support Tensorized records (to be tested)
   - [ ] Allow batch build


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script

Posted by GitBox <gi...@apache.org>.
zxybazh commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028650742


##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
     return inputs
 
 
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
-    assert len(a) == len(b), "Different number of outputs from two modules"
-    for i, _ in enumerate(a):
-        if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
-            return False
-    return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+    """Convert a list of TVM NDArray to a list of numpy array
 
+    Parameters
+    ----------
+    a : List[tvm.nd.NDArray]
+        The list of TVM NDArray to be converted
 
-def validate_correctness(
-    original_mod: IRModule,  # compiled for "baseline_target"
-    scheduled_mod: IRModule,  # compiled for "target"
-    *,
-    baseline_target: Target,
-    target: Target,
-    dev_type: str,
-    rpc_config: ms.runner.RPCConfig,
-    f_input_generator: Union[
-        str, Callable[[IRModule], List[tvm.nd.NDArray]]
-    ] = default_input_generator,
-    f_check_metric: Union[
-        str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
-    ] = default_check_metric,
-) -> bool:
-    """Function to validate the correctness of a scheduled module.
+    Returns
+    -------
+    b : List[np.ndarray]
+        The list of numpy array
+    """
+    assert a is not None, "Empty result cannot be converted to numpy"
+    return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+    """Convert a list of numpy array to a list of TVM NDArray
 
     Parameters
     ----------
-    original_mod : IRModule
-        The original module to be compiled.
-    scheduled_mod : IRModule
-        The scheduled module to be compiled.
-    baseline_target : Target
-        The baseline target to compile the original module.
-    target : Target
-        The target to compile the scheduled module.
-    dev_type : str
-        The device type to run the module via rpc.
-    rpc_config : RPCConfig
-        The RPCConfig to run the scheduled module.
-    f_input_generator : Union[str, Callable]
-        The function to generate the input data.
-    f_check_metric : Union[str, Callable]
-        The function to check the metric.
+    a : List[np.ndarray]
+        The list of numpy array to be converted.
 
     Returns
     -------
-    result : bool
-        The result of the validation.
+    b : List[tvm.nd.NDArray]
+        The list of TVM NDArray.
     """
+    assert a is not None, "Empty result cannot be converted to TVM NDArray"
+    return [tvm.nd.array(x) for x in a]
 
-    def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
-        """Convert a list of TVM NDArray to a list of numpy array"""
-        assert a is not None, "Empty result cannot be converted to numpy"
-        return [x.numpy() for x in a]
-
-    def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
-        """Convert a list of numpy array to a list of TVM NDArray"""
-        assert a is not None, "Empty result cannot be converted to TVM NDArray"
-        return [tvm.nd.array(x) for x in a]
-
-    def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
-        """Build and run the module on the target device."""
-        rt_mod = tvm.build(mod, target=target)
-        return run_module_via_rpc(
-            rpc_config=rpc_config,
-            lib=rt_mod,
-            dev_type=dev_type,
-            args={i: v for i, v in enumerate(inputs)},  # pylint: disable=unnecessary-comprehension
-            continuation=create_calculator(backend="tir"),
-            backend="tir",
-        )
 
-    # fetch functions & prepare inputs
-    if isinstance(f_input_generator, str):
-        f_input_generator = get_global_func(f_input_generator)
-    if isinstance(f_check_metric, str):
-        f_check_metric = get_global_func(f_check_metric)
-    inputs = to_numpy(f_input_generator(original_mod))  # type: ignore
-    # build & run original result
-    original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
-    scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
-    # check metric
-    if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)):  # type: ignore
-        return True
-    else:
-        print(
-            ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+    """Check if a tuning record is failed.
+
+    Parameters
+    ----------
+    record : TuningRecord
+        The tuning record to check.
+
+    Returns
+    -------
+    is_failed : bool
+    """
+    return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+    """Print with counter
+
+    Parameters
+    ----------
+    counter : int
+        The counter to print with.
+    total : int
+        The total number of items to print with.
+
+    Returns
+    -------
+    print_result : Callable
+        The print result function.
+    """
+
+    def print_result(
+        result: str,
+        *,
+        original_mod: IRModule = None,
+        scheduled_mod: IRModule = None,
+        inputs: List[np.ndarray] = None,
+        original_res: List[np.ndarray] = None,
+        scheduled_res: List[np.ndarray] = None,
+        original_run_secs: List[float] = None,
+        scheduled_run_secs: List[float] = None,
+        exception: Exception = None,
+        trace: str = None,
+    ) -> None:
+        """Print the validation result."""
+        status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+        if result in ["pass", "wrong answer"]:
+            status += (
+                f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+                f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+            )
+
+        output = [status]
+        if result not in ["pass", "skip"]:
+            output.extend(
                 [
-                    "Validation failed!",
-                    "Original Result:" + DELIMITOR + str(original_res),
-                    "Scheduled Result:" + DELIMITOR + str(scheduled_res),
-                    "Input:" + DELIMITOR + str(inputs),
                     "Original IRModule:" + DELIMITOR + original_mod.script(),
                     "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+                    "Trace" + DELIMITOR + str(trace),
                 ]
             )
+            if result == "wrong answer":
+                output.extend(
+                    [
+                        "Input:" + DELIMITOR + str(inputs),
+                        "Original Result:" + DELIMITOR + str(original_res),
+                        "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+                        "Max Diff:"
+                        + DELIMITOR
+                        + str(
+                            [
+                                np.max(np.abs(original_res[i] - scheduled_res[i]))
+                                for i in range(len(original_res))
+                            ]
+                        )
+                        + "\n",
+                    ]
+                )
+            elif result == "exception":
+                output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+            else:
+                raise ValueError(f"Unknown result: {result}")
+        print("\n\n".join(output))
+
+    return print_result
+
+
+def make_alloc_arg_and_check(
+    inputs: List[np.ndarray],
+    original_mod: IRModule,
+    scheduled_mod: IRModule,
+    trace: str,
+    original_res: List[np.ndarray],
+    original_run_secs: List[float],
+    print_result: Callable,
+) -> Tuple[Callable, Callable]:
+    """Make alloc_arg and check functions for the given inputs and collect results.
+
+    Parameters
+    ----------
+    inputs : List[np.ndarray]
+        The inputs to the two modules.
+    original_mod : IRModule
+        The original IRModule.
+    scheduled_mod : IRModule
+        The scheduled IRModule.
+    trace : str
+        The trace of the scheduled IRModule.
+    original_res : List[np.ndarray]
+        The original results.
+    original_run_secs : List[float]
+        The original run times.
+    print_result : Callable
+        The print result function.
+
+    Returns
+    -------
+    f_with_args_alloc_argument : Callable
+        The function to allocate arguments.
+
+    f_with_args_run_evaluator : Callable
+        The function to run evaluator.
+    """
+
+    def f_with_args_alloc_argument(
+        # pylint: disable=unused-argument
+        session: tvm.rpc.RPCSession,
+        device: tvm.runtime.Device,
+        args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+        alloc_repeat: int,
+        # pylint: enable=unused-argument
+    ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+        """Allocate arguments using the given inputs.
+
+        Parameters
+        ----------
+        session : RPCSession
+            The RPC session.
+        device : Device
+            The device.
+        args_info : T_ARG_INFO_JSON_OBJ_LIST
+            argument information.
+        alloc_repeat : int
+            The number of times to repeat the allocation.
+
+        Returns
+        -------
+        args_list : List[T_ARGUMENT_LIST]
+            The list of argument lists.
+        """
+        return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+    def f_with_args_run_evaluator(
+        session: tvm.rpc.RPCSession,  # pylint: disable=unused-argument
+        rt_mod: tvm.runtime.Module,
+        device: tvm.runtime.Device,
+        evaluator_config: ms.runner.EvaluatorConfig,
+        repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+    ) -> List[float]:
+        """With args function to run the evaluator
+
+        Parameters
+        ----------
+        session : tvm.rpc.RPCSession
+            The RPC session
+        rt_mod: Module
+            The runtime module
+        device: Device
+            The device to run the evaluator
+        evaluator_config: EvaluatorConfig
+            The evaluator config
+        repeated_args: List[T_ARGUMENT_LIST]
+            The repeated arguments
+
+        Returns
+        -------
+        costs: List[float]
+            The evaluator results
+        """
+        evaluator = rt_mod.time_evaluator(
+            func_name=rt_mod.entry_name,
+            dev=device,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )
+
+        repeated_costs: List[List[float]] = []
+        for args in repeated_args:
+            device.sync()
+            profile_result = evaluator(*args)
+            repeated_costs.append(profile_result.results)
+        costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+        assert len(repeated_args) == 1, "Only support one set of arguments"
+        scheduled_res = [arg.numpy() for arg in repeated_args[0]]  # type: ignore
+        # fetch comparison function
+        passed = check_and_run(
+            ARGS.check_metric_func,
+            to_tvm_ndarray(original_res),
+            to_tvm_ndarray(scheduled_res),
+        )
+
+        print_result(
+            result="pass" if passed else "wrong answer",
+            original_mod=original_mod,
+            scheduled_mod=scheduled_mod,
+            trace=trace,
+            inputs=inputs,
+            original_res=original_res,
+            scheduled_res=scheduled_res,
+            original_run_secs=original_run_secs,
+            scheduled_run_secs=costs,
         )
-        return False
+
+        return costs
+
+    return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+    mod: IRModule,
+    target: Target,
+    device: tvm.runtime.Device,
+    inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+    """Build and run the module locally.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to build and run
+    target: Target
+        The target to build the module
+    device: Device
+        The device to run the module
+    inputs: List[np.ndarray]
+        The inputs to run the module
+
+    Returns
+    -------
+    res: List[np.ndarray]
+        The results of running the module
+    run_secs: List[float]
+        The running time of running the module
+    """
+    # potential memory leak https://github.com/apache/tvm/issues/11096
+    lib = tvm.build(mod, target=target)
+    tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+    device.sync()
+    func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+    benchmark_res = func(*tvm_inputs)
+    device.sync()
+    return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+    """Check if the builder result is defined.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    """
+    assert builder_result.error_msg is None, "Builder failed: " + str(
+        builder_result.error_msg if builder_result.error_msg else "Empty error message"
+    )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+    """Apply the trace to the module.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to apply the trace to
+    trace: Trace
+        The trace to apply
+
+    Returns
+    -------
+    mod: IRModule
+        The module with the trace applied
+    """
+    sch = Schedule(mod)
+    trace.apply_to_schedule(sch, remove_postproc=False)
+    return sch.mod
+
+
+def _build_all_mods(
+    mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+    """Build all the modules.
+
+    Parameters
+    ----------
+    mods: List[IRModule]
+        The modules to build
+    builder: Builder
+        The builder to build the modules
+    target: Target
+        The target to build the modules
+
+    Returns
+    -------
+    builder_results: List[BuilderResult]
+        The builder results
+    """
+    builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+    assert len(builder_results) == len(
+        mods
+    ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+    return builder_results
+
+
+def _run_single_mod(
+    builder_result: ms.builder.BuilderResult,
+    runner: ms.runner.Runner,
+    dev_type: str,
+) -> None:
+    """Run a single module.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    runner: Runner
+        The runner to run the module
+    dev_type: str
+        The device type
+    """
+    runner_futures = runner.run(
+        # arginfo is not used in this case so we can pass an empty list
+        [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+    )
+    assert (
+        len(runner_futures) == 1
+    ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+    (runner_future,) = runner_futures  # pylint: disable=unbalanced-tuple-unpacking
+    runner_res = runner_future.result()
+    assert runner_res.error_msg is None, "Runner failed: " + (
+        runner_res.error_msg if runner_res.error_msg else "Empty error message"
+    )
 
 
 def main():
     """Main function"""
     describe()
-    database = ms.database.create(work_dir=ARGS.work_dir)
-    target = ARGS.target
-    if target.kind.name == "llvm":
-        dev_type = "cpu"
-    elif target.kind.name == "cuda":
-        dev_type = "cuda"
-    else:
-        raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
-    records = database.get_all_tuning_records()
     with ms.Profiler() as profiler:
-        for i, record in enumerate(records):
-            scope_name = f"validate #{i}"
-            with profiler.timeit(scope_name):
-                original_mod = record.workload.mod
-                sch = Schedule(original_mod)
-                record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
-                scheduled_mod = sch.mod
-                is_success = False
+        # initialize
+        target = ARGS.target
+        dev_type = get_device_type(target)
+        builder = ms.builder.LocalBuilder()
+        database = ms.database.create(work_dir=ARGS.work_dir)
+
+        # collect records
+        with profiler.timeit("collect records"):
+            records = database.get_all_tuning_records()
+        total = len(records)
+        print(
+            f"Total {total} records to be validated. "
+            f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+        )
+
+        # collect unique original TIR
+        with profiler.timeit("deduplicate records"):
+            workloads = set()
+            for record in records:
+                workloads.add(OriginalModule(record.workload.mod))
+        print(
+            f"Total {len(workloads)} unique original TIR to validate. "
+            f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+        )
+        if ARGS.top_k < 10**9:
+            print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+            total = len(workloads) * ARGS.top_k
+        print()
+
+        # validate correctness
+        counter = 0
+        for item in workloads:
+            original_mod = item.mod
+            records = database.get_top_k(
+                workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+            )
+            if len(records) < ARGS.top_k:
+                total -= ARGS.top_k - len(records)
+            inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+            original_res, original_run_secs = local_build_and_run(
+                original_mod,
+                target=ARGS.baseline_target,
+                inputs=inputs,
+                device=get_runtime_device(ARGS.baseline_target),
+            )
+            scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+            builder_results = _build_all_mods(scheduled_mods, builder, target)  # type: ignore
+            for i, record in enumerate(records):
+                counter += 1
+                print_result = print_with_counter_func(counter=counter, total=total)
+                if is_failed_record(record):
+                    # skip failed records where run_secs is 1e9
+                    # these records are only negative samples for cost model
+                    print_result(result="skip")
+                    continue
                 try:
-                    is_success = validate_correctness(
-                        original_mod=original_mod,
-                        scheduled_mod=scheduled_mod,
-                        target=target,
-                        baseline_target=ARGS.baseline_target,
-                        dev_type=dev_type,
+                    # prepare scheduled module
+                    scheduled_mod = scheduled_mods[i]
+                    # check build result
+                    builder_result = builder_results[i]
+                    _check_builder_result(builder_result)
+                    # fetch functions
+                    (
+                        f_with_args_alloc_argument,
+                        f_with_args_run_evaluator,
+                    ) = make_alloc_arg_and_check(
+                        inputs,
+                        original_mod,
+                        scheduled_mod,
+                        str(record.trace),
+                        original_res=original_res,
+                        original_run_secs=original_run_secs,
+                        print_result=print_result,
+                    )
+                    # create rpc runner
+                    runner = ms.runner.RPCRunner(

Review Comment:
   I think that's a reasonable ask, what do you think would be a good user interface?
   Say, `--rpc-host="local"`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script

Posted by GitBox <gi...@apache.org>.
zxybazh commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028656409


##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
     return inputs
 
 
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
-    assert len(a) == len(b), "Different number of outputs from two modules"
-    for i, _ in enumerate(a):
-        if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
-            return False
-    return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+    """Convert a list of TVM NDArray to a list of numpy array
 
+    Parameters
+    ----------
+    a : List[tvm.nd.NDArray]
+        The list of TVM NDArray to be converted
 
-def validate_correctness(
-    original_mod: IRModule,  # compiled for "baseline_target"
-    scheduled_mod: IRModule,  # compiled for "target"
-    *,
-    baseline_target: Target,
-    target: Target,
-    dev_type: str,
-    rpc_config: ms.runner.RPCConfig,
-    f_input_generator: Union[
-        str, Callable[[IRModule], List[tvm.nd.NDArray]]
-    ] = default_input_generator,
-    f_check_metric: Union[
-        str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
-    ] = default_check_metric,
-) -> bool:
-    """Function to validate the correctness of a scheduled module.
+    Returns
+    -------
+    b : List[np.ndarray]
+        The list of numpy array
+    """
+    assert a is not None, "Empty result cannot be converted to numpy"
+    return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+    """Convert a list of numpy array to a list of TVM NDArray
 
     Parameters
     ----------
-    original_mod : IRModule
-        The original module to be compiled.
-    scheduled_mod : IRModule
-        The scheduled module to be compiled.
-    baseline_target : Target
-        The baseline target to compile the original module.
-    target : Target
-        The target to compile the scheduled module.
-    dev_type : str
-        The device type to run the module via rpc.
-    rpc_config : RPCConfig
-        The RPCConfig to run the scheduled module.
-    f_input_generator : Union[str, Callable]
-        The function to generate the input data.
-    f_check_metric : Union[str, Callable]
-        The function to check the metric.
+    a : List[np.ndarray]
+        The list of numpy array to be converted.
 
     Returns
     -------
-    result : bool
-        The result of the validation.
+    b : List[tvm.nd.NDArray]
+        The list of TVM NDArray.
     """
+    assert a is not None, "Empty result cannot be converted to TVM NDArray"
+    return [tvm.nd.array(x) for x in a]
 
-    def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
-        """Convert a list of TVM NDArray to a list of numpy array"""
-        assert a is not None, "Empty result cannot be converted to numpy"
-        return [x.numpy() for x in a]
-
-    def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
-        """Convert a list of numpy array to a list of TVM NDArray"""
-        assert a is not None, "Empty result cannot be converted to TVM NDArray"
-        return [tvm.nd.array(x) for x in a]
-
-    def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
-        """Build and run the module on the target device."""
-        rt_mod = tvm.build(mod, target=target)
-        return run_module_via_rpc(
-            rpc_config=rpc_config,
-            lib=rt_mod,
-            dev_type=dev_type,
-            args={i: v for i, v in enumerate(inputs)},  # pylint: disable=unnecessary-comprehension
-            continuation=create_calculator(backend="tir"),
-            backend="tir",
-        )
 
-    # fetch functions & prepare inputs
-    if isinstance(f_input_generator, str):
-        f_input_generator = get_global_func(f_input_generator)
-    if isinstance(f_check_metric, str):
-        f_check_metric = get_global_func(f_check_metric)
-    inputs = to_numpy(f_input_generator(original_mod))  # type: ignore
-    # build & run original result
-    original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
-    scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
-    # check metric
-    if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)):  # type: ignore
-        return True
-    else:
-        print(
-            ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+    """Check if a tuning record is failed.
+
+    Parameters
+    ----------
+    record : TuningRecord
+        The tuning record to check.
+
+    Returns
+    -------
+    is_failed : bool
+    """
+    return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+    """Print with counter
+
+    Parameters
+    ----------
+    counter : int
+        The counter to print with.
+    total : int
+        The total number of items to print with.
+
+    Returns
+    -------
+    print_result : Callable
+        The print result function.
+    """
+
+    def print_result(
+        result: str,
+        *,
+        original_mod: IRModule = None,
+        scheduled_mod: IRModule = None,
+        inputs: List[np.ndarray] = None,
+        original_res: List[np.ndarray] = None,
+        scheduled_res: List[np.ndarray] = None,
+        original_run_secs: List[float] = None,
+        scheduled_run_secs: List[float] = None,
+        exception: Exception = None,
+        trace: str = None,
+    ) -> None:
+        """Print the validation result."""
+        status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+        if result in ["pass", "wrong answer"]:
+            status += (
+                f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+                f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+            )
+
+        output = [status]
+        if result not in ["pass", "skip"]:
+            output.extend(
                 [
-                    "Validation failed!",
-                    "Original Result:" + DELIMITOR + str(original_res),
-                    "Scheduled Result:" + DELIMITOR + str(scheduled_res),
-                    "Input:" + DELIMITOR + str(inputs),
                     "Original IRModule:" + DELIMITOR + original_mod.script(),
                     "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+                    "Trace" + DELIMITOR + str(trace),
                 ]
             )
+            if result == "wrong answer":
+                output.extend(
+                    [
+                        "Input:" + DELIMITOR + str(inputs),
+                        "Original Result:" + DELIMITOR + str(original_res),
+                        "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+                        "Max Diff:"
+                        + DELIMITOR
+                        + str(
+                            [
+                                np.max(np.abs(original_res[i] - scheduled_res[i]))
+                                for i in range(len(original_res))
+                            ]
+                        )
+                        + "\n",
+                    ]
+                )
+            elif result == "exception":
+                output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+            else:
+                raise ValueError(f"Unknown result: {result}")
+        print("\n\n".join(output))
+
+    return print_result
+
+
+def make_alloc_arg_and_check(
+    inputs: List[np.ndarray],
+    original_mod: IRModule,
+    scheduled_mod: IRModule,
+    trace: str,
+    original_res: List[np.ndarray],
+    original_run_secs: List[float],
+    print_result: Callable,
+) -> Tuple[Callable, Callable]:
+    """Make alloc_arg and check functions for the given inputs and collect results.
+
+    Parameters
+    ----------
+    inputs : List[np.ndarray]
+        The inputs to the two modules.
+    original_mod : IRModule
+        The original IRModule.
+    scheduled_mod : IRModule
+        The scheduled IRModule.
+    trace : str
+        The trace of the scheduled IRModule.
+    original_res : List[np.ndarray]
+        The original results.
+    original_run_secs : List[float]
+        The original run times.
+    print_result : Callable
+        The print result function.
+
+    Returns
+    -------
+    f_with_args_alloc_argument : Callable
+        The function to allocate arguments.
+
+    f_with_args_run_evaluator : Callable
+        The function to run evaluator.
+    """
+
+    def f_with_args_alloc_argument(
+        # pylint: disable=unused-argument
+        session: tvm.rpc.RPCSession,
+        device: tvm.runtime.Device,
+        args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+        alloc_repeat: int,
+        # pylint: enable=unused-argument
+    ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+        """Allocate arguments using the given inputs.
+
+        Parameters
+        ----------
+        session : RPCSession
+            The RPC session.
+        device : Device
+            The device.
+        args_info : T_ARG_INFO_JSON_OBJ_LIST
+            argument information.
+        alloc_repeat : int
+            The number of times to repeat the allocation.
+
+        Returns
+        -------
+        args_list : List[T_ARGUMENT_LIST]
+            The list of argument lists.
+        """
+        return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+    def f_with_args_run_evaluator(
+        session: tvm.rpc.RPCSession,  # pylint: disable=unused-argument
+        rt_mod: tvm.runtime.Module,
+        device: tvm.runtime.Device,
+        evaluator_config: ms.runner.EvaluatorConfig,
+        repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+    ) -> List[float]:
+        """With args function to run the evaluator
+
+        Parameters
+        ----------
+        session : tvm.rpc.RPCSession
+            The RPC session
+        rt_mod: Module
+            The runtime module
+        device: Device
+            The device to run the evaluator
+        evaluator_config: EvaluatorConfig
+            The evaluator config
+        repeated_args: List[T_ARGUMENT_LIST]
+            The repeated arguments
+
+        Returns
+        -------
+        costs: List[float]
+            The evaluator results
+        """
+        evaluator = rt_mod.time_evaluator(
+            func_name=rt_mod.entry_name,
+            dev=device,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )
+
+        repeated_costs: List[List[float]] = []
+        for args in repeated_args:
+            device.sync()
+            profile_result = evaluator(*args)
+            repeated_costs.append(profile_result.results)
+        costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+        assert len(repeated_args) == 1, "Only support one set of arguments"
+        scheduled_res = [arg.numpy() for arg in repeated_args[0]]  # type: ignore
+        # fetch comparison function
+        passed = check_and_run(
+            ARGS.check_metric_func,
+            to_tvm_ndarray(original_res),
+            to_tvm_ndarray(scheduled_res),
+        )
+
+        print_result(
+            result="pass" if passed else "wrong answer",
+            original_mod=original_mod,
+            scheduled_mod=scheduled_mod,
+            trace=trace,
+            inputs=inputs,
+            original_res=original_res,
+            scheduled_res=scheduled_res,
+            original_run_secs=original_run_secs,
+            scheduled_run_secs=costs,
         )
-        return False
+
+        return costs
+
+    return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+    mod: IRModule,
+    target: Target,
+    device: tvm.runtime.Device,
+    inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+    """Build and run the module locally.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to build and run
+    target: Target
+        The target to build the module
+    device: Device
+        The device to run the module
+    inputs: List[np.ndarray]
+        The inputs to run the module
+
+    Returns
+    -------
+    res: List[np.ndarray]
+        The results of running the module
+    run_secs: List[float]
+        The running time of running the module
+    """
+    # potential memory leak https://github.com/apache/tvm/issues/11096
+    lib = tvm.build(mod, target=target)
+    tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+    device.sync()
+    func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+    benchmark_res = func(*tvm_inputs)
+    device.sync()
+    return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+    """Check if the builder result is defined.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    """
+    assert builder_result.error_msg is None, "Builder failed: " + str(
+        builder_result.error_msg if builder_result.error_msg else "Empty error message"
+    )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+    """Apply the trace to the module.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to apply the trace to
+    trace: Trace
+        The trace to apply
+
+    Returns
+    -------
+    mod: IRModule
+        The module with the trace applied
+    """
+    sch = Schedule(mod)
+    trace.apply_to_schedule(sch, remove_postproc=False)
+    return sch.mod
+
+
+def _build_all_mods(
+    mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+    """Build all the modules.
+
+    Parameters
+    ----------
+    mods: List[IRModule]
+        The modules to build
+    builder: Builder
+        The builder to build the modules
+    target: Target
+        The target to build the modules
+
+    Returns
+    -------
+    builder_results: List[BuilderResult]
+        The builder results
+    """
+    builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+    assert len(builder_results) == len(
+        mods
+    ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+    return builder_results
+
+
+def _run_single_mod(
+    builder_result: ms.builder.BuilderResult,
+    runner: ms.runner.Runner,
+    dev_type: str,
+) -> None:
+    """Run a single module.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    runner: Runner
+        The runner to run the module
+    dev_type: str
+        The device type
+    """
+    runner_futures = runner.run(
+        # arginfo is not used in this case so we can pass an empty list
+        [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+    )
+    assert (
+        len(runner_futures) == 1
+    ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+    (runner_future,) = runner_futures  # pylint: disable=unbalanced-tuple-unpacking
+    runner_res = runner_future.result()
+    assert runner_res.error_msg is None, "Runner failed: " + (
+        runner_res.error_msg if runner_res.error_msg else "Empty error message"
+    )
 
 
 def main():
     """Main function"""
     describe()
-    database = ms.database.create(work_dir=ARGS.work_dir)
-    target = ARGS.target
-    if target.kind.name == "llvm":
-        dev_type = "cpu"
-    elif target.kind.name == "cuda":
-        dev_type = "cuda"
-    else:
-        raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
-    records = database.get_all_tuning_records()
     with ms.Profiler() as profiler:
-        for i, record in enumerate(records):
-            scope_name = f"validate #{i}"
-            with profiler.timeit(scope_name):
-                original_mod = record.workload.mod
-                sch = Schedule(original_mod)
-                record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
-                scheduled_mod = sch.mod
-                is_success = False
+        # initialize
+        target = ARGS.target
+        dev_type = get_device_type(target)
+        builder = ms.builder.LocalBuilder()
+        database = ms.database.create(work_dir=ARGS.work_dir)
+
+        # collect records
+        with profiler.timeit("collect records"):
+            records = database.get_all_tuning_records()
+        total = len(records)
+        print(
+            f"Total {total} records to be validated. "
+            f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+        )
+
+        # collect unique original TIR
+        with profiler.timeit("deduplicate records"):
+            workloads = set()
+            for record in records:
+                workloads.add(OriginalModule(record.workload.mod))
+        print(
+            f"Total {len(workloads)} unique original TIR to validate. "
+            f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+        )
+        if ARGS.top_k < 10**9:
+            print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+            total = len(workloads) * ARGS.top_k
+        print()
+
+        # validate correctness
+        counter = 0
+        for item in workloads:
+            original_mod = item.mod
+            records = database.get_top_k(
+                workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+            )
+            if len(records) < ARGS.top_k:
+                total -= ARGS.top_k - len(records)
+            inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+            original_res, original_run_secs = local_build_and_run(
+                original_mod,
+                target=ARGS.baseline_target,
+                inputs=inputs,
+                device=get_runtime_device(ARGS.baseline_target),
+            )
+            scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+            builder_results = _build_all_mods(scheduled_mods, builder, target)  # type: ignore
+            for i, record in enumerate(records):
+                counter += 1
+                print_result = print_with_counter_func(counter=counter, total=total)
+                if is_failed_record(record):
+                    # skip failed records where run_secs is 1e9
+                    # these records are only negative samples for cost model
+                    print_result(result="skip")
+                    continue
                 try:
-                    is_success = validate_correctness(
-                        original_mod=original_mod,
-                        scheduled_mod=scheduled_mod,
-                        target=target,
-                        baseline_target=ARGS.baseline_target,
-                        dev_type=dev_type,
+                    # prepare scheduled module
+                    scheduled_mod = scheduled_mods[i]
+                    # check build result
+                    builder_result = builder_results[i]
+                    _check_builder_result(builder_result)
+                    # fetch functions
+                    (
+                        f_with_args_alloc_argument,
+                        f_with_args_run_evaluator,
+                    ) = make_alloc_arg_and_check(
+                        inputs,
+                        original_mod,
+                        scheduled_mod,
+                        str(record.trace),
+                        original_res=original_res,
+                        original_run_secs=original_run_secs,
+                        print_result=print_result,
+                    )
+                    # create rpc runner
+                    runner = ms.runner.RPCRunner(

Review Comment:
   I'll throw a warning though, to make sure user is aware of that.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] vinx13 merged pull request #13459: [MetaSchedule] Enhance Database Validation Script

Posted by GitBox <gi...@apache.org>.
vinx13 merged PR #13459:
URL: https://github.com/apache/tvm/pull/13459


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yelite commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script

Posted by GitBox <gi...@apache.org>.
yelite commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028648334


##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
     return inputs
 
 
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
-    assert len(a) == len(b), "Different number of outputs from two modules"
-    for i, _ in enumerate(a):
-        if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
-            return False
-    return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+    """Convert a list of TVM NDArray to a list of numpy array
 
+    Parameters
+    ----------
+    a : List[tvm.nd.NDArray]
+        The list of TVM NDArray to be converted
 
-def validate_correctness(
-    original_mod: IRModule,  # compiled for "baseline_target"
-    scheduled_mod: IRModule,  # compiled for "target"
-    *,
-    baseline_target: Target,
-    target: Target,
-    dev_type: str,
-    rpc_config: ms.runner.RPCConfig,
-    f_input_generator: Union[
-        str, Callable[[IRModule], List[tvm.nd.NDArray]]
-    ] = default_input_generator,
-    f_check_metric: Union[
-        str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
-    ] = default_check_metric,
-) -> bool:
-    """Function to validate the correctness of a scheduled module.
+    Returns
+    -------
+    b : List[np.ndarray]
+        The list of numpy array
+    """
+    assert a is not None, "Empty result cannot be converted to numpy"
+    return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+    """Convert a list of numpy array to a list of TVM NDArray
 
     Parameters
     ----------
-    original_mod : IRModule
-        The original module to be compiled.
-    scheduled_mod : IRModule
-        The scheduled module to be compiled.
-    baseline_target : Target
-        The baseline target to compile the original module.
-    target : Target
-        The target to compile the scheduled module.
-    dev_type : str
-        The device type to run the module via rpc.
-    rpc_config : RPCConfig
-        The RPCConfig to run the scheduled module.
-    f_input_generator : Union[str, Callable]
-        The function to generate the input data.
-    f_check_metric : Union[str, Callable]
-        The function to check the metric.
+    a : List[np.ndarray]
+        The list of numpy array to be converted.
 
     Returns
     -------
-    result : bool
-        The result of the validation.
+    b : List[tvm.nd.NDArray]
+        The list of TVM NDArray.
     """
+    assert a is not None, "Empty result cannot be converted to TVM NDArray"
+    return [tvm.nd.array(x) for x in a]
 
-    def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
-        """Convert a list of TVM NDArray to a list of numpy array"""
-        assert a is not None, "Empty result cannot be converted to numpy"
-        return [x.numpy() for x in a]
-
-    def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
-        """Convert a list of numpy array to a list of TVM NDArray"""
-        assert a is not None, "Empty result cannot be converted to TVM NDArray"
-        return [tvm.nd.array(x) for x in a]
-
-    def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
-        """Build and run the module on the target device."""
-        rt_mod = tvm.build(mod, target=target)
-        return run_module_via_rpc(
-            rpc_config=rpc_config,
-            lib=rt_mod,
-            dev_type=dev_type,
-            args={i: v for i, v in enumerate(inputs)},  # pylint: disable=unnecessary-comprehension
-            continuation=create_calculator(backend="tir"),
-            backend="tir",
-        )
 
-    # fetch functions & prepare inputs
-    if isinstance(f_input_generator, str):
-        f_input_generator = get_global_func(f_input_generator)
-    if isinstance(f_check_metric, str):
-        f_check_metric = get_global_func(f_check_metric)
-    inputs = to_numpy(f_input_generator(original_mod))  # type: ignore
-    # build & run original result
-    original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
-    scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
-    # check metric
-    if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)):  # type: ignore
-        return True
-    else:
-        print(
-            ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+    """Check if a tuning record is failed.
+
+    Parameters
+    ----------
+    record : TuningRecord
+        The tuning record to check.
+
+    Returns
+    -------
+    is_failed : bool
+    """
+    return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+    """Print with counter
+
+    Parameters
+    ----------
+    counter : int
+        The counter to print with.
+    total : int
+        The total number of items to print with.
+
+    Returns
+    -------
+    print_result : Callable
+        The print result function.
+    """
+
+    def print_result(
+        result: str,
+        *,
+        original_mod: IRModule = None,
+        scheduled_mod: IRModule = None,
+        inputs: List[np.ndarray] = None,
+        original_res: List[np.ndarray] = None,
+        scheduled_res: List[np.ndarray] = None,
+        original_run_secs: List[float] = None,
+        scheduled_run_secs: List[float] = None,
+        exception: Exception = None,
+        trace: str = None,
+    ) -> None:
+        """Print the validation result."""
+        status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+        if result in ["pass", "wrong answer"]:
+            status += (
+                f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+                f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+            )
+
+        output = [status]
+        if result not in ["pass", "skip"]:
+            output.extend(
                 [
-                    "Validation failed!",
-                    "Original Result:" + DELIMITOR + str(original_res),
-                    "Scheduled Result:" + DELIMITOR + str(scheduled_res),
-                    "Input:" + DELIMITOR + str(inputs),
                     "Original IRModule:" + DELIMITOR + original_mod.script(),
                     "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+                    "Trace" + DELIMITOR + str(trace),
                 ]
             )
+            if result == "wrong answer":
+                output.extend(
+                    [
+                        "Input:" + DELIMITOR + str(inputs),
+                        "Original Result:" + DELIMITOR + str(original_res),
+                        "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+                        "Max Diff:"
+                        + DELIMITOR
+                        + str(
+                            [
+                                np.max(np.abs(original_res[i] - scheduled_res[i]))
+                                for i in range(len(original_res))
+                            ]
+                        )
+                        + "\n",
+                    ]
+                )
+            elif result == "exception":
+                output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+            else:
+                raise ValueError(f"Unknown result: {result}")
+        print("\n\n".join(output))
+
+    return print_result
+
+
+def make_alloc_arg_and_check(
+    inputs: List[np.ndarray],
+    original_mod: IRModule,
+    scheduled_mod: IRModule,
+    trace: str,
+    original_res: List[np.ndarray],
+    original_run_secs: List[float],
+    print_result: Callable,
+) -> Tuple[Callable, Callable]:
+    """Make alloc_arg and check functions for the given inputs and collect results.
+
+    Parameters
+    ----------
+    inputs : List[np.ndarray]
+        The inputs to the two modules.
+    original_mod : IRModule
+        The original IRModule.
+    scheduled_mod : IRModule
+        The scheduled IRModule.
+    trace : str
+        The trace of the scheduled IRModule.
+    original_res : List[np.ndarray]
+        The original results.
+    original_run_secs : List[float]
+        The original run times.
+    print_result : Callable
+        The print result function.
+
+    Returns
+    -------
+    f_with_args_alloc_argument : Callable
+        The function to allocate arguments.
+
+    f_with_args_run_evaluator : Callable
+        The function to run evaluator.
+    """
+
+    def f_with_args_alloc_argument(
+        # pylint: disable=unused-argument
+        session: tvm.rpc.RPCSession,
+        device: tvm.runtime.Device,
+        args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+        alloc_repeat: int,
+        # pylint: enable=unused-argument
+    ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+        """Allocate arguments using the given inputs.
+
+        Parameters
+        ----------
+        session : RPCSession
+            The RPC session.
+        device : Device
+            The device.
+        args_info : T_ARG_INFO_JSON_OBJ_LIST
+            argument information.
+        alloc_repeat : int
+            The number of times to repeat the allocation.
+
+        Returns
+        -------
+        args_list : List[T_ARGUMENT_LIST]
+            The list of argument lists.
+        """
+        return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+    def f_with_args_run_evaluator(
+        session: tvm.rpc.RPCSession,  # pylint: disable=unused-argument
+        rt_mod: tvm.runtime.Module,
+        device: tvm.runtime.Device,
+        evaluator_config: ms.runner.EvaluatorConfig,
+        repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+    ) -> List[float]:
+        """With args function to run the evaluator
+
+        Parameters
+        ----------
+        session : tvm.rpc.RPCSession
+            The RPC session
+        rt_mod: Module
+            The runtime module
+        device: Device
+            The device to run the evaluator
+        evaluator_config: EvaluatorConfig
+            The evaluator config
+        repeated_args: List[T_ARGUMENT_LIST]
+            The repeated arguments
+
+        Returns
+        -------
+        costs: List[float]
+            The evaluator results
+        """
+        evaluator = rt_mod.time_evaluator(
+            func_name=rt_mod.entry_name,
+            dev=device,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )
+
+        repeated_costs: List[List[float]] = []
+        for args in repeated_args:
+            device.sync()
+            profile_result = evaluator(*args)
+            repeated_costs.append(profile_result.results)
+        costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+        assert len(repeated_args) == 1, "Only support one set of arguments"
+        scheduled_res = [arg.numpy() for arg in repeated_args[0]]  # type: ignore
+        # fetch comparison function
+        passed = check_and_run(
+            ARGS.check_metric_func,
+            to_tvm_ndarray(original_res),
+            to_tvm_ndarray(scheduled_res),
+        )
+
+        print_result(
+            result="pass" if passed else "wrong answer",
+            original_mod=original_mod,
+            scheduled_mod=scheduled_mod,
+            trace=trace,
+            inputs=inputs,
+            original_res=original_res,
+            scheduled_res=scheduled_res,
+            original_run_secs=original_run_secs,
+            scheduled_run_secs=costs,
         )
-        return False
+
+        return costs
+
+    return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+    mod: IRModule,
+    target: Target,
+    device: tvm.runtime.Device,
+    inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+    """Build and run the module locally.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to build and run
+    target: Target
+        The target to build the module
+    device: Device
+        The device to run the module
+    inputs: List[np.ndarray]
+        The inputs to run the module
+
+    Returns
+    -------
+    res: List[np.ndarray]
+        The results of running the module
+    run_secs: List[float]
+        The running time of running the module
+    """
+    # potential memory leak https://github.com/apache/tvm/issues/11096
+    lib = tvm.build(mod, target=target)
+    tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+    device.sync()
+    func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+    benchmark_res = func(*tvm_inputs)
+    device.sync()
+    return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+    """Check if the builder result is defined.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    """
+    assert builder_result.error_msg is None, "Builder failed: " + str(
+        builder_result.error_msg if builder_result.error_msg else "Empty error message"
+    )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+    """Apply the trace to the module.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to apply the trace to
+    trace: Trace
+        The trace to apply
+
+    Returns
+    -------
+    mod: IRModule
+        The module with the trace applied
+    """
+    sch = Schedule(mod)
+    trace.apply_to_schedule(sch, remove_postproc=False)
+    return sch.mod
+
+
+def _build_all_mods(
+    mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+    """Build all the modules.
+
+    Parameters
+    ----------
+    mods: List[IRModule]
+        The modules to build
+    builder: Builder
+        The builder to build the modules
+    target: Target
+        The target to build the modules
+
+    Returns
+    -------
+    builder_results: List[BuilderResult]
+        The builder results
+    """
+    builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+    assert len(builder_results) == len(
+        mods
+    ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+    return builder_results
+
+
+def _run_single_mod(
+    builder_result: ms.builder.BuilderResult,
+    runner: ms.runner.Runner,
+    dev_type: str,
+) -> None:
+    """Run a single module.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    runner: Runner
+        The runner to run the module
+    dev_type: str
+        The device type
+    """
+    runner_futures = runner.run(
+        # arginfo is not used in this case so we can pass an empty list
+        [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+    )
+    assert (
+        len(runner_futures) == 1
+    ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+    (runner_future,) = runner_futures  # pylint: disable=unbalanced-tuple-unpacking
+    runner_res = runner_future.result()
+    assert runner_res.error_msg is None, "Runner failed: " + (
+        runner_res.error_msg if runner_res.error_msg else "Empty error message"
+    )
 
 
 def main():
     """Main function"""
     describe()
-    database = ms.database.create(work_dir=ARGS.work_dir)
-    target = ARGS.target
-    if target.kind.name == "llvm":
-        dev_type = "cpu"
-    elif target.kind.name == "cuda":
-        dev_type = "cuda"
-    else:
-        raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
-    records = database.get_all_tuning_records()
     with ms.Profiler() as profiler:
-        for i, record in enumerate(records):
-            scope_name = f"validate #{i}"
-            with profiler.timeit(scope_name):
-                original_mod = record.workload.mod
-                sch = Schedule(original_mod)
-                record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
-                scheduled_mod = sch.mod
-                is_success = False
+        # initialize
+        target = ARGS.target
+        dev_type = get_device_type(target)
+        builder = ms.builder.LocalBuilder()
+        database = ms.database.create(work_dir=ARGS.work_dir)
+
+        # collect records
+        with profiler.timeit("collect records"):
+            records = database.get_all_tuning_records()
+        total = len(records)
+        print(
+            f"Total {total} records to be validated. "
+            f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+        )
+
+        # collect unique original TIR
+        with profiler.timeit("deduplicate records"):
+            workloads = set()
+            for record in records:
+                workloads.add(OriginalModule(record.workload.mod))
+        print(
+            f"Total {len(workloads)} unique original TIR to validate. "
+            f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+        )
+        if ARGS.top_k < 10**9:
+            print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+            total = len(workloads) * ARGS.top_k
+        print()
+
+        # validate correctness
+        counter = 0
+        for item in workloads:
+            original_mod = item.mod
+            records = database.get_top_k(
+                workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+            )
+            if len(records) < ARGS.top_k:
+                total -= ARGS.top_k - len(records)
+            inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+            original_res, original_run_secs = local_build_and_run(
+                original_mod,
+                target=ARGS.baseline_target,
+                inputs=inputs,
+                device=get_runtime_device(ARGS.baseline_target),
+            )
+            scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+            builder_results = _build_all_mods(scheduled_mods, builder, target)  # type: ignore
+            for i, record in enumerate(records):
+                counter += 1
+                print_result = print_with_counter_func(counter=counter, total=total)
+                if is_failed_record(record):
+                    # skip failed records where run_secs is 1e9
+                    # these records are only negative samples for cost model
+                    print_result(result="skip")
+                    continue
                 try:
-                    is_success = validate_correctness(
-                        original_mod=original_mod,
-                        scheduled_mod=scheduled_mod,
-                        target=target,
-                        baseline_target=ARGS.baseline_target,
-                        dev_type=dev_type,
+                    # prepare scheduled module
+                    scheduled_mod = scheduled_mods[i]
+                    # check build result
+                    builder_result = builder_results[i]
+                    _check_builder_result(builder_result)
+                    # fetch functions
+                    (
+                        f_with_args_alloc_argument,
+                        f_with_args_run_evaluator,
+                    ) = make_alloc_arg_and_check(
+                        inputs,
+                        original_mod,
+                        scheduled_mod,
+                        str(record.trace),
+                        original_res=original_res,
+                        original_run_secs=original_run_secs,
+                        print_result=print_result,
+                    )
+                    # create rpc runner
+                    runner = ms.runner.RPCRunner(

Review Comment:
   Will it be useful to make rpc optional, which means it uses LocalRunner if no rpc related args are provided?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script

Posted by GitBox <gi...@apache.org>.
zxybazh commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028656239


##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
     return inputs
 
 
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
-    assert len(a) == len(b), "Different number of outputs from two modules"
-    for i, _ in enumerate(a):
-        if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
-            return False
-    return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+    """Convert a list of TVM NDArray to a list of numpy array
 
+    Parameters
+    ----------
+    a : List[tvm.nd.NDArray]
+        The list of TVM NDArray to be converted
 
-def validate_correctness(
-    original_mod: IRModule,  # compiled for "baseline_target"
-    scheduled_mod: IRModule,  # compiled for "target"
-    *,
-    baseline_target: Target,
-    target: Target,
-    dev_type: str,
-    rpc_config: ms.runner.RPCConfig,
-    f_input_generator: Union[
-        str, Callable[[IRModule], List[tvm.nd.NDArray]]
-    ] = default_input_generator,
-    f_check_metric: Union[
-        str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
-    ] = default_check_metric,
-) -> bool:
-    """Function to validate the correctness of a scheduled module.
+    Returns
+    -------
+    b : List[np.ndarray]
+        The list of numpy array
+    """
+    assert a is not None, "Empty result cannot be converted to numpy"
+    return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+    """Convert a list of numpy array to a list of TVM NDArray
 
     Parameters
     ----------
-    original_mod : IRModule
-        The original module to be compiled.
-    scheduled_mod : IRModule
-        The scheduled module to be compiled.
-    baseline_target : Target
-        The baseline target to compile the original module.
-    target : Target
-        The target to compile the scheduled module.
-    dev_type : str
-        The device type to run the module via rpc.
-    rpc_config : RPCConfig
-        The RPCConfig to run the scheduled module.
-    f_input_generator : Union[str, Callable]
-        The function to generate the input data.
-    f_check_metric : Union[str, Callable]
-        The function to check the metric.
+    a : List[np.ndarray]
+        The list of numpy array to be converted.
 
     Returns
     -------
-    result : bool
-        The result of the validation.
+    b : List[tvm.nd.NDArray]
+        The list of TVM NDArray.
     """
+    assert a is not None, "Empty result cannot be converted to TVM NDArray"
+    return [tvm.nd.array(x) for x in a]
 
-    def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
-        """Convert a list of TVM NDArray to a list of numpy array"""
-        assert a is not None, "Empty result cannot be converted to numpy"
-        return [x.numpy() for x in a]
-
-    def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
-        """Convert a list of numpy array to a list of TVM NDArray"""
-        assert a is not None, "Empty result cannot be converted to TVM NDArray"
-        return [tvm.nd.array(x) for x in a]
-
-    def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
-        """Build and run the module on the target device."""
-        rt_mod = tvm.build(mod, target=target)
-        return run_module_via_rpc(
-            rpc_config=rpc_config,
-            lib=rt_mod,
-            dev_type=dev_type,
-            args={i: v for i, v in enumerate(inputs)},  # pylint: disable=unnecessary-comprehension
-            continuation=create_calculator(backend="tir"),
-            backend="tir",
-        )
 
-    # fetch functions & prepare inputs
-    if isinstance(f_input_generator, str):
-        f_input_generator = get_global_func(f_input_generator)
-    if isinstance(f_check_metric, str):
-        f_check_metric = get_global_func(f_check_metric)
-    inputs = to_numpy(f_input_generator(original_mod))  # type: ignore
-    # build & run original result
-    original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
-    scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
-    # check metric
-    if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)):  # type: ignore
-        return True
-    else:
-        print(
-            ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+    """Check if a tuning record is failed.
+
+    Parameters
+    ----------
+    record : TuningRecord
+        The tuning record to check.
+
+    Returns
+    -------
+    is_failed : bool
+    """
+    return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+    """Print with counter
+
+    Parameters
+    ----------
+    counter : int
+        The counter to print with.
+    total : int
+        The total number of items to print with.
+
+    Returns
+    -------
+    print_result : Callable
+        The print result function.
+    """
+
+    def print_result(
+        result: str,
+        *,
+        original_mod: IRModule = None,
+        scheduled_mod: IRModule = None,
+        inputs: List[np.ndarray] = None,
+        original_res: List[np.ndarray] = None,
+        scheduled_res: List[np.ndarray] = None,
+        original_run_secs: List[float] = None,
+        scheduled_run_secs: List[float] = None,
+        exception: Exception = None,
+        trace: str = None,
+    ) -> None:
+        """Print the validation result."""
+        status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+        if result in ["pass", "wrong answer"]:
+            status += (
+                f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+                f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+            )
+
+        output = [status]
+        if result not in ["pass", "skip"]:
+            output.extend(
                 [
-                    "Validation failed!",
-                    "Original Result:" + DELIMITOR + str(original_res),
-                    "Scheduled Result:" + DELIMITOR + str(scheduled_res),
-                    "Input:" + DELIMITOR + str(inputs),
                     "Original IRModule:" + DELIMITOR + original_mod.script(),
                     "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+                    "Trace" + DELIMITOR + str(trace),
                 ]
             )
+            if result == "wrong answer":
+                output.extend(
+                    [
+                        "Input:" + DELIMITOR + str(inputs),
+                        "Original Result:" + DELIMITOR + str(original_res),
+                        "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+                        "Max Diff:"
+                        + DELIMITOR
+                        + str(
+                            [
+                                np.max(np.abs(original_res[i] - scheduled_res[i]))
+                                for i in range(len(original_res))
+                            ]
+                        )
+                        + "\n",
+                    ]
+                )
+            elif result == "exception":
+                output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+            else:
+                raise ValueError(f"Unknown result: {result}")
+        print("\n\n".join(output))
+
+    return print_result
+
+
+def make_alloc_arg_and_check(
+    inputs: List[np.ndarray],
+    original_mod: IRModule,
+    scheduled_mod: IRModule,
+    trace: str,
+    original_res: List[np.ndarray],
+    original_run_secs: List[float],
+    print_result: Callable,
+) -> Tuple[Callable, Callable]:
+    """Make alloc_arg and check functions for the given inputs and collect results.
+
+    Parameters
+    ----------
+    inputs : List[np.ndarray]
+        The inputs to the two modules.
+    original_mod : IRModule
+        The original IRModule.
+    scheduled_mod : IRModule
+        The scheduled IRModule.
+    trace : str
+        The trace of the scheduled IRModule.
+    original_res : List[np.ndarray]
+        The original results.
+    original_run_secs : List[float]
+        The original run times.
+    print_result : Callable
+        The print result function.
+
+    Returns
+    -------
+    f_with_args_alloc_argument : Callable
+        The function to allocate arguments.
+
+    f_with_args_run_evaluator : Callable
+        The function to run evaluator.
+    """
+
+    def f_with_args_alloc_argument(
+        # pylint: disable=unused-argument
+        session: tvm.rpc.RPCSession,
+        device: tvm.runtime.Device,
+        args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+        alloc_repeat: int,
+        # pylint: enable=unused-argument
+    ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+        """Allocate arguments using the given inputs.
+
+        Parameters
+        ----------
+        session : RPCSession
+            The RPC session.
+        device : Device
+            The device.
+        args_info : T_ARG_INFO_JSON_OBJ_LIST
+            argument information.
+        alloc_repeat : int
+            The number of times to repeat the allocation.
+
+        Returns
+        -------
+        args_list : List[T_ARGUMENT_LIST]
+            The list of argument lists.
+        """
+        return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+    def f_with_args_run_evaluator(
+        session: tvm.rpc.RPCSession,  # pylint: disable=unused-argument
+        rt_mod: tvm.runtime.Module,
+        device: tvm.runtime.Device,
+        evaluator_config: ms.runner.EvaluatorConfig,
+        repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+    ) -> List[float]:
+        """With args function to run the evaluator
+
+        Parameters
+        ----------
+        session : tvm.rpc.RPCSession
+            The RPC session
+        rt_mod: Module
+            The runtime module
+        device: Device
+            The device to run the evaluator
+        evaluator_config: EvaluatorConfig
+            The evaluator config
+        repeated_args: List[T_ARGUMENT_LIST]
+            The repeated arguments
+
+        Returns
+        -------
+        costs: List[float]
+            The evaluator results
+        """
+        evaluator = rt_mod.time_evaluator(
+            func_name=rt_mod.entry_name,
+            dev=device,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )
+
+        repeated_costs: List[List[float]] = []
+        for args in repeated_args:
+            device.sync()
+            profile_result = evaluator(*args)
+            repeated_costs.append(profile_result.results)
+        costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+        assert len(repeated_args) == 1, "Only support one set of arguments"
+        scheduled_res = [arg.numpy() for arg in repeated_args[0]]  # type: ignore
+        # fetch comparison function
+        passed = check_and_run(
+            ARGS.check_metric_func,
+            to_tvm_ndarray(original_res),
+            to_tvm_ndarray(scheduled_res),
+        )
+
+        print_result(
+            result="pass" if passed else "wrong answer",
+            original_mod=original_mod,
+            scheduled_mod=scheduled_mod,
+            trace=trace,
+            inputs=inputs,
+            original_res=original_res,
+            scheduled_res=scheduled_res,
+            original_run_secs=original_run_secs,
+            scheduled_run_secs=costs,
         )
-        return False
+
+        return costs
+
+    return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+    mod: IRModule,
+    target: Target,
+    device: tvm.runtime.Device,
+    inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+    """Build and run the module locally.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to build and run
+    target: Target
+        The target to build the module
+    device: Device
+        The device to run the module
+    inputs: List[np.ndarray]
+        The inputs to run the module
+
+    Returns
+    -------
+    res: List[np.ndarray]
+        The results of running the module
+    run_secs: List[float]
+        The running time of running the module
+    """
+    # potential memory leak https://github.com/apache/tvm/issues/11096
+    lib = tvm.build(mod, target=target)
+    tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+    device.sync()
+    func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+    benchmark_res = func(*tvm_inputs)
+    device.sync()
+    return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+    """Check if the builder result is defined.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    """
+    assert builder_result.error_msg is None, "Builder failed: " + str(
+        builder_result.error_msg if builder_result.error_msg else "Empty error message"
+    )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+    """Apply the trace to the module.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to apply the trace to
+    trace: Trace
+        The trace to apply
+
+    Returns
+    -------
+    mod: IRModule
+        The module with the trace applied
+    """
+    sch = Schedule(mod)
+    trace.apply_to_schedule(sch, remove_postproc=False)
+    return sch.mod
+
+
+def _build_all_mods(
+    mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+    """Build all the modules.
+
+    Parameters
+    ----------
+    mods: List[IRModule]
+        The modules to build
+    builder: Builder
+        The builder to build the modules
+    target: Target
+        The target to build the modules
+
+    Returns
+    -------
+    builder_results: List[BuilderResult]
+        The builder results
+    """
+    builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+    assert len(builder_results) == len(
+        mods
+    ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+    return builder_results
+
+
+def _run_single_mod(
+    builder_result: ms.builder.BuilderResult,
+    runner: ms.runner.Runner,
+    dev_type: str,
+) -> None:
+    """Run a single module.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    runner: Runner
+        The runner to run the module
+    dev_type: str
+        The device type
+    """
+    runner_futures = runner.run(
+        # arginfo is not used in this case so we can pass an empty list
+        [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+    )
+    assert (
+        len(runner_futures) == 1
+    ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+    (runner_future,) = runner_futures  # pylint: disable=unbalanced-tuple-unpacking
+    runner_res = runner_future.result()
+    assert runner_res.error_msg is None, "Runner failed: " + (
+        runner_res.error_msg if runner_res.error_msg else "Empty error message"
+    )
 
 
 def main():
     """Main function"""
     describe()
-    database = ms.database.create(work_dir=ARGS.work_dir)
-    target = ARGS.target
-    if target.kind.name == "llvm":
-        dev_type = "cpu"
-    elif target.kind.name == "cuda":
-        dev_type = "cuda"
-    else:
-        raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
-    records = database.get_all_tuning_records()
     with ms.Profiler() as profiler:
-        for i, record in enumerate(records):
-            scope_name = f"validate #{i}"
-            with profiler.timeit(scope_name):
-                original_mod = record.workload.mod
-                sch = Schedule(original_mod)
-                record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
-                scheduled_mod = sch.mod
-                is_success = False
+        # initialize
+        target = ARGS.target
+        dev_type = get_device_type(target)
+        builder = ms.builder.LocalBuilder()
+        database = ms.database.create(work_dir=ARGS.work_dir)
+
+        # collect records
+        with profiler.timeit("collect records"):
+            records = database.get_all_tuning_records()
+        total = len(records)
+        print(
+            f"Total {total} records to be validated. "
+            f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+        )
+
+        # collect unique original TIR
+        with profiler.timeit("deduplicate records"):
+            workloads = set()
+            for record in records:
+                workloads.add(OriginalModule(record.workload.mod))
+        print(
+            f"Total {len(workloads)} unique original TIR to validate. "
+            f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+        )
+        if ARGS.top_k < 10**9:
+            print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+            total = len(workloads) * ARGS.top_k
+        print()
+
+        # validate correctness
+        counter = 0
+        for item in workloads:
+            original_mod = item.mod
+            records = database.get_top_k(
+                workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+            )
+            if len(records) < ARGS.top_k:
+                total -= ARGS.top_k - len(records)
+            inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+            original_res, original_run_secs = local_build_and_run(
+                original_mod,
+                target=ARGS.baseline_target,
+                inputs=inputs,
+                device=get_runtime_device(ARGS.baseline_target),
+            )
+            scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+            builder_results = _build_all_mods(scheduled_mods, builder, target)  # type: ignore
+            for i, record in enumerate(records):
+                counter += 1
+                print_result = print_with_counter_func(counter=counter, total=total)
+                if is_failed_record(record):
+                    # skip failed records where run_secs is 1e9
+                    # these records are only negative samples for cost model
+                    print_result(result="skip")
+                    continue
                 try:
-                    is_success = validate_correctness(
-                        original_mod=original_mod,
-                        scheduled_mod=scheduled_mod,
-                        target=target,
-                        baseline_target=ARGS.baseline_target,
-                        dev_type=dev_type,
+                    # prepare scheduled module
+                    scheduled_mod = scheduled_mods[i]
+                    # check build result
+                    builder_result = builder_results[i]
+                    _check_builder_result(builder_result)
+                    # fetch functions
+                    (
+                        f_with_args_alloc_argument,
+                        f_with_args_run_evaluator,
+                    ) = make_alloc_arg_and_check(
+                        inputs,
+                        original_mod,
+                        scheduled_mod,
+                        str(record.trace),
+                        original_res=original_res,
+                        original_run_secs=original_run_secs,
+                        print_result=print_result,
+                    )
+                    # create rpc runner
+                    runner = ms.runner.RPCRunner(

Review Comment:
   Sure, can do that.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #13459: [MetaSchedule] Enhance Database Validation Script

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13459:
URL: https://github.com/apache/tvm/pull/13459#issuecomment-1322779362

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] zxybazh commented on pull request #13459: [MetaSchedule] Enhance Database Validation Script

Posted by GitBox <gi...@apache.org>.
zxybazh commented on PR #13459:
URL: https://github.com/apache/tvm/pull/13459#issuecomment-1322832239

   Sample output looks like this:
   ```python
   Python Environment
     TVM version    = 0.11.dev0
     Python version = 3.8.13 (default, Mar 28 2022, 11:38:47)  [GCC 7.5.0] (64 bit)
     os.uname()     = Linux 5.18.10-76051810-generic #202207071639~1657252310~21.10~7d5e891 SMP PREEMPT_DYNAMIC Fri J x86_64
   CMake Options:
     {
       "BACKTRACE_ON_SEGFAULT": "OFF",
       "BUILD_STATIC_RUNTIME": "OFF",
       "COMPILER_RT_PATH": "3rdparty/compiler-rt",
       "CUDA_VERSION": "NOT-FOUND",
       "DLPACK_PATH": "3rdparty/dlpack/include",
       "DMLC_PATH": "3rdparty/dmlc-core/include",
       "GIT_COMMIT_HASH": "490e0e3120f304a98607770502b5700ec6ab9d55",
       "GIT_COMMIT_TIME": "2022-11-18 10:02:55 -0800",
       "HIDE_PRIVATE_SYMBOLS": "ON",
       "INDEX_DEFAULT_I64": "ON",
       "INSTALL_DEV": "OFF",
       "LLVM_VERSION": "12.0.1",
       "PICOJSON_PATH": "3rdparty/picojson",
       "RANG_PATH": "3rdparty/rang/include",
       "ROCM_PATH": "/opt/rocm",
       "SUMMARIZE": "OFF",
       "TVM_CXX_COMPILER_PATH": "/usr/bin/c++",
       "USE_ALTERNATIVE_LINKER": "AUTO",
       "USE_AOT_EXECUTOR": "ON",
       "USE_ARM_COMPUTE_LIB": "OFF",
       "USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR": "OFF",
       "USE_BLAS": "none",
       "USE_BNNS": "OFF",
       "USE_BYODT_POSIT": "OFF",
       "USE_CCACHE": "AUTO",
       "USE_CLML": "OFF",
       "USE_CLML_GRAPH_EXECUTOR": "OFF",
       "USE_CMSISNN": "OFF",
       "USE_COREML": "OFF",
       "USE_CPP_RPC": "OFF",
       "USE_CUBLAS": "OFF",
       "USE_CUDA": "/usr/local/cuda-11.8/",
       "USE_CUDNN": "ON",
       "USE_CURAND": "ON",
       "USE_CUSTOM_LOGGING": "OFF",
       "USE_CUTLASS": "OFF",
       "USE_DNNL": "OFF",
       "USE_ETHOSN": "OFF",
       "USE_FALLBACK_STL_MAP": "OFF",
       "USE_GRAPH_EXECUTOR": "ON",
       "USE_GRAPH_EXECUTOR_CUDA_GRAPH": "OFF",
       "USE_GTEST": "AUTO",
       "USE_HEXAGON": "OFF",
       "USE_HEXAGON_EXTERNAL_LIBS": "OFF",
       "USE_HEXAGON_GTEST": "/path/to/hexagon/gtest",
       "USE_HEXAGON_RPC": "OFF",
       "USE_HEXAGON_SDK": "/path/to/sdk",
       "USE_IOS_RPC": "OFF",
       "USE_KHRONOS_SPIRV": "OFF",
       "USE_LIBBACKTRACE": "ON",
       "USE_LIBTORCH": "OFF",
       "USE_LLVM": "llvm-config-12 --link-static",
       "USE_METAL": "OFF",
       "USE_MICRO": "OFF",
       "USE_MICRO_STANDALONE_RUNTIME": "OFF",
       "USE_MIOPEN": "OFF",
       "USE_MKL": "OFF",
       "USE_MSVC_MT": "OFF",
       "USE_NNPACK": "OFF",
       "USE_OPENCL": "OFF",
       "USE_OPENCL_GTEST": "/path/to/opencl/gtest",
       "USE_OPENMP": "none",
       "USE_PAPI": "OFF",
       "USE_PROFILER": "ON",
       "USE_PT_TVMDSOOP": "OFF",
       "USE_RANDOM": "ON",
       "USE_RELAY_DEBUG": "OFF",
       "USE_ROCBLAS": "OFF",
       "USE_ROCM": "OFF",
       "USE_RPC": "ON",
       "USE_RTTI": "ON",
       "USE_RUST_EXT": "OFF",
       "USE_SORT": "ON",
       "USE_SPIRV_KHR_INTEGER_DOT_PRODUCT": "OFF",
       "USE_STACKVM_RUNTIME": "OFF",
       "USE_TARGET_ONNX": "OFF",
       "USE_TENSORFLOW_PATH": "none",
       "USE_TENSORRT_CODEGEN": "OFF",
       "USE_TENSORRT_RUNTIME": "OFF",
       "USE_TFLITE": "OFF",
       "USE_TF_TVMDSOOP": "OFF",
       "USE_THREADS": "ON",
       "USE_THRUST": "OFF",
       "USE_UMA": "OFF",
       "USE_VITIS_AI": "OFF",
       "USE_VULKAN": "OFF"
     }
   2022-11-21 16:19:36.788 INFO LocalBuilder: max_workers = 24
   Total 1221 records to be validated. Collected in  0.000 sec.
   Total 20 unique original TIR to validate. Deduplicated in  0.358 sec.
   Top 10 records for each original TIR will be validated.
   
   Progress      1 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress      2 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress      3 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress      4 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress      5 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress      6 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress      7 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress      8 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress      9 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress     10 /    200 (estimated) checked, result:       pass, original:      0.006 ms, scheduled:      0.003 ms
   Progress     11 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.030 ms
   Progress     12 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.035 ms
   Progress     13 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.035 ms
   Progress     14 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.037 ms
   Progress     15 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.038 ms
   Progress     16 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.043 ms
   Progress     17 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.044 ms
   Progress     18 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.045 ms
   Progress     19 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.047 ms
   Progress     20 /    200 (estimated) checked, result:       pass, original:     52.827 ms, scheduled:      0.047 ms
   Progress     21 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.104 ms
   Progress     22 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.105 ms
   Progress     23 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.116 ms
   Progress     24 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.118 ms
   Progress     25 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.122 ms
   Progress     26 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.152 ms
   Progress     27 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.151 ms
   Progress     28 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.190 ms
   Progress     29 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.218 ms
   Progress     30 /    200 (estimated) checked, result:       pass, original:    146.474 ms, scheduled:      0.230 ms
   Progress     31 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.013 ms
   Progress     32 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.016 ms
   Progress     33 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.016 ms
   Progress     34 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.017 ms
   Progress     35 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.018 ms
   Progress     36 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.019 ms
   Progress     37 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.021 ms
   Progress     38 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.021 ms
   Progress     39 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.024 ms
   Progress     40 /    200 (estimated) checked, result:       pass, original:      7.804 ms, scheduled:      0.024 ms
   Progress     41 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.086 ms
   Progress     42 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.083 ms
   Progress     43 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.084 ms
   Progress     44 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.084 ms
   Progress     45 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.084 ms
   Progress     46 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.086 ms
   Progress     47 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.086 ms
   Progress     48 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.086 ms
   Progress     49 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.087 ms
   Progress     50 /    200 (estimated) checked, result:       pass, original:     43.212 ms, scheduled:      0.089 ms
   Progress     51 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.057 ms
   Progress     52 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.071 ms
   Progress     53 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.074 ms
   Progress     54 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.076 ms
   Progress     55 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.082 ms
   Progress     56 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.088 ms
   Progress     57 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.090 ms
   Progress     58 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.097 ms
   Progress     59 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.098 ms
   Progress     60 /    200 (estimated) checked, result:       pass, original:     65.038 ms, scheduled:      0.101 ms
   Progress     61 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.074 ms
   Progress     62 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.067 ms
   Progress     63 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.066 ms
   Progress     64 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.067 ms
   Progress     65 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.075 ms
   Progress     66 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.078 ms
   Progress     67 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.079 ms
   Progress     68 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.089 ms
   Progress     69 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.096 ms
   Progress     70 /    200 (estimated) checked, result:       pass, original:     31.743 ms, scheduled:      0.110 ms
   Progress     71 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.006 ms
   Progress     72 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.005 ms
   Progress     73 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.005 ms
   Progress     74 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.005 ms
   Progress     75 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.006 ms
   Progress     76 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.006 ms
   Progress     77 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.007 ms
   Progress     78 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.007 ms
   Progress     79 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.008 ms
   Progress     80 /    200 (estimated) checked, result:       pass, original:      3.623 ms, scheduled:      0.009 ms
   Progress     81 /    196 (estimated) checked, result:       pass, original:      0.034 ms, scheduled:      0.002 ms
   Progress     82 /    196 (estimated) checked, result:       pass, original:      0.034 ms, scheduled:      0.002 ms
   Progress     83 /    196 (estimated) checked, result:       pass, original:      0.034 ms, scheduled:      0.002 ms
   Progress     84 /    196 (estimated) checked, result:       pass, original:      0.034 ms, scheduled:      0.003 ms
   Progress     85 /    196 (estimated) checked, result:       pass, original:      0.034 ms, scheduled:      0.003 ms
   Progress     86 /    196 (estimated) checked, result:       pass, original:      0.034 ms, scheduled:      0.005 ms
   Progress     87 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.007 ms
   Progress     88 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.007 ms
   Progress     89 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.009 ms
   Progress     90 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.009 ms
   Progress     91 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.009 ms
   Progress     92 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.009 ms
   Progress     93 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.010 ms
   Progress     94 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.010 ms
   Progress     95 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.012 ms
   Progress     96 /    196 (estimated) checked, result:       pass, original:      6.110 ms, scheduled:      0.013 ms
   Progress     97 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.082 ms
   Progress     98 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.098 ms
   Progress     99 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.101 ms
   Progress    100 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.106 ms
   Progress    101 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.117 ms
   Progress    102 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.120 ms
   Progress    103 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.125 ms
   Progress    104 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.143 ms
   Progress    105 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.147 ms
   Progress    106 /    196 (estimated) checked, result:       pass, original:     43.432 ms, scheduled:      0.150 ms
   Progress    107 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.051 ms
   Progress    108 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.048 ms
   Progress    109 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.051 ms
   Progress    110 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.053 ms
   Progress    111 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.061 ms
   Progress    112 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.068 ms
   Progress    113 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.082 ms
   Progress    114 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.086 ms
   Progress    115 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.078 ms
   Progress    116 /    196 (estimated) checked, result:       pass, original:     24.410 ms, scheduled:      0.102 ms
   Progress    117 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.067 ms
   Progress    118 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.081 ms
   Progress    119 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.095 ms
   Progress    120 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.112 ms
   Progress    121 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.113 ms
   Progress    122 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.125 ms
   Progress    123 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.129 ms
   Progress    124 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.130 ms
   Progress    125 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.142 ms
   Progress    126 /    196 (estimated) checked, result:       pass, original:     32.747 ms, scheduled:      0.142 ms
   Progress    127 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.051 ms
   Progress    128 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.059 ms
   Progress    129 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.059 ms
   Progress    130 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.059 ms
   Progress    131 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.060 ms
   Progress    132 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.060 ms
   Progress    133 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.063 ms
   Progress    134 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.065 ms
   Progress    135 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.071 ms
   Progress    136 /    196 (estimated) checked, result:       pass, original:     31.283 ms, scheduled:      0.077 ms
   Progress    137 /    191 (estimated) checked, result:       pass, original:      0.000 ms, scheduled:      0.002 ms
   Progress    138 /    191 (estimated) checked, result:       pass, original:      0.000 ms, scheduled:      0.002 ms
   Progress    139 /    191 (estimated) checked, result:       pass, original:      0.000 ms, scheduled:      0.002 ms
   Progress    140 /    191 (estimated) checked, result:       pass, original:      0.000 ms, scheduled:      0.002 ms
   Progress    141 /    191 (estimated) checked, result:       pass, original:      0.000 ms, scheduled:      0.002 ms
   Progress    142 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.010 ms
   Progress    143 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.016 ms
   Progress    144 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.021 ms
   Progress    145 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.025 ms
   Progress    146 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.029 ms
   Progress    147 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.030 ms
   Progress    148 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.030 ms
   Progress    149 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.033 ms
   Progress    150 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.033 ms
   Progress    151 /    191 (estimated) checked, result:       pass, original:      0.338 ms, scheduled:      0.033 ms
   Progress    152 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    153 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    154 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    155 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    156 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    157 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    158 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    159 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    160 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    161 /    191 (estimated) checked, result:       pass, original:      0.177 ms, scheduled:      0.005 ms
   Progress    162 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.049 ms
   Progress    163 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.048 ms
   Progress    164 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.053 ms
   Progress    165 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.055 ms
   Progress    166 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.056 ms
   Progress    167 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.056 ms
   Progress    168 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.057 ms
   Progress    169 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.059 ms
   Progress    170 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.062 ms
   Progress    171 /    191 (estimated) checked, result:       pass, original:     70.768 ms, scheduled:      0.064 ms
   Progress    172 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.058 ms
   Progress    173 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.055 ms
   Progress    174 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.059 ms
   Progress    175 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.063 ms
   Progress    176 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.063 ms
   Progress    177 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.072 ms
   Progress    178 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.074 ms
   Progress    179 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.080 ms
   Progress    180 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.074 ms
   Progress    181 /    191 (estimated) checked, result:       pass, original:     31.186 ms, scheduled:      0.085 ms
   Progress    182 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.048 ms
   Progress    183 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.044 ms
   Progress    184 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.050 ms
   Progress    185 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.052 ms
   Progress    186 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.065 ms
   Progress    187 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.067 ms
   Progress    188 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.068 ms
   Progress    189 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.077 ms
   Progress    190 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.078 ms
   Progress    191 /    191 (estimated) checked, result:       pass, original:     24.627 ms, scheduled:      0.082 ms
   Validation finished! Total time spent:  168.474 sec.
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] yelite commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script

Posted by GitBox <gi...@apache.org>.
yelite commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028655821


##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
     return inputs
 
 
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
-    assert len(a) == len(b), "Different number of outputs from two modules"
-    for i, _ in enumerate(a):
-        if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
-            return False
-    return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+    """Convert a list of TVM NDArray to a list of numpy array
 
+    Parameters
+    ----------
+    a : List[tvm.nd.NDArray]
+        The list of TVM NDArray to be converted
 
-def validate_correctness(
-    original_mod: IRModule,  # compiled for "baseline_target"
-    scheduled_mod: IRModule,  # compiled for "target"
-    *,
-    baseline_target: Target,
-    target: Target,
-    dev_type: str,
-    rpc_config: ms.runner.RPCConfig,
-    f_input_generator: Union[
-        str, Callable[[IRModule], List[tvm.nd.NDArray]]
-    ] = default_input_generator,
-    f_check_metric: Union[
-        str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
-    ] = default_check_metric,
-) -> bool:
-    """Function to validate the correctness of a scheduled module.
+    Returns
+    -------
+    b : List[np.ndarray]
+        The list of numpy array
+    """
+    assert a is not None, "Empty result cannot be converted to numpy"
+    return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+    """Convert a list of numpy array to a list of TVM NDArray
 
     Parameters
     ----------
-    original_mod : IRModule
-        The original module to be compiled.
-    scheduled_mod : IRModule
-        The scheduled module to be compiled.
-    baseline_target : Target
-        The baseline target to compile the original module.
-    target : Target
-        The target to compile the scheduled module.
-    dev_type : str
-        The device type to run the module via rpc.
-    rpc_config : RPCConfig
-        The RPCConfig to run the scheduled module.
-    f_input_generator : Union[str, Callable]
-        The function to generate the input data.
-    f_check_metric : Union[str, Callable]
-        The function to check the metric.
+    a : List[np.ndarray]
+        The list of numpy array to be converted.
 
     Returns
     -------
-    result : bool
-        The result of the validation.
+    b : List[tvm.nd.NDArray]
+        The list of TVM NDArray.
     """
+    assert a is not None, "Empty result cannot be converted to TVM NDArray"
+    return [tvm.nd.array(x) for x in a]
 
-    def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
-        """Convert a list of TVM NDArray to a list of numpy array"""
-        assert a is not None, "Empty result cannot be converted to numpy"
-        return [x.numpy() for x in a]
-
-    def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
-        """Convert a list of numpy array to a list of TVM NDArray"""
-        assert a is not None, "Empty result cannot be converted to TVM NDArray"
-        return [tvm.nd.array(x) for x in a]
-
-    def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
-        """Build and run the module on the target device."""
-        rt_mod = tvm.build(mod, target=target)
-        return run_module_via_rpc(
-            rpc_config=rpc_config,
-            lib=rt_mod,
-            dev_type=dev_type,
-            args={i: v for i, v in enumerate(inputs)},  # pylint: disable=unnecessary-comprehension
-            continuation=create_calculator(backend="tir"),
-            backend="tir",
-        )
 
-    # fetch functions & prepare inputs
-    if isinstance(f_input_generator, str):
-        f_input_generator = get_global_func(f_input_generator)
-    if isinstance(f_check_metric, str):
-        f_check_metric = get_global_func(f_check_metric)
-    inputs = to_numpy(f_input_generator(original_mod))  # type: ignore
-    # build & run original result
-    original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
-    scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
-    # check metric
-    if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)):  # type: ignore
-        return True
-    else:
-        print(
-            ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+    """Check if a tuning record is failed.
+
+    Parameters
+    ----------
+    record : TuningRecord
+        The tuning record to check.
+
+    Returns
+    -------
+    is_failed : bool
+    """
+    return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+    """Print with counter
+
+    Parameters
+    ----------
+    counter : int
+        The counter to print with.
+    total : int
+        The total number of items to print with.
+
+    Returns
+    -------
+    print_result : Callable
+        The print result function.
+    """
+
+    def print_result(
+        result: str,
+        *,
+        original_mod: IRModule = None,
+        scheduled_mod: IRModule = None,
+        inputs: List[np.ndarray] = None,
+        original_res: List[np.ndarray] = None,
+        scheduled_res: List[np.ndarray] = None,
+        original_run_secs: List[float] = None,
+        scheduled_run_secs: List[float] = None,
+        exception: Exception = None,
+        trace: str = None,
+    ) -> None:
+        """Print the validation result."""
+        status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+        if result in ["pass", "wrong answer"]:
+            status += (
+                f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+                f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+            )
+
+        output = [status]
+        if result not in ["pass", "skip"]:
+            output.extend(
                 [
-                    "Validation failed!",
-                    "Original Result:" + DELIMITOR + str(original_res),
-                    "Scheduled Result:" + DELIMITOR + str(scheduled_res),
-                    "Input:" + DELIMITOR + str(inputs),
                     "Original IRModule:" + DELIMITOR + original_mod.script(),
                     "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+                    "Trace" + DELIMITOR + str(trace),
                 ]
             )
+            if result == "wrong answer":
+                output.extend(
+                    [
+                        "Input:" + DELIMITOR + str(inputs),
+                        "Original Result:" + DELIMITOR + str(original_res),
+                        "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+                        "Max Diff:"
+                        + DELIMITOR
+                        + str(
+                            [
+                                np.max(np.abs(original_res[i] - scheduled_res[i]))
+                                for i in range(len(original_res))
+                            ]
+                        )
+                        + "\n",
+                    ]
+                )
+            elif result == "exception":
+                output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+            else:
+                raise ValueError(f"Unknown result: {result}")
+        print("\n\n".join(output))
+
+    return print_result
+
+
+def make_alloc_arg_and_check(
+    inputs: List[np.ndarray],
+    original_mod: IRModule,
+    scheduled_mod: IRModule,
+    trace: str,
+    original_res: List[np.ndarray],
+    original_run_secs: List[float],
+    print_result: Callable,
+) -> Tuple[Callable, Callable]:
+    """Make alloc_arg and check functions for the given inputs and collect results.
+
+    Parameters
+    ----------
+    inputs : List[np.ndarray]
+        The inputs to the two modules.
+    original_mod : IRModule
+        The original IRModule.
+    scheduled_mod : IRModule
+        The scheduled IRModule.
+    trace : str
+        The trace of the scheduled IRModule.
+    original_res : List[np.ndarray]
+        The original results.
+    original_run_secs : List[float]
+        The original run times.
+    print_result : Callable
+        The print result function.
+
+    Returns
+    -------
+    f_with_args_alloc_argument : Callable
+        The function to allocate arguments.
+
+    f_with_args_run_evaluator : Callable
+        The function to run evaluator.
+    """
+
+    def f_with_args_alloc_argument(
+        # pylint: disable=unused-argument
+        session: tvm.rpc.RPCSession,
+        device: tvm.runtime.Device,
+        args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+        alloc_repeat: int,
+        # pylint: enable=unused-argument
+    ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+        """Allocate arguments using the given inputs.
+
+        Parameters
+        ----------
+        session : RPCSession
+            The RPC session.
+        device : Device
+            The device.
+        args_info : T_ARG_INFO_JSON_OBJ_LIST
+            argument information.
+        alloc_repeat : int
+            The number of times to repeat the allocation.
+
+        Returns
+        -------
+        args_list : List[T_ARGUMENT_LIST]
+            The list of argument lists.
+        """
+        return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+    def f_with_args_run_evaluator(
+        session: tvm.rpc.RPCSession,  # pylint: disable=unused-argument
+        rt_mod: tvm.runtime.Module,
+        device: tvm.runtime.Device,
+        evaluator_config: ms.runner.EvaluatorConfig,
+        repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+    ) -> List[float]:
+        """With args function to run the evaluator
+
+        Parameters
+        ----------
+        session : tvm.rpc.RPCSession
+            The RPC session
+        rt_mod: Module
+            The runtime module
+        device: Device
+            The device to run the evaluator
+        evaluator_config: EvaluatorConfig
+            The evaluator config
+        repeated_args: List[T_ARGUMENT_LIST]
+            The repeated arguments
+
+        Returns
+        -------
+        costs: List[float]
+            The evaluator results
+        """
+        evaluator = rt_mod.time_evaluator(
+            func_name=rt_mod.entry_name,
+            dev=device,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )
+
+        repeated_costs: List[List[float]] = []
+        for args in repeated_args:
+            device.sync()
+            profile_result = evaluator(*args)
+            repeated_costs.append(profile_result.results)
+        costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+        assert len(repeated_args) == 1, "Only support one set of arguments"
+        scheduled_res = [arg.numpy() for arg in repeated_args[0]]  # type: ignore
+        # fetch comparison function
+        passed = check_and_run(
+            ARGS.check_metric_func,
+            to_tvm_ndarray(original_res),
+            to_tvm_ndarray(scheduled_res),
+        )
+
+        print_result(
+            result="pass" if passed else "wrong answer",
+            original_mod=original_mod,
+            scheduled_mod=scheduled_mod,
+            trace=trace,
+            inputs=inputs,
+            original_res=original_res,
+            scheduled_res=scheduled_res,
+            original_run_secs=original_run_secs,
+            scheduled_run_secs=costs,
         )
-        return False
+
+        return costs
+
+    return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+    mod: IRModule,
+    target: Target,
+    device: tvm.runtime.Device,
+    inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+    """Build and run the module locally.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to build and run
+    target: Target
+        The target to build the module
+    device: Device
+        The device to run the module
+    inputs: List[np.ndarray]
+        The inputs to run the module
+
+    Returns
+    -------
+    res: List[np.ndarray]
+        The results of running the module
+    run_secs: List[float]
+        The running time of running the module
+    """
+    # potential memory leak https://github.com/apache/tvm/issues/11096
+    lib = tvm.build(mod, target=target)
+    tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+    device.sync()
+    func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+    benchmark_res = func(*tvm_inputs)
+    device.sync()
+    return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+    """Check if the builder result is defined.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    """
+    assert builder_result.error_msg is None, "Builder failed: " + str(
+        builder_result.error_msg if builder_result.error_msg else "Empty error message"
+    )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+    """Apply the trace to the module.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to apply the trace to
+    trace: Trace
+        The trace to apply
+
+    Returns
+    -------
+    mod: IRModule
+        The module with the trace applied
+    """
+    sch = Schedule(mod)
+    trace.apply_to_schedule(sch, remove_postproc=False)
+    return sch.mod
+
+
+def _build_all_mods(
+    mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+    """Build all the modules.
+
+    Parameters
+    ----------
+    mods: List[IRModule]
+        The modules to build
+    builder: Builder
+        The builder to build the modules
+    target: Target
+        The target to build the modules
+
+    Returns
+    -------
+    builder_results: List[BuilderResult]
+        The builder results
+    """
+    builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+    assert len(builder_results) == len(
+        mods
+    ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+    return builder_results
+
+
+def _run_single_mod(
+    builder_result: ms.builder.BuilderResult,
+    runner: ms.runner.Runner,
+    dev_type: str,
+) -> None:
+    """Run a single module.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    runner: Runner
+        The runner to run the module
+    dev_type: str
+        The device type
+    """
+    runner_futures = runner.run(
+        # arginfo is not used in this case so we can pass an empty list
+        [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+    )
+    assert (
+        len(runner_futures) == 1
+    ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+    (runner_future,) = runner_futures  # pylint: disable=unbalanced-tuple-unpacking
+    runner_res = runner_future.result()
+    assert runner_res.error_msg is None, "Runner failed: " + (
+        runner_res.error_msg if runner_res.error_msg else "Empty error message"
+    )
 
 
 def main():
     """Main function"""
     describe()
-    database = ms.database.create(work_dir=ARGS.work_dir)
-    target = ARGS.target
-    if target.kind.name == "llvm":
-        dev_type = "cpu"
-    elif target.kind.name == "cuda":
-        dev_type = "cuda"
-    else:
-        raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
-    records = database.get_all_tuning_records()
     with ms.Profiler() as profiler:
-        for i, record in enumerate(records):
-            scope_name = f"validate #{i}"
-            with profiler.timeit(scope_name):
-                original_mod = record.workload.mod
-                sch = Schedule(original_mod)
-                record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
-                scheduled_mod = sch.mod
-                is_success = False
+        # initialize
+        target = ARGS.target
+        dev_type = get_device_type(target)
+        builder = ms.builder.LocalBuilder()
+        database = ms.database.create(work_dir=ARGS.work_dir)
+
+        # collect records
+        with profiler.timeit("collect records"):
+            records = database.get_all_tuning_records()
+        total = len(records)
+        print(
+            f"Total {total} records to be validated. "
+            f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+        )
+
+        # collect unique original TIR
+        with profiler.timeit("deduplicate records"):
+            workloads = set()
+            for record in records:
+                workloads.add(OriginalModule(record.workload.mod))
+        print(
+            f"Total {len(workloads)} unique original TIR to validate. "
+            f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+        )
+        if ARGS.top_k < 10**9:
+            print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+            total = len(workloads) * ARGS.top_k
+        print()
+
+        # validate correctness
+        counter = 0
+        for item in workloads:
+            original_mod = item.mod
+            records = database.get_top_k(
+                workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+            )
+            if len(records) < ARGS.top_k:
+                total -= ARGS.top_k - len(records)
+            inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+            original_res, original_run_secs = local_build_and_run(
+                original_mod,
+                target=ARGS.baseline_target,
+                inputs=inputs,
+                device=get_runtime_device(ARGS.baseline_target),
+            )
+            scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+            builder_results = _build_all_mods(scheduled_mods, builder, target)  # type: ignore
+            for i, record in enumerate(records):
+                counter += 1
+                print_result = print_with_counter_func(counter=counter, total=total)
+                if is_failed_record(record):
+                    # skip failed records where run_secs is 1e9
+                    # these records are only negative samples for cost model
+                    print_result(result="skip")
+                    continue
                 try:
-                    is_success = validate_correctness(
-                        original_mod=original_mod,
-                        scheduled_mod=scheduled_mod,
-                        target=target,
-                        baseline_target=ARGS.baseline_target,
-                        dev_type=dev_type,
+                    # prepare scheduled module
+                    scheduled_mod = scheduled_mods[i]
+                    # check build result
+                    builder_result = builder_results[i]
+                    _check_builder_result(builder_result)
+                    # fetch functions
+                    (
+                        f_with_args_alloc_argument,
+                        f_with_args_run_evaluator,
+                    ) = make_alloc_arg_and_check(
+                        inputs,
+                        original_mod,
+                        scheduled_mod,
+                        str(record.trace),
+                        original_res=original_res,
+                        original_run_secs=original_run_secs,
+                        print_result=print_result,
+                    )
+                    # create rpc runner
+                    runner = ms.runner.RPCRunner(

Review Comment:
   Maybe just make the `--rpc-host` arg and its friends optional and use local runner if `--rpc-host` is not given by user



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org