You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/22 00:58:04 UTC

[GitHub] [tvm] yelite commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script

yelite commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028648334


##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
     return inputs
 
 
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
-    assert len(a) == len(b), "Different number of outputs from two modules"
-    for i, _ in enumerate(a):
-        if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
-            return False
-    return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+    """Convert a list of TVM NDArray to a list of numpy array
 
+    Parameters
+    ----------
+    a : List[tvm.nd.NDArray]
+        The list of TVM NDArray to be converted
 
-def validate_correctness(
-    original_mod: IRModule,  # compiled for "baseline_target"
-    scheduled_mod: IRModule,  # compiled for "target"
-    *,
-    baseline_target: Target,
-    target: Target,
-    dev_type: str,
-    rpc_config: ms.runner.RPCConfig,
-    f_input_generator: Union[
-        str, Callable[[IRModule], List[tvm.nd.NDArray]]
-    ] = default_input_generator,
-    f_check_metric: Union[
-        str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
-    ] = default_check_metric,
-) -> bool:
-    """Function to validate the correctness of a scheduled module.
+    Returns
+    -------
+    b : List[np.ndarray]
+        The list of numpy array
+    """
+    assert a is not None, "Empty result cannot be converted to numpy"
+    return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+    """Convert a list of numpy array to a list of TVM NDArray
 
     Parameters
     ----------
-    original_mod : IRModule
-        The original module to be compiled.
-    scheduled_mod : IRModule
-        The scheduled module to be compiled.
-    baseline_target : Target
-        The baseline target to compile the original module.
-    target : Target
-        The target to compile the scheduled module.
-    dev_type : str
-        The device type to run the module via rpc.
-    rpc_config : RPCConfig
-        The RPCConfig to run the scheduled module.
-    f_input_generator : Union[str, Callable]
-        The function to generate the input data.
-    f_check_metric : Union[str, Callable]
-        The function to check the metric.
+    a : List[np.ndarray]
+        The list of numpy array to be converted.
 
     Returns
     -------
-    result : bool
-        The result of the validation.
+    b : List[tvm.nd.NDArray]
+        The list of TVM NDArray.
     """
+    assert a is not None, "Empty result cannot be converted to TVM NDArray"
+    return [tvm.nd.array(x) for x in a]
 
-    def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
-        """Convert a list of TVM NDArray to a list of numpy array"""
-        assert a is not None, "Empty result cannot be converted to numpy"
-        return [x.numpy() for x in a]
-
-    def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
-        """Convert a list of numpy array to a list of TVM NDArray"""
-        assert a is not None, "Empty result cannot be converted to TVM NDArray"
-        return [tvm.nd.array(x) for x in a]
-
-    def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
-        """Build and run the module on the target device."""
-        rt_mod = tvm.build(mod, target=target)
-        return run_module_via_rpc(
-            rpc_config=rpc_config,
-            lib=rt_mod,
-            dev_type=dev_type,
-            args={i: v for i, v in enumerate(inputs)},  # pylint: disable=unnecessary-comprehension
-            continuation=create_calculator(backend="tir"),
-            backend="tir",
-        )
 
-    # fetch functions & prepare inputs
-    if isinstance(f_input_generator, str):
-        f_input_generator = get_global_func(f_input_generator)
-    if isinstance(f_check_metric, str):
-        f_check_metric = get_global_func(f_check_metric)
-    inputs = to_numpy(f_input_generator(original_mod))  # type: ignore
-    # build & run original result
-    original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
-    scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
-    # check metric
-    if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)):  # type: ignore
-        return True
-    else:
-        print(
-            ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+    """Check if a tuning record is failed.
+
+    Parameters
+    ----------
+    record : TuningRecord
+        The tuning record to check.
+
+    Returns
+    -------
+    is_failed : bool
+    """
+    return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+    """Print with counter
+
+    Parameters
+    ----------
+    counter : int
+        The counter to print with.
+    total : int
+        The total number of items to print with.
+
+    Returns
+    -------
+    print_result : Callable
+        The print result function.
+    """
+
+    def print_result(
+        result: str,
+        *,
+        original_mod: IRModule = None,
+        scheduled_mod: IRModule = None,
+        inputs: List[np.ndarray] = None,
+        original_res: List[np.ndarray] = None,
+        scheduled_res: List[np.ndarray] = None,
+        original_run_secs: List[float] = None,
+        scheduled_run_secs: List[float] = None,
+        exception: Exception = None,
+        trace: str = None,
+    ) -> None:
+        """Print the validation result."""
+        status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+        if result in ["pass", "wrong answer"]:
+            status += (
+                f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+                f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+            )
+
+        output = [status]
+        if result not in ["pass", "skip"]:
+            output.extend(
                 [
-                    "Validation failed!",
-                    "Original Result:" + DELIMITOR + str(original_res),
-                    "Scheduled Result:" + DELIMITOR + str(scheduled_res),
-                    "Input:" + DELIMITOR + str(inputs),
                     "Original IRModule:" + DELIMITOR + original_mod.script(),
                     "Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+                    "Trace" + DELIMITOR + str(trace),
                 ]
             )
+            if result == "wrong answer":
+                output.extend(
+                    [
+                        "Input:" + DELIMITOR + str(inputs),
+                        "Original Result:" + DELIMITOR + str(original_res),
+                        "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+                        "Max Diff:"
+                        + DELIMITOR
+                        + str(
+                            [
+                                np.max(np.abs(original_res[i] - scheduled_res[i]))
+                                for i in range(len(original_res))
+                            ]
+                        )
+                        + "\n",
+                    ]
+                )
+            elif result == "exception":
+                output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+            else:
+                raise ValueError(f"Unknown result: {result}")
+        print("\n\n".join(output))
+
+    return print_result
+
+
+def make_alloc_arg_and_check(
+    inputs: List[np.ndarray],
+    original_mod: IRModule,
+    scheduled_mod: IRModule,
+    trace: str,
+    original_res: List[np.ndarray],
+    original_run_secs: List[float],
+    print_result: Callable,
+) -> Tuple[Callable, Callable]:
+    """Make alloc_arg and check functions for the given inputs and collect results.
+
+    Parameters
+    ----------
+    inputs : List[np.ndarray]
+        The inputs to the two modules.
+    original_mod : IRModule
+        The original IRModule.
+    scheduled_mod : IRModule
+        The scheduled IRModule.
+    trace : str
+        The trace of the scheduled IRModule.
+    original_res : List[np.ndarray]
+        The original results.
+    original_run_secs : List[float]
+        The original run times.
+    print_result : Callable
+        The print result function.
+
+    Returns
+    -------
+    f_with_args_alloc_argument : Callable
+        The function to allocate arguments.
+
+    f_with_args_run_evaluator : Callable
+        The function to run evaluator.
+    """
+
+    def f_with_args_alloc_argument(
+        # pylint: disable=unused-argument
+        session: tvm.rpc.RPCSession,
+        device: tvm.runtime.Device,
+        args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+        alloc_repeat: int,
+        # pylint: enable=unused-argument
+    ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+        """Allocate arguments using the given inputs.
+
+        Parameters
+        ----------
+        session : RPCSession
+            The RPC session.
+        device : Device
+            The device.
+        args_info : T_ARG_INFO_JSON_OBJ_LIST
+            argument information.
+        alloc_repeat : int
+            The number of times to repeat the allocation.
+
+        Returns
+        -------
+        args_list : List[T_ARGUMENT_LIST]
+            The list of argument lists.
+        """
+        return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+    def f_with_args_run_evaluator(
+        session: tvm.rpc.RPCSession,  # pylint: disable=unused-argument
+        rt_mod: tvm.runtime.Module,
+        device: tvm.runtime.Device,
+        evaluator_config: ms.runner.EvaluatorConfig,
+        repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+    ) -> List[float]:
+        """With args function to run the evaluator
+
+        Parameters
+        ----------
+        session : tvm.rpc.RPCSession
+            The RPC session
+        rt_mod: Module
+            The runtime module
+        device: Device
+            The device to run the evaluator
+        evaluator_config: EvaluatorConfig
+            The evaluator config
+        repeated_args: List[T_ARGUMENT_LIST]
+            The repeated arguments
+
+        Returns
+        -------
+        costs: List[float]
+            The evaluator results
+        """
+        evaluator = rt_mod.time_evaluator(
+            func_name=rt_mod.entry_name,
+            dev=device,
+            number=evaluator_config.number,
+            repeat=evaluator_config.repeat,
+            min_repeat_ms=evaluator_config.min_repeat_ms,
+            f_preproc="cache_flush_cpu_non_first_arg"
+            if evaluator_config.enable_cpu_cache_flush
+            else "",
+        )
+
+        repeated_costs: List[List[float]] = []
+        for args in repeated_args:
+            device.sync()
+            profile_result = evaluator(*args)
+            repeated_costs.append(profile_result.results)
+        costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+        assert len(repeated_args) == 1, "Only support one set of arguments"
+        scheduled_res = [arg.numpy() for arg in repeated_args[0]]  # type: ignore
+        # fetch comparison function
+        passed = check_and_run(
+            ARGS.check_metric_func,
+            to_tvm_ndarray(original_res),
+            to_tvm_ndarray(scheduled_res),
+        )
+
+        print_result(
+            result="pass" if passed else "wrong answer",
+            original_mod=original_mod,
+            scheduled_mod=scheduled_mod,
+            trace=trace,
+            inputs=inputs,
+            original_res=original_res,
+            scheduled_res=scheduled_res,
+            original_run_secs=original_run_secs,
+            scheduled_run_secs=costs,
         )
-        return False
+
+        return costs
+
+    return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+    mod: IRModule,
+    target: Target,
+    device: tvm.runtime.Device,
+    inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+    """Build and run the module locally.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to build and run
+    target: Target
+        The target to build the module
+    device: Device
+        The device to run the module
+    inputs: List[np.ndarray]
+        The inputs to run the module
+
+    Returns
+    -------
+    res: List[np.ndarray]
+        The results of running the module
+    run_secs: List[float]
+        The running time of running the module
+    """
+    # potential memory leak https://github.com/apache/tvm/issues/11096
+    lib = tvm.build(mod, target=target)
+    tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+    device.sync()
+    func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+    benchmark_res = func(*tvm_inputs)
+    device.sync()
+    return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+    """Check if the builder result is defined.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    """
+    assert builder_result.error_msg is None, "Builder failed: " + str(
+        builder_result.error_msg if builder_result.error_msg else "Empty error message"
+    )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+    """Apply the trace to the module.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to apply the trace to
+    trace: Trace
+        The trace to apply
+
+    Returns
+    -------
+    mod: IRModule
+        The module with the trace applied
+    """
+    sch = Schedule(mod)
+    trace.apply_to_schedule(sch, remove_postproc=False)
+    return sch.mod
+
+
+def _build_all_mods(
+    mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+    """Build all the modules.
+
+    Parameters
+    ----------
+    mods: List[IRModule]
+        The modules to build
+    builder: Builder
+        The builder to build the modules
+    target: Target
+        The target to build the modules
+
+    Returns
+    -------
+    builder_results: List[BuilderResult]
+        The builder results
+    """
+    builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+    assert len(builder_results) == len(
+        mods
+    ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+    return builder_results
+
+
+def _run_single_mod(
+    builder_result: ms.builder.BuilderResult,
+    runner: ms.runner.Runner,
+    dev_type: str,
+) -> None:
+    """Run a single module.
+
+    Parameters
+    ----------
+    builder_result: BuilderResult
+        The builder result
+    runner: Runner
+        The runner to run the module
+    dev_type: str
+        The device type
+    """
+    runner_futures = runner.run(
+        # arginfo is not used in this case so we can pass an empty list
+        [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+    )
+    assert (
+        len(runner_futures) == 1
+    ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+    (runner_future,) = runner_futures  # pylint: disable=unbalanced-tuple-unpacking
+    runner_res = runner_future.result()
+    assert runner_res.error_msg is None, "Runner failed: " + (
+        runner_res.error_msg if runner_res.error_msg else "Empty error message"
+    )
 
 
 def main():
     """Main function"""
     describe()
-    database = ms.database.create(work_dir=ARGS.work_dir)
-    target = ARGS.target
-    if target.kind.name == "llvm":
-        dev_type = "cpu"
-    elif target.kind.name == "cuda":
-        dev_type = "cuda"
-    else:
-        raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
-    records = database.get_all_tuning_records()
     with ms.Profiler() as profiler:
-        for i, record in enumerate(records):
-            scope_name = f"validate #{i}"
-            with profiler.timeit(scope_name):
-                original_mod = record.workload.mod
-                sch = Schedule(original_mod)
-                record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
-                scheduled_mod = sch.mod
-                is_success = False
+        # initialize
+        target = ARGS.target
+        dev_type = get_device_type(target)
+        builder = ms.builder.LocalBuilder()
+        database = ms.database.create(work_dir=ARGS.work_dir)
+
+        # collect records
+        with profiler.timeit("collect records"):
+            records = database.get_all_tuning_records()
+        total = len(records)
+        print(
+            f"Total {total} records to be validated. "
+            f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+        )
+
+        # collect unique original TIR
+        with profiler.timeit("deduplicate records"):
+            workloads = set()
+            for record in records:
+                workloads.add(OriginalModule(record.workload.mod))
+        print(
+            f"Total {len(workloads)} unique original TIR to validate. "
+            f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+        )
+        if ARGS.top_k < 10**9:
+            print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+            total = len(workloads) * ARGS.top_k
+        print()
+
+        # validate correctness
+        counter = 0
+        for item in workloads:
+            original_mod = item.mod
+            records = database.get_top_k(
+                workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+            )
+            if len(records) < ARGS.top_k:
+                total -= ARGS.top_k - len(records)
+            inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+            original_res, original_run_secs = local_build_and_run(
+                original_mod,
+                target=ARGS.baseline_target,
+                inputs=inputs,
+                device=get_runtime_device(ARGS.baseline_target),
+            )
+            scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+            builder_results = _build_all_mods(scheduled_mods, builder, target)  # type: ignore
+            for i, record in enumerate(records):
+                counter += 1
+                print_result = print_with_counter_func(counter=counter, total=total)
+                if is_failed_record(record):
+                    # skip failed records where run_secs is 1e9
+                    # these records are only negative samples for cost model
+                    print_result(result="skip")
+                    continue
                 try:
-                    is_success = validate_correctness(
-                        original_mod=original_mod,
-                        scheduled_mod=scheduled_mod,
-                        target=target,
-                        baseline_target=ARGS.baseline_target,
-                        dev_type=dev_type,
+                    # prepare scheduled module
+                    scheduled_mod = scheduled_mods[i]
+                    # check build result
+                    builder_result = builder_results[i]
+                    _check_builder_result(builder_result)
+                    # fetch functions
+                    (
+                        f_with_args_alloc_argument,
+                        f_with_args_run_evaluator,
+                    ) = make_alloc_arg_and_check(
+                        inputs,
+                        original_mod,
+                        scheduled_mod,
+                        str(record.trace),
+                        original_res=original_res,
+                        original_run_secs=original_run_secs,
+                        print_result=print_result,
+                    )
+                    # create rpc runner
+                    runner = ms.runner.RPCRunner(

Review Comment:
   Will it be useful to make rpc optional, which means it uses LocalRunner if no rpc related args are provided?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org