You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/21 23:10:49 UTC
[GitHub] [tvm] zxybazh opened a new pull request, #13459: [MetaSchedule] Enhance Database Validation Script
zxybazh opened a new pull request, #13459:
URL: https://github.com/apache/tvm/pull/13459
Following up on #12948, this PR further enhances the validation script by
- [x] Reusing same TIR results for speed up
- [x] Reusing RPCRunner for benchmarking & validation at the same time
- [x] Consolidate output interface and output maximum difference if wrong answer
- [x] Support Top-K validation (superfast for relay compilation validation)
- [x] Skip failed tuning records
- [ ] Support Tensorized records (to be tested)
- [ ] Allow batch build
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] zxybazh commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script
Posted by GitBox <gi...@apache.org>.
zxybazh commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028650742
##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
return inputs
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
- assert len(a) == len(b), "Different number of outputs from two modules"
- for i, _ in enumerate(a):
- if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
- return False
- return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+ """Convert a list of TVM NDArray to a list of numpy array
+ Parameters
+ ----------
+ a : List[tvm.nd.NDArray]
+ The list of TVM NDArray to be converted
-def validate_correctness(
- original_mod: IRModule, # compiled for "baseline_target"
- scheduled_mod: IRModule, # compiled for "target"
- *,
- baseline_target: Target,
- target: Target,
- dev_type: str,
- rpc_config: ms.runner.RPCConfig,
- f_input_generator: Union[
- str, Callable[[IRModule], List[tvm.nd.NDArray]]
- ] = default_input_generator,
- f_check_metric: Union[
- str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
- ] = default_check_metric,
-) -> bool:
- """Function to validate the correctness of a scheduled module.
+ Returns
+ -------
+ b : List[np.ndarray]
+ The list of numpy array
+ """
+ assert a is not None, "Empty result cannot be converted to numpy"
+ return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+ """Convert a list of numpy array to a list of TVM NDArray
Parameters
----------
- original_mod : IRModule
- The original module to be compiled.
- scheduled_mod : IRModule
- The scheduled module to be compiled.
- baseline_target : Target
- The baseline target to compile the original module.
- target : Target
- The target to compile the scheduled module.
- dev_type : str
- The device type to run the module via rpc.
- rpc_config : RPCConfig
- The RPCConfig to run the scheduled module.
- f_input_generator : Union[str, Callable]
- The function to generate the input data.
- f_check_metric : Union[str, Callable]
- The function to check the metric.
+ a : List[np.ndarray]
+ The list of numpy array to be converted.
Returns
-------
- result : bool
- The result of the validation.
+ b : List[tvm.nd.NDArray]
+ The list of TVM NDArray.
"""
+ assert a is not None, "Empty result cannot be converted to TVM NDArray"
+ return [tvm.nd.array(x) for x in a]
- def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
- """Convert a list of TVM NDArray to a list of numpy array"""
- assert a is not None, "Empty result cannot be converted to numpy"
- return [x.numpy() for x in a]
-
- def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
- """Convert a list of numpy array to a list of TVM NDArray"""
- assert a is not None, "Empty result cannot be converted to TVM NDArray"
- return [tvm.nd.array(x) for x in a]
-
- def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
- """Build and run the module on the target device."""
- rt_mod = tvm.build(mod, target=target)
- return run_module_via_rpc(
- rpc_config=rpc_config,
- lib=rt_mod,
- dev_type=dev_type,
- args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension
- continuation=create_calculator(backend="tir"),
- backend="tir",
- )
- # fetch functions & prepare inputs
- if isinstance(f_input_generator, str):
- f_input_generator = get_global_func(f_input_generator)
- if isinstance(f_check_metric, str):
- f_check_metric = get_global_func(f_check_metric)
- inputs = to_numpy(f_input_generator(original_mod)) # type: ignore
- # build & run original result
- original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
- scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
- # check metric
- if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore
- return True
- else:
- print(
- ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+ """Check if a tuning record is failed.
+
+ Parameters
+ ----------
+ record : TuningRecord
+ The tuning record to check.
+
+ Returns
+ -------
+ is_failed : bool
+ """
+ return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+ """Print with counter
+
+ Parameters
+ ----------
+ counter : int
+ The counter to print with.
+ total : int
+ The total number of items to print with.
+
+ Returns
+ -------
+ print_result : Callable
+ The print result function.
+ """
+
+ def print_result(
+ result: str,
+ *,
+ original_mod: IRModule = None,
+ scheduled_mod: IRModule = None,
+ inputs: List[np.ndarray] = None,
+ original_res: List[np.ndarray] = None,
+ scheduled_res: List[np.ndarray] = None,
+ original_run_secs: List[float] = None,
+ scheduled_run_secs: List[float] = None,
+ exception: Exception = None,
+ trace: str = None,
+ ) -> None:
+ """Print the validation result."""
+ status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+ if result in ["pass", "wrong answer"]:
+ status += (
+ f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+ f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+ )
+
+ output = [status]
+ if result not in ["pass", "skip"]:
+ output.extend(
[
- "Validation failed!",
- "Original Result:" + DELIMITOR + str(original_res),
- "Scheduled Result:" + DELIMITOR + str(scheduled_res),
- "Input:" + DELIMITOR + str(inputs),
"Original IRModule:" + DELIMITOR + original_mod.script(),
"Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+ "Trace" + DELIMITOR + str(trace),
]
)
+ if result == "wrong answer":
+ output.extend(
+ [
+ "Input:" + DELIMITOR + str(inputs),
+ "Original Result:" + DELIMITOR + str(original_res),
+ "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+ "Max Diff:"
+ + DELIMITOR
+ + str(
+ [
+ np.max(np.abs(original_res[i] - scheduled_res[i]))
+ for i in range(len(original_res))
+ ]
+ )
+ + "\n",
+ ]
+ )
+ elif result == "exception":
+ output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+ else:
+ raise ValueError(f"Unknown result: {result}")
+ print("\n\n".join(output))
+
+ return print_result
+
+
+def make_alloc_arg_and_check(
+ inputs: List[np.ndarray],
+ original_mod: IRModule,
+ scheduled_mod: IRModule,
+ trace: str,
+ original_res: List[np.ndarray],
+ original_run_secs: List[float],
+ print_result: Callable,
+) -> Tuple[Callable, Callable]:
+ """Make alloc_arg and check functions for the given inputs and collect results.
+
+ Parameters
+ ----------
+ inputs : List[np.ndarray]
+ The inputs to the two modules.
+ original_mod : IRModule
+ The original IRModule.
+ scheduled_mod : IRModule
+ The scheduled IRModule.
+ trace : str
+ The trace of the scheduled IRModule.
+ original_res : List[np.ndarray]
+ The original results.
+ original_run_secs : List[float]
+ The original run times.
+ print_result : Callable
+ The print result function.
+
+ Returns
+ -------
+ f_with_args_alloc_argument : Callable
+ The function to allocate arguments.
+
+ f_with_args_run_evaluator : Callable
+ The function to run evaluator.
+ """
+
+ def f_with_args_alloc_argument(
+ # pylint: disable=unused-argument
+ session: tvm.rpc.RPCSession,
+ device: tvm.runtime.Device,
+ args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+ alloc_repeat: int,
+ # pylint: enable=unused-argument
+ ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+ """Allocate arguments using the given inputs.
+
+ Parameters
+ ----------
+ session : RPCSession
+ The RPC session.
+ device : Device
+ The device.
+ args_info : T_ARG_INFO_JSON_OBJ_LIST
+ argument information.
+ alloc_repeat : int
+ The number of times to repeat the allocation.
+
+ Returns
+ -------
+ args_list : List[T_ARGUMENT_LIST]
+ The list of argument lists.
+ """
+ return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+ def f_with_args_run_evaluator(
+ session: tvm.rpc.RPCSession, # pylint: disable=unused-argument
+ rt_mod: tvm.runtime.Module,
+ device: tvm.runtime.Device,
+ evaluator_config: ms.runner.EvaluatorConfig,
+ repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+ ) -> List[float]:
+ """With args function to run the evaluator
+
+ Parameters
+ ----------
+ session : tvm.rpc.RPCSession
+ The RPC session
+ rt_mod: Module
+ The runtime module
+ device: Device
+ The device to run the evaluator
+ evaluator_config: EvaluatorConfig
+ The evaluator config
+ repeated_args: List[T_ARGUMENT_LIST]
+ The repeated arguments
+
+ Returns
+ -------
+ costs: List[float]
+ The evaluator results
+ """
+ evaluator = rt_mod.time_evaluator(
+ func_name=rt_mod.entry_name,
+ dev=device,
+ number=evaluator_config.number,
+ repeat=evaluator_config.repeat,
+ min_repeat_ms=evaluator_config.min_repeat_ms,
+ f_preproc="cache_flush_cpu_non_first_arg"
+ if evaluator_config.enable_cpu_cache_flush
+ else "",
+ )
+
+ repeated_costs: List[List[float]] = []
+ for args in repeated_args:
+ device.sync()
+ profile_result = evaluator(*args)
+ repeated_costs.append(profile_result.results)
+ costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+ assert len(repeated_args) == 1, "Only support one set of arguments"
+ scheduled_res = [arg.numpy() for arg in repeated_args[0]] # type: ignore
+ # fetch comparison function
+ passed = check_and_run(
+ ARGS.check_metric_func,
+ to_tvm_ndarray(original_res),
+ to_tvm_ndarray(scheduled_res),
+ )
+
+ print_result(
+ result="pass" if passed else "wrong answer",
+ original_mod=original_mod,
+ scheduled_mod=scheduled_mod,
+ trace=trace,
+ inputs=inputs,
+ original_res=original_res,
+ scheduled_res=scheduled_res,
+ original_run_secs=original_run_secs,
+ scheduled_run_secs=costs,
)
- return False
+
+ return costs
+
+ return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+ mod: IRModule,
+ target: Target,
+ device: tvm.runtime.Device,
+ inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+ """Build and run the module locally.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to build and run
+ target: Target
+ The target to build the module
+ device: Device
+ The device to run the module
+ inputs: List[np.ndarray]
+ The inputs to run the module
+
+ Returns
+ -------
+ res: List[np.ndarray]
+ The results of running the module
+ run_secs: List[float]
+ The running time of running the module
+ """
+ # potential memory leak https://github.com/apache/tvm/issues/11096
+ lib = tvm.build(mod, target=target)
+ tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+ device.sync()
+ func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+ benchmark_res = func(*tvm_inputs)
+ device.sync()
+ return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+ """Check if the builder result is defined.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ """
+ assert builder_result.error_msg is None, "Builder failed: " + str(
+ builder_result.error_msg if builder_result.error_msg else "Empty error message"
+ )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+ """Apply the trace to the module.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to apply the trace to
+ trace: Trace
+ The trace to apply
+
+ Returns
+ -------
+ mod: IRModule
+ The module with the trace applied
+ """
+ sch = Schedule(mod)
+ trace.apply_to_schedule(sch, remove_postproc=False)
+ return sch.mod
+
+
+def _build_all_mods(
+ mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+ """Build all the modules.
+
+ Parameters
+ ----------
+ mods: List[IRModule]
+ The modules to build
+ builder: Builder
+ The builder to build the modules
+ target: Target
+ The target to build the modules
+
+ Returns
+ -------
+ builder_results: List[BuilderResult]
+ The builder results
+ """
+ builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+ assert len(builder_results) == len(
+ mods
+ ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+ return builder_results
+
+
+def _run_single_mod(
+ builder_result: ms.builder.BuilderResult,
+ runner: ms.runner.Runner,
+ dev_type: str,
+) -> None:
+ """Run a single module.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ runner: Runner
+ The runner to run the module
+ dev_type: str
+ The device type
+ """
+ runner_futures = runner.run(
+ # arginfo is not used in this case so we can pass an empty list
+ [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+ )
+ assert (
+ len(runner_futures) == 1
+ ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+ (runner_future,) = runner_futures # pylint: disable=unbalanced-tuple-unpacking
+ runner_res = runner_future.result()
+ assert runner_res.error_msg is None, "Runner failed: " + (
+ runner_res.error_msg if runner_res.error_msg else "Empty error message"
+ )
def main():
"""Main function"""
describe()
- database = ms.database.create(work_dir=ARGS.work_dir)
- target = ARGS.target
- if target.kind.name == "llvm":
- dev_type = "cpu"
- elif target.kind.name == "cuda":
- dev_type = "cuda"
- else:
- raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
- records = database.get_all_tuning_records()
with ms.Profiler() as profiler:
- for i, record in enumerate(records):
- scope_name = f"validate #{i}"
- with profiler.timeit(scope_name):
- original_mod = record.workload.mod
- sch = Schedule(original_mod)
- record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
- scheduled_mod = sch.mod
- is_success = False
+ # initialize
+ target = ARGS.target
+ dev_type = get_device_type(target)
+ builder = ms.builder.LocalBuilder()
+ database = ms.database.create(work_dir=ARGS.work_dir)
+
+ # collect records
+ with profiler.timeit("collect records"):
+ records = database.get_all_tuning_records()
+ total = len(records)
+ print(
+ f"Total {total} records to be validated. "
+ f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+ )
+
+ # collect unique original TIR
+ with profiler.timeit("deduplicate records"):
+ workloads = set()
+ for record in records:
+ workloads.add(OriginalModule(record.workload.mod))
+ print(
+ f"Total {len(workloads)} unique original TIR to validate. "
+ f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+ )
+ if ARGS.top_k < 10**9:
+ print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+ total = len(workloads) * ARGS.top_k
+ print()
+
+ # validate correctness
+ counter = 0
+ for item in workloads:
+ original_mod = item.mod
+ records = database.get_top_k(
+ workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+ )
+ if len(records) < ARGS.top_k:
+ total -= ARGS.top_k - len(records)
+ inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+ original_res, original_run_secs = local_build_and_run(
+ original_mod,
+ target=ARGS.baseline_target,
+ inputs=inputs,
+ device=get_runtime_device(ARGS.baseline_target),
+ )
+ scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+ builder_results = _build_all_mods(scheduled_mods, builder, target) # type: ignore
+ for i, record in enumerate(records):
+ counter += 1
+ print_result = print_with_counter_func(counter=counter, total=total)
+ if is_failed_record(record):
+ # skip failed records where run_secs is 1e9
+ # these records are only negative samples for cost model
+ print_result(result="skip")
+ continue
try:
- is_success = validate_correctness(
- original_mod=original_mod,
- scheduled_mod=scheduled_mod,
- target=target,
- baseline_target=ARGS.baseline_target,
- dev_type=dev_type,
+ # prepare scheduled module
+ scheduled_mod = scheduled_mods[i]
+ # check build result
+ builder_result = builder_results[i]
+ _check_builder_result(builder_result)
+ # fetch functions
+ (
+ f_with_args_alloc_argument,
+ f_with_args_run_evaluator,
+ ) = make_alloc_arg_and_check(
+ inputs,
+ original_mod,
+ scheduled_mod,
+ str(record.trace),
+ original_res=original_res,
+ original_run_secs=original_run_secs,
+ print_result=print_result,
+ )
+ # create rpc runner
+ runner = ms.runner.RPCRunner(
Review Comment:
I think that's a reasonable ask, what do you think would be a good user interface?
Say, `--rpc-host="local"`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] zxybazh commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script
Posted by GitBox <gi...@apache.org>.
zxybazh commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028656409
##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
return inputs
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
- assert len(a) == len(b), "Different number of outputs from two modules"
- for i, _ in enumerate(a):
- if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
- return False
- return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+ """Convert a list of TVM NDArray to a list of numpy array
+ Parameters
+ ----------
+ a : List[tvm.nd.NDArray]
+ The list of TVM NDArray to be converted
-def validate_correctness(
- original_mod: IRModule, # compiled for "baseline_target"
- scheduled_mod: IRModule, # compiled for "target"
- *,
- baseline_target: Target,
- target: Target,
- dev_type: str,
- rpc_config: ms.runner.RPCConfig,
- f_input_generator: Union[
- str, Callable[[IRModule], List[tvm.nd.NDArray]]
- ] = default_input_generator,
- f_check_metric: Union[
- str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
- ] = default_check_metric,
-) -> bool:
- """Function to validate the correctness of a scheduled module.
+ Returns
+ -------
+ b : List[np.ndarray]
+ The list of numpy array
+ """
+ assert a is not None, "Empty result cannot be converted to numpy"
+ return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+ """Convert a list of numpy array to a list of TVM NDArray
Parameters
----------
- original_mod : IRModule
- The original module to be compiled.
- scheduled_mod : IRModule
- The scheduled module to be compiled.
- baseline_target : Target
- The baseline target to compile the original module.
- target : Target
- The target to compile the scheduled module.
- dev_type : str
- The device type to run the module via rpc.
- rpc_config : RPCConfig
- The RPCConfig to run the scheduled module.
- f_input_generator : Union[str, Callable]
- The function to generate the input data.
- f_check_metric : Union[str, Callable]
- The function to check the metric.
+ a : List[np.ndarray]
+ The list of numpy array to be converted.
Returns
-------
- result : bool
- The result of the validation.
+ b : List[tvm.nd.NDArray]
+ The list of TVM NDArray.
"""
+ assert a is not None, "Empty result cannot be converted to TVM NDArray"
+ return [tvm.nd.array(x) for x in a]
- def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
- """Convert a list of TVM NDArray to a list of numpy array"""
- assert a is not None, "Empty result cannot be converted to numpy"
- return [x.numpy() for x in a]
-
- def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
- """Convert a list of numpy array to a list of TVM NDArray"""
- assert a is not None, "Empty result cannot be converted to TVM NDArray"
- return [tvm.nd.array(x) for x in a]
-
- def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
- """Build and run the module on the target device."""
- rt_mod = tvm.build(mod, target=target)
- return run_module_via_rpc(
- rpc_config=rpc_config,
- lib=rt_mod,
- dev_type=dev_type,
- args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension
- continuation=create_calculator(backend="tir"),
- backend="tir",
- )
- # fetch functions & prepare inputs
- if isinstance(f_input_generator, str):
- f_input_generator = get_global_func(f_input_generator)
- if isinstance(f_check_metric, str):
- f_check_metric = get_global_func(f_check_metric)
- inputs = to_numpy(f_input_generator(original_mod)) # type: ignore
- # build & run original result
- original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
- scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
- # check metric
- if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore
- return True
- else:
- print(
- ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+ """Check if a tuning record is failed.
+
+ Parameters
+ ----------
+ record : TuningRecord
+ The tuning record to check.
+
+ Returns
+ -------
+ is_failed : bool
+ """
+ return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+ """Print with counter
+
+ Parameters
+ ----------
+ counter : int
+ The counter to print with.
+ total : int
+ The total number of items to print with.
+
+ Returns
+ -------
+ print_result : Callable
+ The print result function.
+ """
+
+ def print_result(
+ result: str,
+ *,
+ original_mod: IRModule = None,
+ scheduled_mod: IRModule = None,
+ inputs: List[np.ndarray] = None,
+ original_res: List[np.ndarray] = None,
+ scheduled_res: List[np.ndarray] = None,
+ original_run_secs: List[float] = None,
+ scheduled_run_secs: List[float] = None,
+ exception: Exception = None,
+ trace: str = None,
+ ) -> None:
+ """Print the validation result."""
+ status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+ if result in ["pass", "wrong answer"]:
+ status += (
+ f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+ f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+ )
+
+ output = [status]
+ if result not in ["pass", "skip"]:
+ output.extend(
[
- "Validation failed!",
- "Original Result:" + DELIMITOR + str(original_res),
- "Scheduled Result:" + DELIMITOR + str(scheduled_res),
- "Input:" + DELIMITOR + str(inputs),
"Original IRModule:" + DELIMITOR + original_mod.script(),
"Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+ "Trace" + DELIMITOR + str(trace),
]
)
+ if result == "wrong answer":
+ output.extend(
+ [
+ "Input:" + DELIMITOR + str(inputs),
+ "Original Result:" + DELIMITOR + str(original_res),
+ "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+ "Max Diff:"
+ + DELIMITOR
+ + str(
+ [
+ np.max(np.abs(original_res[i] - scheduled_res[i]))
+ for i in range(len(original_res))
+ ]
+ )
+ + "\n",
+ ]
+ )
+ elif result == "exception":
+ output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+ else:
+ raise ValueError(f"Unknown result: {result}")
+ print("\n\n".join(output))
+
+ return print_result
+
+
+def make_alloc_arg_and_check(
+ inputs: List[np.ndarray],
+ original_mod: IRModule,
+ scheduled_mod: IRModule,
+ trace: str,
+ original_res: List[np.ndarray],
+ original_run_secs: List[float],
+ print_result: Callable,
+) -> Tuple[Callable, Callable]:
+ """Make alloc_arg and check functions for the given inputs and collect results.
+
+ Parameters
+ ----------
+ inputs : List[np.ndarray]
+ The inputs to the two modules.
+ original_mod : IRModule
+ The original IRModule.
+ scheduled_mod : IRModule
+ The scheduled IRModule.
+ trace : str
+ The trace of the scheduled IRModule.
+ original_res : List[np.ndarray]
+ The original results.
+ original_run_secs : List[float]
+ The original run times.
+ print_result : Callable
+ The print result function.
+
+ Returns
+ -------
+ f_with_args_alloc_argument : Callable
+ The function to allocate arguments.
+
+ f_with_args_run_evaluator : Callable
+ The function to run evaluator.
+ """
+
+ def f_with_args_alloc_argument(
+ # pylint: disable=unused-argument
+ session: tvm.rpc.RPCSession,
+ device: tvm.runtime.Device,
+ args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+ alloc_repeat: int,
+ # pylint: enable=unused-argument
+ ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+ """Allocate arguments using the given inputs.
+
+ Parameters
+ ----------
+ session : RPCSession
+ The RPC session.
+ device : Device
+ The device.
+ args_info : T_ARG_INFO_JSON_OBJ_LIST
+ argument information.
+ alloc_repeat : int
+ The number of times to repeat the allocation.
+
+ Returns
+ -------
+ args_list : List[T_ARGUMENT_LIST]
+ The list of argument lists.
+ """
+ return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+ def f_with_args_run_evaluator(
+ session: tvm.rpc.RPCSession, # pylint: disable=unused-argument
+ rt_mod: tvm.runtime.Module,
+ device: tvm.runtime.Device,
+ evaluator_config: ms.runner.EvaluatorConfig,
+ repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+ ) -> List[float]:
+ """With args function to run the evaluator
+
+ Parameters
+ ----------
+ session : tvm.rpc.RPCSession
+ The RPC session
+ rt_mod: Module
+ The runtime module
+ device: Device
+ The device to run the evaluator
+ evaluator_config: EvaluatorConfig
+ The evaluator config
+ repeated_args: List[T_ARGUMENT_LIST]
+ The repeated arguments
+
+ Returns
+ -------
+ costs: List[float]
+ The evaluator results
+ """
+ evaluator = rt_mod.time_evaluator(
+ func_name=rt_mod.entry_name,
+ dev=device,
+ number=evaluator_config.number,
+ repeat=evaluator_config.repeat,
+ min_repeat_ms=evaluator_config.min_repeat_ms,
+ f_preproc="cache_flush_cpu_non_first_arg"
+ if evaluator_config.enable_cpu_cache_flush
+ else "",
+ )
+
+ repeated_costs: List[List[float]] = []
+ for args in repeated_args:
+ device.sync()
+ profile_result = evaluator(*args)
+ repeated_costs.append(profile_result.results)
+ costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+ assert len(repeated_args) == 1, "Only support one set of arguments"
+ scheduled_res = [arg.numpy() for arg in repeated_args[0]] # type: ignore
+ # fetch comparison function
+ passed = check_and_run(
+ ARGS.check_metric_func,
+ to_tvm_ndarray(original_res),
+ to_tvm_ndarray(scheduled_res),
+ )
+
+ print_result(
+ result="pass" if passed else "wrong answer",
+ original_mod=original_mod,
+ scheduled_mod=scheduled_mod,
+ trace=trace,
+ inputs=inputs,
+ original_res=original_res,
+ scheduled_res=scheduled_res,
+ original_run_secs=original_run_secs,
+ scheduled_run_secs=costs,
)
- return False
+
+ return costs
+
+ return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+ mod: IRModule,
+ target: Target,
+ device: tvm.runtime.Device,
+ inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+ """Build and run the module locally.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to build and run
+ target: Target
+ The target to build the module
+ device: Device
+ The device to run the module
+ inputs: List[np.ndarray]
+ The inputs to run the module
+
+ Returns
+ -------
+ res: List[np.ndarray]
+ The results of running the module
+ run_secs: List[float]
+ The running time of running the module
+ """
+ # potential memory leak https://github.com/apache/tvm/issues/11096
+ lib = tvm.build(mod, target=target)
+ tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+ device.sync()
+ func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+ benchmark_res = func(*tvm_inputs)
+ device.sync()
+ return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+ """Check if the builder result is defined.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ """
+ assert builder_result.error_msg is None, "Builder failed: " + str(
+ builder_result.error_msg if builder_result.error_msg else "Empty error message"
+ )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+ """Apply the trace to the module.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to apply the trace to
+ trace: Trace
+ The trace to apply
+
+ Returns
+ -------
+ mod: IRModule
+ The module with the trace applied
+ """
+ sch = Schedule(mod)
+ trace.apply_to_schedule(sch, remove_postproc=False)
+ return sch.mod
+
+
+def _build_all_mods(
+ mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+ """Build all the modules.
+
+ Parameters
+ ----------
+ mods: List[IRModule]
+ The modules to build
+ builder: Builder
+ The builder to build the modules
+ target: Target
+ The target to build the modules
+
+ Returns
+ -------
+ builder_results: List[BuilderResult]
+ The builder results
+ """
+ builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+ assert len(builder_results) == len(
+ mods
+ ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+ return builder_results
+
+
+def _run_single_mod(
+ builder_result: ms.builder.BuilderResult,
+ runner: ms.runner.Runner,
+ dev_type: str,
+) -> None:
+ """Run a single module.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ runner: Runner
+ The runner to run the module
+ dev_type: str
+ The device type
+ """
+ runner_futures = runner.run(
+ # arginfo is not used in this case so we can pass an empty list
+ [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+ )
+ assert (
+ len(runner_futures) == 1
+ ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+ (runner_future,) = runner_futures # pylint: disable=unbalanced-tuple-unpacking
+ runner_res = runner_future.result()
+ assert runner_res.error_msg is None, "Runner failed: " + (
+ runner_res.error_msg if runner_res.error_msg else "Empty error message"
+ )
def main():
"""Main function"""
describe()
- database = ms.database.create(work_dir=ARGS.work_dir)
- target = ARGS.target
- if target.kind.name == "llvm":
- dev_type = "cpu"
- elif target.kind.name == "cuda":
- dev_type = "cuda"
- else:
- raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
- records = database.get_all_tuning_records()
with ms.Profiler() as profiler:
- for i, record in enumerate(records):
- scope_name = f"validate #{i}"
- with profiler.timeit(scope_name):
- original_mod = record.workload.mod
- sch = Schedule(original_mod)
- record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
- scheduled_mod = sch.mod
- is_success = False
+ # initialize
+ target = ARGS.target
+ dev_type = get_device_type(target)
+ builder = ms.builder.LocalBuilder()
+ database = ms.database.create(work_dir=ARGS.work_dir)
+
+ # collect records
+ with profiler.timeit("collect records"):
+ records = database.get_all_tuning_records()
+ total = len(records)
+ print(
+ f"Total {total} records to be validated. "
+ f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+ )
+
+ # collect unique original TIR
+ with profiler.timeit("deduplicate records"):
+ workloads = set()
+ for record in records:
+ workloads.add(OriginalModule(record.workload.mod))
+ print(
+ f"Total {len(workloads)} unique original TIR to validate. "
+ f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+ )
+ if ARGS.top_k < 10**9:
+ print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+ total = len(workloads) * ARGS.top_k
+ print()
+
+ # validate correctness
+ counter = 0
+ for item in workloads:
+ original_mod = item.mod
+ records = database.get_top_k(
+ workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+ )
+ if len(records) < ARGS.top_k:
+ total -= ARGS.top_k - len(records)
+ inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+ original_res, original_run_secs = local_build_and_run(
+ original_mod,
+ target=ARGS.baseline_target,
+ inputs=inputs,
+ device=get_runtime_device(ARGS.baseline_target),
+ )
+ scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+ builder_results = _build_all_mods(scheduled_mods, builder, target) # type: ignore
+ for i, record in enumerate(records):
+ counter += 1
+ print_result = print_with_counter_func(counter=counter, total=total)
+ if is_failed_record(record):
+ # skip failed records where run_secs is 1e9
+ # these records are only negative samples for cost model
+ print_result(result="skip")
+ continue
try:
- is_success = validate_correctness(
- original_mod=original_mod,
- scheduled_mod=scheduled_mod,
- target=target,
- baseline_target=ARGS.baseline_target,
- dev_type=dev_type,
+ # prepare scheduled module
+ scheduled_mod = scheduled_mods[i]
+ # check build result
+ builder_result = builder_results[i]
+ _check_builder_result(builder_result)
+ # fetch functions
+ (
+ f_with_args_alloc_argument,
+ f_with_args_run_evaluator,
+ ) = make_alloc_arg_and_check(
+ inputs,
+ original_mod,
+ scheduled_mod,
+ str(record.trace),
+ original_res=original_res,
+ original_run_secs=original_run_secs,
+ print_result=print_result,
+ )
+ # create rpc runner
+ runner = ms.runner.RPCRunner(
Review Comment:
I'll throw a warning though, to make sure user is aware of that.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] vinx13 merged pull request #13459: [MetaSchedule] Enhance Database Validation Script
Posted by GitBox <gi...@apache.org>.
vinx13 merged PR #13459:
URL: https://github.com/apache/tvm/pull/13459
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] yelite commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script
Posted by GitBox <gi...@apache.org>.
yelite commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028648334
##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
return inputs
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
- assert len(a) == len(b), "Different number of outputs from two modules"
- for i, _ in enumerate(a):
- if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
- return False
- return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+ """Convert a list of TVM NDArray to a list of numpy array
+ Parameters
+ ----------
+ a : List[tvm.nd.NDArray]
+ The list of TVM NDArray to be converted
-def validate_correctness(
- original_mod: IRModule, # compiled for "baseline_target"
- scheduled_mod: IRModule, # compiled for "target"
- *,
- baseline_target: Target,
- target: Target,
- dev_type: str,
- rpc_config: ms.runner.RPCConfig,
- f_input_generator: Union[
- str, Callable[[IRModule], List[tvm.nd.NDArray]]
- ] = default_input_generator,
- f_check_metric: Union[
- str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
- ] = default_check_metric,
-) -> bool:
- """Function to validate the correctness of a scheduled module.
+ Returns
+ -------
+ b : List[np.ndarray]
+ The list of numpy array
+ """
+ assert a is not None, "Empty result cannot be converted to numpy"
+ return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+ """Convert a list of numpy array to a list of TVM NDArray
Parameters
----------
- original_mod : IRModule
- The original module to be compiled.
- scheduled_mod : IRModule
- The scheduled module to be compiled.
- baseline_target : Target
- The baseline target to compile the original module.
- target : Target
- The target to compile the scheduled module.
- dev_type : str
- The device type to run the module via rpc.
- rpc_config : RPCConfig
- The RPCConfig to run the scheduled module.
- f_input_generator : Union[str, Callable]
- The function to generate the input data.
- f_check_metric : Union[str, Callable]
- The function to check the metric.
+ a : List[np.ndarray]
+ The list of numpy array to be converted.
Returns
-------
- result : bool
- The result of the validation.
+ b : List[tvm.nd.NDArray]
+ The list of TVM NDArray.
"""
+ assert a is not None, "Empty result cannot be converted to TVM NDArray"
+ return [tvm.nd.array(x) for x in a]
- def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
- """Convert a list of TVM NDArray to a list of numpy array"""
- assert a is not None, "Empty result cannot be converted to numpy"
- return [x.numpy() for x in a]
-
- def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
- """Convert a list of numpy array to a list of TVM NDArray"""
- assert a is not None, "Empty result cannot be converted to TVM NDArray"
- return [tvm.nd.array(x) for x in a]
-
- def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
- """Build and run the module on the target device."""
- rt_mod = tvm.build(mod, target=target)
- return run_module_via_rpc(
- rpc_config=rpc_config,
- lib=rt_mod,
- dev_type=dev_type,
- args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension
- continuation=create_calculator(backend="tir"),
- backend="tir",
- )
- # fetch functions & prepare inputs
- if isinstance(f_input_generator, str):
- f_input_generator = get_global_func(f_input_generator)
- if isinstance(f_check_metric, str):
- f_check_metric = get_global_func(f_check_metric)
- inputs = to_numpy(f_input_generator(original_mod)) # type: ignore
- # build & run original result
- original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
- scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
- # check metric
- if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore
- return True
- else:
- print(
- ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+ """Check if a tuning record is failed.
+
+ Parameters
+ ----------
+ record : TuningRecord
+ The tuning record to check.
+
+ Returns
+ -------
+ is_failed : bool
+ """
+ return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+ """Print with counter
+
+ Parameters
+ ----------
+ counter : int
+ The counter to print with.
+ total : int
+ The total number of items to print with.
+
+ Returns
+ -------
+ print_result : Callable
+ The print result function.
+ """
+
+ def print_result(
+ result: str,
+ *,
+ original_mod: IRModule = None,
+ scheduled_mod: IRModule = None,
+ inputs: List[np.ndarray] = None,
+ original_res: List[np.ndarray] = None,
+ scheduled_res: List[np.ndarray] = None,
+ original_run_secs: List[float] = None,
+ scheduled_run_secs: List[float] = None,
+ exception: Exception = None,
+ trace: str = None,
+ ) -> None:
+ """Print the validation result."""
+ status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+ if result in ["pass", "wrong answer"]:
+ status += (
+ f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+ f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+ )
+
+ output = [status]
+ if result not in ["pass", "skip"]:
+ output.extend(
[
- "Validation failed!",
- "Original Result:" + DELIMITOR + str(original_res),
- "Scheduled Result:" + DELIMITOR + str(scheduled_res),
- "Input:" + DELIMITOR + str(inputs),
"Original IRModule:" + DELIMITOR + original_mod.script(),
"Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+ "Trace" + DELIMITOR + str(trace),
]
)
+ if result == "wrong answer":
+ output.extend(
+ [
+ "Input:" + DELIMITOR + str(inputs),
+ "Original Result:" + DELIMITOR + str(original_res),
+ "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+ "Max Diff:"
+ + DELIMITOR
+ + str(
+ [
+ np.max(np.abs(original_res[i] - scheduled_res[i]))
+ for i in range(len(original_res))
+ ]
+ )
+ + "\n",
+ ]
+ )
+ elif result == "exception":
+ output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+ else:
+ raise ValueError(f"Unknown result: {result}")
+ print("\n\n".join(output))
+
+ return print_result
+
+
+def make_alloc_arg_and_check(
+ inputs: List[np.ndarray],
+ original_mod: IRModule,
+ scheduled_mod: IRModule,
+ trace: str,
+ original_res: List[np.ndarray],
+ original_run_secs: List[float],
+ print_result: Callable,
+) -> Tuple[Callable, Callable]:
+ """Make alloc_arg and check functions for the given inputs and collect results.
+
+ Parameters
+ ----------
+ inputs : List[np.ndarray]
+ The inputs to the two modules.
+ original_mod : IRModule
+ The original IRModule.
+ scheduled_mod : IRModule
+ The scheduled IRModule.
+ trace : str
+ The trace of the scheduled IRModule.
+ original_res : List[np.ndarray]
+ The original results.
+ original_run_secs : List[float]
+ The original run times.
+ print_result : Callable
+ The print result function.
+
+ Returns
+ -------
+ f_with_args_alloc_argument : Callable
+ The function to allocate arguments.
+
+ f_with_args_run_evaluator : Callable
+ The function to run evaluator.
+ """
+
+ def f_with_args_alloc_argument(
+ # pylint: disable=unused-argument
+ session: tvm.rpc.RPCSession,
+ device: tvm.runtime.Device,
+ args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+ alloc_repeat: int,
+ # pylint: enable=unused-argument
+ ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+ """Allocate arguments using the given inputs.
+
+ Parameters
+ ----------
+ session : RPCSession
+ The RPC session.
+ device : Device
+ The device.
+ args_info : T_ARG_INFO_JSON_OBJ_LIST
+ argument information.
+ alloc_repeat : int
+ The number of times to repeat the allocation.
+
+ Returns
+ -------
+ args_list : List[T_ARGUMENT_LIST]
+ The list of argument lists.
+ """
+ return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+ def f_with_args_run_evaluator(
+ session: tvm.rpc.RPCSession, # pylint: disable=unused-argument
+ rt_mod: tvm.runtime.Module,
+ device: tvm.runtime.Device,
+ evaluator_config: ms.runner.EvaluatorConfig,
+ repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+ ) -> List[float]:
+ """With args function to run the evaluator
+
+ Parameters
+ ----------
+ session : tvm.rpc.RPCSession
+ The RPC session
+ rt_mod: Module
+ The runtime module
+ device: Device
+ The device to run the evaluator
+ evaluator_config: EvaluatorConfig
+ The evaluator config
+ repeated_args: List[T_ARGUMENT_LIST]
+ The repeated arguments
+
+ Returns
+ -------
+ costs: List[float]
+ The evaluator results
+ """
+ evaluator = rt_mod.time_evaluator(
+ func_name=rt_mod.entry_name,
+ dev=device,
+ number=evaluator_config.number,
+ repeat=evaluator_config.repeat,
+ min_repeat_ms=evaluator_config.min_repeat_ms,
+ f_preproc="cache_flush_cpu_non_first_arg"
+ if evaluator_config.enable_cpu_cache_flush
+ else "",
+ )
+
+ repeated_costs: List[List[float]] = []
+ for args in repeated_args:
+ device.sync()
+ profile_result = evaluator(*args)
+ repeated_costs.append(profile_result.results)
+ costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+ assert len(repeated_args) == 1, "Only support one set of arguments"
+ scheduled_res = [arg.numpy() for arg in repeated_args[0]] # type: ignore
+ # fetch comparison function
+ passed = check_and_run(
+ ARGS.check_metric_func,
+ to_tvm_ndarray(original_res),
+ to_tvm_ndarray(scheduled_res),
+ )
+
+ print_result(
+ result="pass" if passed else "wrong answer",
+ original_mod=original_mod,
+ scheduled_mod=scheduled_mod,
+ trace=trace,
+ inputs=inputs,
+ original_res=original_res,
+ scheduled_res=scheduled_res,
+ original_run_secs=original_run_secs,
+ scheduled_run_secs=costs,
)
- return False
+
+ return costs
+
+ return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+ mod: IRModule,
+ target: Target,
+ device: tvm.runtime.Device,
+ inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+ """Build and run the module locally.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to build and run
+ target: Target
+ The target to build the module
+ device: Device
+ The device to run the module
+ inputs: List[np.ndarray]
+ The inputs to run the module
+
+ Returns
+ -------
+ res: List[np.ndarray]
+ The results of running the module
+ run_secs: List[float]
+ The running time of running the module
+ """
+ # potential memory leak https://github.com/apache/tvm/issues/11096
+ lib = tvm.build(mod, target=target)
+ tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+ device.sync()
+ func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+ benchmark_res = func(*tvm_inputs)
+ device.sync()
+ return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+ """Check if the builder result is defined.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ """
+ assert builder_result.error_msg is None, "Builder failed: " + str(
+ builder_result.error_msg if builder_result.error_msg else "Empty error message"
+ )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+ """Apply the trace to the module.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to apply the trace to
+ trace: Trace
+ The trace to apply
+
+ Returns
+ -------
+ mod: IRModule
+ The module with the trace applied
+ """
+ sch = Schedule(mod)
+ trace.apply_to_schedule(sch, remove_postproc=False)
+ return sch.mod
+
+
+def _build_all_mods(
+ mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+ """Build all the modules.
+
+ Parameters
+ ----------
+ mods: List[IRModule]
+ The modules to build
+ builder: Builder
+ The builder to build the modules
+ target: Target
+ The target to build the modules
+
+ Returns
+ -------
+ builder_results: List[BuilderResult]
+ The builder results
+ """
+ builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+ assert len(builder_results) == len(
+ mods
+ ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+ return builder_results
+
+
+def _run_single_mod(
+ builder_result: ms.builder.BuilderResult,
+ runner: ms.runner.Runner,
+ dev_type: str,
+) -> None:
+ """Run a single module.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ runner: Runner
+ The runner to run the module
+ dev_type: str
+ The device type
+ """
+ runner_futures = runner.run(
+ # arginfo is not used in this case so we can pass an empty list
+ [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+ )
+ assert (
+ len(runner_futures) == 1
+ ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+ (runner_future,) = runner_futures # pylint: disable=unbalanced-tuple-unpacking
+ runner_res = runner_future.result()
+ assert runner_res.error_msg is None, "Runner failed: " + (
+ runner_res.error_msg if runner_res.error_msg else "Empty error message"
+ )
def main():
"""Main function"""
describe()
- database = ms.database.create(work_dir=ARGS.work_dir)
- target = ARGS.target
- if target.kind.name == "llvm":
- dev_type = "cpu"
- elif target.kind.name == "cuda":
- dev_type = "cuda"
- else:
- raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
- records = database.get_all_tuning_records()
with ms.Profiler() as profiler:
- for i, record in enumerate(records):
- scope_name = f"validate #{i}"
- with profiler.timeit(scope_name):
- original_mod = record.workload.mod
- sch = Schedule(original_mod)
- record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
- scheduled_mod = sch.mod
- is_success = False
+ # initialize
+ target = ARGS.target
+ dev_type = get_device_type(target)
+ builder = ms.builder.LocalBuilder()
+ database = ms.database.create(work_dir=ARGS.work_dir)
+
+ # collect records
+ with profiler.timeit("collect records"):
+ records = database.get_all_tuning_records()
+ total = len(records)
+ print(
+ f"Total {total} records to be validated. "
+ f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+ )
+
+ # collect unique original TIR
+ with profiler.timeit("deduplicate records"):
+ workloads = set()
+ for record in records:
+ workloads.add(OriginalModule(record.workload.mod))
+ print(
+ f"Total {len(workloads)} unique original TIR to validate. "
+ f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+ )
+ if ARGS.top_k < 10**9:
+ print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+ total = len(workloads) * ARGS.top_k
+ print()
+
+ # validate correctness
+ counter = 0
+ for item in workloads:
+ original_mod = item.mod
+ records = database.get_top_k(
+ workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+ )
+ if len(records) < ARGS.top_k:
+ total -= ARGS.top_k - len(records)
+ inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+ original_res, original_run_secs = local_build_and_run(
+ original_mod,
+ target=ARGS.baseline_target,
+ inputs=inputs,
+ device=get_runtime_device(ARGS.baseline_target),
+ )
+ scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+ builder_results = _build_all_mods(scheduled_mods, builder, target) # type: ignore
+ for i, record in enumerate(records):
+ counter += 1
+ print_result = print_with_counter_func(counter=counter, total=total)
+ if is_failed_record(record):
+ # skip failed records where run_secs is 1e9
+ # these records are only negative samples for cost model
+ print_result(result="skip")
+ continue
try:
- is_success = validate_correctness(
- original_mod=original_mod,
- scheduled_mod=scheduled_mod,
- target=target,
- baseline_target=ARGS.baseline_target,
- dev_type=dev_type,
+ # prepare scheduled module
+ scheduled_mod = scheduled_mods[i]
+ # check build result
+ builder_result = builder_results[i]
+ _check_builder_result(builder_result)
+ # fetch functions
+ (
+ f_with_args_alloc_argument,
+ f_with_args_run_evaluator,
+ ) = make_alloc_arg_and_check(
+ inputs,
+ original_mod,
+ scheduled_mod,
+ str(record.trace),
+ original_res=original_res,
+ original_run_secs=original_run_secs,
+ print_result=print_result,
+ )
+ # create rpc runner
+ runner = ms.runner.RPCRunner(
Review Comment:
Will it be useful to make rpc optional, which means it uses LocalRunner if no rpc related args are provided?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] zxybazh commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script
Posted by GitBox <gi...@apache.org>.
zxybazh commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028656239
##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
return inputs
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
- assert len(a) == len(b), "Different number of outputs from two modules"
- for i, _ in enumerate(a):
- if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
- return False
- return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+ """Convert a list of TVM NDArray to a list of numpy array
+ Parameters
+ ----------
+ a : List[tvm.nd.NDArray]
+ The list of TVM NDArray to be converted
-def validate_correctness(
- original_mod: IRModule, # compiled for "baseline_target"
- scheduled_mod: IRModule, # compiled for "target"
- *,
- baseline_target: Target,
- target: Target,
- dev_type: str,
- rpc_config: ms.runner.RPCConfig,
- f_input_generator: Union[
- str, Callable[[IRModule], List[tvm.nd.NDArray]]
- ] = default_input_generator,
- f_check_metric: Union[
- str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
- ] = default_check_metric,
-) -> bool:
- """Function to validate the correctness of a scheduled module.
+ Returns
+ -------
+ b : List[np.ndarray]
+ The list of numpy array
+ """
+ assert a is not None, "Empty result cannot be converted to numpy"
+ return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+ """Convert a list of numpy array to a list of TVM NDArray
Parameters
----------
- original_mod : IRModule
- The original module to be compiled.
- scheduled_mod : IRModule
- The scheduled module to be compiled.
- baseline_target : Target
- The baseline target to compile the original module.
- target : Target
- The target to compile the scheduled module.
- dev_type : str
- The device type to run the module via rpc.
- rpc_config : RPCConfig
- The RPCConfig to run the scheduled module.
- f_input_generator : Union[str, Callable]
- The function to generate the input data.
- f_check_metric : Union[str, Callable]
- The function to check the metric.
+ a : List[np.ndarray]
+ The list of numpy array to be converted.
Returns
-------
- result : bool
- The result of the validation.
+ b : List[tvm.nd.NDArray]
+ The list of TVM NDArray.
"""
+ assert a is not None, "Empty result cannot be converted to TVM NDArray"
+ return [tvm.nd.array(x) for x in a]
- def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
- """Convert a list of TVM NDArray to a list of numpy array"""
- assert a is not None, "Empty result cannot be converted to numpy"
- return [x.numpy() for x in a]
-
- def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
- """Convert a list of numpy array to a list of TVM NDArray"""
- assert a is not None, "Empty result cannot be converted to TVM NDArray"
- return [tvm.nd.array(x) for x in a]
-
- def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
- """Build and run the module on the target device."""
- rt_mod = tvm.build(mod, target=target)
- return run_module_via_rpc(
- rpc_config=rpc_config,
- lib=rt_mod,
- dev_type=dev_type,
- args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension
- continuation=create_calculator(backend="tir"),
- backend="tir",
- )
- # fetch functions & prepare inputs
- if isinstance(f_input_generator, str):
- f_input_generator = get_global_func(f_input_generator)
- if isinstance(f_check_metric, str):
- f_check_metric = get_global_func(f_check_metric)
- inputs = to_numpy(f_input_generator(original_mod)) # type: ignore
- # build & run original result
- original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
- scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
- # check metric
- if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore
- return True
- else:
- print(
- ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+ """Check if a tuning record is failed.
+
+ Parameters
+ ----------
+ record : TuningRecord
+ The tuning record to check.
+
+ Returns
+ -------
+ is_failed : bool
+ """
+ return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+ """Print with counter
+
+ Parameters
+ ----------
+ counter : int
+ The counter to print with.
+ total : int
+ The total number of items to print with.
+
+ Returns
+ -------
+ print_result : Callable
+ The print result function.
+ """
+
+ def print_result(
+ result: str,
+ *,
+ original_mod: IRModule = None,
+ scheduled_mod: IRModule = None,
+ inputs: List[np.ndarray] = None,
+ original_res: List[np.ndarray] = None,
+ scheduled_res: List[np.ndarray] = None,
+ original_run_secs: List[float] = None,
+ scheduled_run_secs: List[float] = None,
+ exception: Exception = None,
+ trace: str = None,
+ ) -> None:
+ """Print the validation result."""
+ status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+ if result in ["pass", "wrong answer"]:
+ status += (
+ f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+ f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+ )
+
+ output = [status]
+ if result not in ["pass", "skip"]:
+ output.extend(
[
- "Validation failed!",
- "Original Result:" + DELIMITOR + str(original_res),
- "Scheduled Result:" + DELIMITOR + str(scheduled_res),
- "Input:" + DELIMITOR + str(inputs),
"Original IRModule:" + DELIMITOR + original_mod.script(),
"Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+ "Trace" + DELIMITOR + str(trace),
]
)
+ if result == "wrong answer":
+ output.extend(
+ [
+ "Input:" + DELIMITOR + str(inputs),
+ "Original Result:" + DELIMITOR + str(original_res),
+ "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+ "Max Diff:"
+ + DELIMITOR
+ + str(
+ [
+ np.max(np.abs(original_res[i] - scheduled_res[i]))
+ for i in range(len(original_res))
+ ]
+ )
+ + "\n",
+ ]
+ )
+ elif result == "exception":
+ output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+ else:
+ raise ValueError(f"Unknown result: {result}")
+ print("\n\n".join(output))
+
+ return print_result
+
+
+def make_alloc_arg_and_check(
+ inputs: List[np.ndarray],
+ original_mod: IRModule,
+ scheduled_mod: IRModule,
+ trace: str,
+ original_res: List[np.ndarray],
+ original_run_secs: List[float],
+ print_result: Callable,
+) -> Tuple[Callable, Callable]:
+ """Make alloc_arg and check functions for the given inputs and collect results.
+
+ Parameters
+ ----------
+ inputs : List[np.ndarray]
+ The inputs to the two modules.
+ original_mod : IRModule
+ The original IRModule.
+ scheduled_mod : IRModule
+ The scheduled IRModule.
+ trace : str
+ The trace of the scheduled IRModule.
+ original_res : List[np.ndarray]
+ The original results.
+ original_run_secs : List[float]
+ The original run times.
+ print_result : Callable
+ The print result function.
+
+ Returns
+ -------
+ f_with_args_alloc_argument : Callable
+ The function to allocate arguments.
+
+ f_with_args_run_evaluator : Callable
+ The function to run evaluator.
+ """
+
+ def f_with_args_alloc_argument(
+ # pylint: disable=unused-argument
+ session: tvm.rpc.RPCSession,
+ device: tvm.runtime.Device,
+ args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+ alloc_repeat: int,
+ # pylint: enable=unused-argument
+ ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+ """Allocate arguments using the given inputs.
+
+ Parameters
+ ----------
+ session : RPCSession
+ The RPC session.
+ device : Device
+ The device.
+ args_info : T_ARG_INFO_JSON_OBJ_LIST
+ argument information.
+ alloc_repeat : int
+ The number of times to repeat the allocation.
+
+ Returns
+ -------
+ args_list : List[T_ARGUMENT_LIST]
+ The list of argument lists.
+ """
+ return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+ def f_with_args_run_evaluator(
+ session: tvm.rpc.RPCSession, # pylint: disable=unused-argument
+ rt_mod: tvm.runtime.Module,
+ device: tvm.runtime.Device,
+ evaluator_config: ms.runner.EvaluatorConfig,
+ repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+ ) -> List[float]:
+ """With args function to run the evaluator
+
+ Parameters
+ ----------
+ session : tvm.rpc.RPCSession
+ The RPC session
+ rt_mod: Module
+ The runtime module
+ device: Device
+ The device to run the evaluator
+ evaluator_config: EvaluatorConfig
+ The evaluator config
+ repeated_args: List[T_ARGUMENT_LIST]
+ The repeated arguments
+
+ Returns
+ -------
+ costs: List[float]
+ The evaluator results
+ """
+ evaluator = rt_mod.time_evaluator(
+ func_name=rt_mod.entry_name,
+ dev=device,
+ number=evaluator_config.number,
+ repeat=evaluator_config.repeat,
+ min_repeat_ms=evaluator_config.min_repeat_ms,
+ f_preproc="cache_flush_cpu_non_first_arg"
+ if evaluator_config.enable_cpu_cache_flush
+ else "",
+ )
+
+ repeated_costs: List[List[float]] = []
+ for args in repeated_args:
+ device.sync()
+ profile_result = evaluator(*args)
+ repeated_costs.append(profile_result.results)
+ costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+ assert len(repeated_args) == 1, "Only support one set of arguments"
+ scheduled_res = [arg.numpy() for arg in repeated_args[0]] # type: ignore
+ # fetch comparison function
+ passed = check_and_run(
+ ARGS.check_metric_func,
+ to_tvm_ndarray(original_res),
+ to_tvm_ndarray(scheduled_res),
+ )
+
+ print_result(
+ result="pass" if passed else "wrong answer",
+ original_mod=original_mod,
+ scheduled_mod=scheduled_mod,
+ trace=trace,
+ inputs=inputs,
+ original_res=original_res,
+ scheduled_res=scheduled_res,
+ original_run_secs=original_run_secs,
+ scheduled_run_secs=costs,
)
- return False
+
+ return costs
+
+ return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+ mod: IRModule,
+ target: Target,
+ device: tvm.runtime.Device,
+ inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+ """Build and run the module locally.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to build and run
+ target: Target
+ The target to build the module
+ device: Device
+ The device to run the module
+ inputs: List[np.ndarray]
+ The inputs to run the module
+
+ Returns
+ -------
+ res: List[np.ndarray]
+ The results of running the module
+ run_secs: List[float]
+ The running time of running the module
+ """
+ # potential memory leak https://github.com/apache/tvm/issues/11096
+ lib = tvm.build(mod, target=target)
+ tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+ device.sync()
+ func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+ benchmark_res = func(*tvm_inputs)
+ device.sync()
+ return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+ """Check if the builder result is defined.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ """
+ assert builder_result.error_msg is None, "Builder failed: " + str(
+ builder_result.error_msg if builder_result.error_msg else "Empty error message"
+ )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+ """Apply the trace to the module.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to apply the trace to
+ trace: Trace
+ The trace to apply
+
+ Returns
+ -------
+ mod: IRModule
+ The module with the trace applied
+ """
+ sch = Schedule(mod)
+ trace.apply_to_schedule(sch, remove_postproc=False)
+ return sch.mod
+
+
+def _build_all_mods(
+ mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+ """Build all the modules.
+
+ Parameters
+ ----------
+ mods: List[IRModule]
+ The modules to build
+ builder: Builder
+ The builder to build the modules
+ target: Target
+ The target to build the modules
+
+ Returns
+ -------
+ builder_results: List[BuilderResult]
+ The builder results
+ """
+ builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+ assert len(builder_results) == len(
+ mods
+ ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+ return builder_results
+
+
+def _run_single_mod(
+ builder_result: ms.builder.BuilderResult,
+ runner: ms.runner.Runner,
+ dev_type: str,
+) -> None:
+ """Run a single module.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ runner: Runner
+ The runner to run the module
+ dev_type: str
+ The device type
+ """
+ runner_futures = runner.run(
+ # arginfo is not used in this case so we can pass an empty list
+ [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+ )
+ assert (
+ len(runner_futures) == 1
+ ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+ (runner_future,) = runner_futures # pylint: disable=unbalanced-tuple-unpacking
+ runner_res = runner_future.result()
+ assert runner_res.error_msg is None, "Runner failed: " + (
+ runner_res.error_msg if runner_res.error_msg else "Empty error message"
+ )
def main():
"""Main function"""
describe()
- database = ms.database.create(work_dir=ARGS.work_dir)
- target = ARGS.target
- if target.kind.name == "llvm":
- dev_type = "cpu"
- elif target.kind.name == "cuda":
- dev_type = "cuda"
- else:
- raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
- records = database.get_all_tuning_records()
with ms.Profiler() as profiler:
- for i, record in enumerate(records):
- scope_name = f"validate #{i}"
- with profiler.timeit(scope_name):
- original_mod = record.workload.mod
- sch = Schedule(original_mod)
- record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
- scheduled_mod = sch.mod
- is_success = False
+ # initialize
+ target = ARGS.target
+ dev_type = get_device_type(target)
+ builder = ms.builder.LocalBuilder()
+ database = ms.database.create(work_dir=ARGS.work_dir)
+
+ # collect records
+ with profiler.timeit("collect records"):
+ records = database.get_all_tuning_records()
+ total = len(records)
+ print(
+ f"Total {total} records to be validated. "
+ f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+ )
+
+ # collect unique original TIR
+ with profiler.timeit("deduplicate records"):
+ workloads = set()
+ for record in records:
+ workloads.add(OriginalModule(record.workload.mod))
+ print(
+ f"Total {len(workloads)} unique original TIR to validate. "
+ f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+ )
+ if ARGS.top_k < 10**9:
+ print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+ total = len(workloads) * ARGS.top_k
+ print()
+
+ # validate correctness
+ counter = 0
+ for item in workloads:
+ original_mod = item.mod
+ records = database.get_top_k(
+ workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+ )
+ if len(records) < ARGS.top_k:
+ total -= ARGS.top_k - len(records)
+ inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+ original_res, original_run_secs = local_build_and_run(
+ original_mod,
+ target=ARGS.baseline_target,
+ inputs=inputs,
+ device=get_runtime_device(ARGS.baseline_target),
+ )
+ scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+ builder_results = _build_all_mods(scheduled_mods, builder, target) # type: ignore
+ for i, record in enumerate(records):
+ counter += 1
+ print_result = print_with_counter_func(counter=counter, total=total)
+ if is_failed_record(record):
+ # skip failed records where run_secs is 1e9
+ # these records are only negative samples for cost model
+ print_result(result="skip")
+ continue
try:
- is_success = validate_correctness(
- original_mod=original_mod,
- scheduled_mod=scheduled_mod,
- target=target,
- baseline_target=ARGS.baseline_target,
- dev_type=dev_type,
+ # prepare scheduled module
+ scheduled_mod = scheduled_mods[i]
+ # check build result
+ builder_result = builder_results[i]
+ _check_builder_result(builder_result)
+ # fetch functions
+ (
+ f_with_args_alloc_argument,
+ f_with_args_run_evaluator,
+ ) = make_alloc_arg_and_check(
+ inputs,
+ original_mod,
+ scheduled_mod,
+ str(record.trace),
+ original_res=original_res,
+ original_run_secs=original_run_secs,
+ print_result=print_result,
+ )
+ # create rpc runner
+ runner = ms.runner.RPCRunner(
Review Comment:
Sure, can do that.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] tvm-bot commented on pull request #13459: [MetaSchedule] Enhance Database Validation Script
Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13459:
URL: https://github.com/apache/tvm/pull/13459#issuecomment-1322779362
<!---bot-comment-->
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
<sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] zxybazh commented on pull request #13459: [MetaSchedule] Enhance Database Validation Script
Posted by GitBox <gi...@apache.org>.
zxybazh commented on PR #13459:
URL: https://github.com/apache/tvm/pull/13459#issuecomment-1322832239
Sample output looks like this:
```python
Python Environment
TVM version = 0.11.dev0
Python version = 3.8.13 (default, Mar 28 2022, 11:38:47) [GCC 7.5.0] (64 bit)
os.uname() = Linux 5.18.10-76051810-generic #202207071639~1657252310~21.10~7d5e891 SMP PREEMPT_DYNAMIC Fri J x86_64
CMake Options:
{
"BACKTRACE_ON_SEGFAULT": "OFF",
"BUILD_STATIC_RUNTIME": "OFF",
"COMPILER_RT_PATH": "3rdparty/compiler-rt",
"CUDA_VERSION": "NOT-FOUND",
"DLPACK_PATH": "3rdparty/dlpack/include",
"DMLC_PATH": "3rdparty/dmlc-core/include",
"GIT_COMMIT_HASH": "490e0e3120f304a98607770502b5700ec6ab9d55",
"GIT_COMMIT_TIME": "2022-11-18 10:02:55 -0800",
"HIDE_PRIVATE_SYMBOLS": "ON",
"INDEX_DEFAULT_I64": "ON",
"INSTALL_DEV": "OFF",
"LLVM_VERSION": "12.0.1",
"PICOJSON_PATH": "3rdparty/picojson",
"RANG_PATH": "3rdparty/rang/include",
"ROCM_PATH": "/opt/rocm",
"SUMMARIZE": "OFF",
"TVM_CXX_COMPILER_PATH": "/usr/bin/c++",
"USE_ALTERNATIVE_LINKER": "AUTO",
"USE_AOT_EXECUTOR": "ON",
"USE_ARM_COMPUTE_LIB": "OFF",
"USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR": "OFF",
"USE_BLAS": "none",
"USE_BNNS": "OFF",
"USE_BYODT_POSIT": "OFF",
"USE_CCACHE": "AUTO",
"USE_CLML": "OFF",
"USE_CLML_GRAPH_EXECUTOR": "OFF",
"USE_CMSISNN": "OFF",
"USE_COREML": "OFF",
"USE_CPP_RPC": "OFF",
"USE_CUBLAS": "OFF",
"USE_CUDA": "/usr/local/cuda-11.8/",
"USE_CUDNN": "ON",
"USE_CURAND": "ON",
"USE_CUSTOM_LOGGING": "OFF",
"USE_CUTLASS": "OFF",
"USE_DNNL": "OFF",
"USE_ETHOSN": "OFF",
"USE_FALLBACK_STL_MAP": "OFF",
"USE_GRAPH_EXECUTOR": "ON",
"USE_GRAPH_EXECUTOR_CUDA_GRAPH": "OFF",
"USE_GTEST": "AUTO",
"USE_HEXAGON": "OFF",
"USE_HEXAGON_EXTERNAL_LIBS": "OFF",
"USE_HEXAGON_GTEST": "/path/to/hexagon/gtest",
"USE_HEXAGON_RPC": "OFF",
"USE_HEXAGON_SDK": "/path/to/sdk",
"USE_IOS_RPC": "OFF",
"USE_KHRONOS_SPIRV": "OFF",
"USE_LIBBACKTRACE": "ON",
"USE_LIBTORCH": "OFF",
"USE_LLVM": "llvm-config-12 --link-static",
"USE_METAL": "OFF",
"USE_MICRO": "OFF",
"USE_MICRO_STANDALONE_RUNTIME": "OFF",
"USE_MIOPEN": "OFF",
"USE_MKL": "OFF",
"USE_MSVC_MT": "OFF",
"USE_NNPACK": "OFF",
"USE_OPENCL": "OFF",
"USE_OPENCL_GTEST": "/path/to/opencl/gtest",
"USE_OPENMP": "none",
"USE_PAPI": "OFF",
"USE_PROFILER": "ON",
"USE_PT_TVMDSOOP": "OFF",
"USE_RANDOM": "ON",
"USE_RELAY_DEBUG": "OFF",
"USE_ROCBLAS": "OFF",
"USE_ROCM": "OFF",
"USE_RPC": "ON",
"USE_RTTI": "ON",
"USE_RUST_EXT": "OFF",
"USE_SORT": "ON",
"USE_SPIRV_KHR_INTEGER_DOT_PRODUCT": "OFF",
"USE_STACKVM_RUNTIME": "OFF",
"USE_TARGET_ONNX": "OFF",
"USE_TENSORFLOW_PATH": "none",
"USE_TENSORRT_CODEGEN": "OFF",
"USE_TENSORRT_RUNTIME": "OFF",
"USE_TFLITE": "OFF",
"USE_TF_TVMDSOOP": "OFF",
"USE_THREADS": "ON",
"USE_THRUST": "OFF",
"USE_UMA": "OFF",
"USE_VITIS_AI": "OFF",
"USE_VULKAN": "OFF"
}
2022-11-21 16:19:36.788 INFO LocalBuilder: max_workers = 24
Total 1221 records to be validated. Collected in 0.000 sec.
Total 20 unique original TIR to validate. Deduplicated in 0.358 sec.
Top 10 records for each original TIR will be validated.
Progress 1 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 2 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 3 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 4 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 5 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 6 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 7 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 8 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 9 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 10 / 200 (estimated) checked, result: pass, original: 0.006 ms, scheduled: 0.003 ms
Progress 11 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.030 ms
Progress 12 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.035 ms
Progress 13 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.035 ms
Progress 14 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.037 ms
Progress 15 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.038 ms
Progress 16 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.043 ms
Progress 17 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.044 ms
Progress 18 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.045 ms
Progress 19 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.047 ms
Progress 20 / 200 (estimated) checked, result: pass, original: 52.827 ms, scheduled: 0.047 ms
Progress 21 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.104 ms
Progress 22 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.105 ms
Progress 23 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.116 ms
Progress 24 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.118 ms
Progress 25 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.122 ms
Progress 26 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.152 ms
Progress 27 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.151 ms
Progress 28 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.190 ms
Progress 29 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.218 ms
Progress 30 / 200 (estimated) checked, result: pass, original: 146.474 ms, scheduled: 0.230 ms
Progress 31 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.013 ms
Progress 32 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.016 ms
Progress 33 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.016 ms
Progress 34 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.017 ms
Progress 35 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.018 ms
Progress 36 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.019 ms
Progress 37 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.021 ms
Progress 38 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.021 ms
Progress 39 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.024 ms
Progress 40 / 200 (estimated) checked, result: pass, original: 7.804 ms, scheduled: 0.024 ms
Progress 41 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.086 ms
Progress 42 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.083 ms
Progress 43 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.084 ms
Progress 44 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.084 ms
Progress 45 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.084 ms
Progress 46 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.086 ms
Progress 47 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.086 ms
Progress 48 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.086 ms
Progress 49 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.087 ms
Progress 50 / 200 (estimated) checked, result: pass, original: 43.212 ms, scheduled: 0.089 ms
Progress 51 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.057 ms
Progress 52 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.071 ms
Progress 53 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.074 ms
Progress 54 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.076 ms
Progress 55 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.082 ms
Progress 56 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.088 ms
Progress 57 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.090 ms
Progress 58 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.097 ms
Progress 59 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.098 ms
Progress 60 / 200 (estimated) checked, result: pass, original: 65.038 ms, scheduled: 0.101 ms
Progress 61 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.074 ms
Progress 62 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.067 ms
Progress 63 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.066 ms
Progress 64 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.067 ms
Progress 65 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.075 ms
Progress 66 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.078 ms
Progress 67 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.079 ms
Progress 68 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.089 ms
Progress 69 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.096 ms
Progress 70 / 200 (estimated) checked, result: pass, original: 31.743 ms, scheduled: 0.110 ms
Progress 71 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.006 ms
Progress 72 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.005 ms
Progress 73 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.005 ms
Progress 74 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.005 ms
Progress 75 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.006 ms
Progress 76 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.006 ms
Progress 77 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.007 ms
Progress 78 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.007 ms
Progress 79 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.008 ms
Progress 80 / 200 (estimated) checked, result: pass, original: 3.623 ms, scheduled: 0.009 ms
Progress 81 / 196 (estimated) checked, result: pass, original: 0.034 ms, scheduled: 0.002 ms
Progress 82 / 196 (estimated) checked, result: pass, original: 0.034 ms, scheduled: 0.002 ms
Progress 83 / 196 (estimated) checked, result: pass, original: 0.034 ms, scheduled: 0.002 ms
Progress 84 / 196 (estimated) checked, result: pass, original: 0.034 ms, scheduled: 0.003 ms
Progress 85 / 196 (estimated) checked, result: pass, original: 0.034 ms, scheduled: 0.003 ms
Progress 86 / 196 (estimated) checked, result: pass, original: 0.034 ms, scheduled: 0.005 ms
Progress 87 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.007 ms
Progress 88 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.007 ms
Progress 89 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.009 ms
Progress 90 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.009 ms
Progress 91 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.009 ms
Progress 92 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.009 ms
Progress 93 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.010 ms
Progress 94 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.010 ms
Progress 95 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.012 ms
Progress 96 / 196 (estimated) checked, result: pass, original: 6.110 ms, scheduled: 0.013 ms
Progress 97 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.082 ms
Progress 98 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.098 ms
Progress 99 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.101 ms
Progress 100 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.106 ms
Progress 101 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.117 ms
Progress 102 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.120 ms
Progress 103 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.125 ms
Progress 104 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.143 ms
Progress 105 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.147 ms
Progress 106 / 196 (estimated) checked, result: pass, original: 43.432 ms, scheduled: 0.150 ms
Progress 107 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.051 ms
Progress 108 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.048 ms
Progress 109 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.051 ms
Progress 110 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.053 ms
Progress 111 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.061 ms
Progress 112 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.068 ms
Progress 113 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.082 ms
Progress 114 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.086 ms
Progress 115 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.078 ms
Progress 116 / 196 (estimated) checked, result: pass, original: 24.410 ms, scheduled: 0.102 ms
Progress 117 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.067 ms
Progress 118 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.081 ms
Progress 119 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.095 ms
Progress 120 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.112 ms
Progress 121 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.113 ms
Progress 122 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.125 ms
Progress 123 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.129 ms
Progress 124 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.130 ms
Progress 125 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.142 ms
Progress 126 / 196 (estimated) checked, result: pass, original: 32.747 ms, scheduled: 0.142 ms
Progress 127 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.051 ms
Progress 128 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.059 ms
Progress 129 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.059 ms
Progress 130 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.059 ms
Progress 131 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.060 ms
Progress 132 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.060 ms
Progress 133 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.063 ms
Progress 134 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.065 ms
Progress 135 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.071 ms
Progress 136 / 196 (estimated) checked, result: pass, original: 31.283 ms, scheduled: 0.077 ms
Progress 137 / 191 (estimated) checked, result: pass, original: 0.000 ms, scheduled: 0.002 ms
Progress 138 / 191 (estimated) checked, result: pass, original: 0.000 ms, scheduled: 0.002 ms
Progress 139 / 191 (estimated) checked, result: pass, original: 0.000 ms, scheduled: 0.002 ms
Progress 140 / 191 (estimated) checked, result: pass, original: 0.000 ms, scheduled: 0.002 ms
Progress 141 / 191 (estimated) checked, result: pass, original: 0.000 ms, scheduled: 0.002 ms
Progress 142 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.010 ms
Progress 143 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.016 ms
Progress 144 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.021 ms
Progress 145 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.025 ms
Progress 146 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.029 ms
Progress 147 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.030 ms
Progress 148 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.030 ms
Progress 149 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.033 ms
Progress 150 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.033 ms
Progress 151 / 191 (estimated) checked, result: pass, original: 0.338 ms, scheduled: 0.033 ms
Progress 152 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 153 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 154 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 155 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 156 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 157 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 158 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 159 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 160 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 161 / 191 (estimated) checked, result: pass, original: 0.177 ms, scheduled: 0.005 ms
Progress 162 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.049 ms
Progress 163 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.048 ms
Progress 164 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.053 ms
Progress 165 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.055 ms
Progress 166 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.056 ms
Progress 167 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.056 ms
Progress 168 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.057 ms
Progress 169 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.059 ms
Progress 170 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.062 ms
Progress 171 / 191 (estimated) checked, result: pass, original: 70.768 ms, scheduled: 0.064 ms
Progress 172 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.058 ms
Progress 173 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.055 ms
Progress 174 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.059 ms
Progress 175 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.063 ms
Progress 176 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.063 ms
Progress 177 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.072 ms
Progress 178 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.074 ms
Progress 179 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.080 ms
Progress 180 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.074 ms
Progress 181 / 191 (estimated) checked, result: pass, original: 31.186 ms, scheduled: 0.085 ms
Progress 182 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.048 ms
Progress 183 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.044 ms
Progress 184 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.050 ms
Progress 185 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.052 ms
Progress 186 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.065 ms
Progress 187 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.067 ms
Progress 188 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.068 ms
Progress 189 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.077 ms
Progress 190 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.078 ms
Progress 191 / 191 (estimated) checked, result: pass, original: 24.627 ms, scheduled: 0.082 ms
Validation finished! Total time spent: 168.474 sec.
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
[GitHub] [tvm] yelite commented on a diff in pull request #13459: [MetaSchedule] Enhance Database Validation Script
Posted by GitBox <gi...@apache.org>.
yelite commented on code in PR #13459:
URL: https://github.com/apache/tvm/pull/13459#discussion_r1028655821
##########
python/tvm/meta_schedule/testing/validate_database.py:
##########
@@ -124,158 +252,495 @@ def default_input_generator(mod: IRModule) -> List[tvm.nd.NDArray]:
return inputs
-@register_func("tvm.meta_schedule.testing.default_check_metric")
-def default_check_metric(a: List[tvm.nd.NDArray], b: List[tvm.nd.NDArray]) -> bool:
- assert len(a) == len(b), "Different number of outputs from two modules"
- for i, _ in enumerate(a):
- if not np.allclose(a[i].numpy(), b[i].numpy(), rtol=1e-3, atol=2e-3):
- return False
- return True
+def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
+ """Convert a list of TVM NDArray to a list of numpy array
+ Parameters
+ ----------
+ a : List[tvm.nd.NDArray]
+ The list of TVM NDArray to be converted
-def validate_correctness(
- original_mod: IRModule, # compiled for "baseline_target"
- scheduled_mod: IRModule, # compiled for "target"
- *,
- baseline_target: Target,
- target: Target,
- dev_type: str,
- rpc_config: ms.runner.RPCConfig,
- f_input_generator: Union[
- str, Callable[[IRModule], List[tvm.nd.NDArray]]
- ] = default_input_generator,
- f_check_metric: Union[
- str, Callable[[tvm.nd.NDArray, tvm.nd.NDArray], bool]
- ] = default_check_metric,
-) -> bool:
- """Function to validate the correctness of a scheduled module.
+ Returns
+ -------
+ b : List[np.ndarray]
+ The list of numpy array
+ """
+ assert a is not None, "Empty result cannot be converted to numpy"
+ return [x.numpy() for x in a]
+
+
+def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
+ """Convert a list of numpy array to a list of TVM NDArray
Parameters
----------
- original_mod : IRModule
- The original module to be compiled.
- scheduled_mod : IRModule
- The scheduled module to be compiled.
- baseline_target : Target
- The baseline target to compile the original module.
- target : Target
- The target to compile the scheduled module.
- dev_type : str
- The device type to run the module via rpc.
- rpc_config : RPCConfig
- The RPCConfig to run the scheduled module.
- f_input_generator : Union[str, Callable]
- The function to generate the input data.
- f_check_metric : Union[str, Callable]
- The function to check the metric.
+ a : List[np.ndarray]
+ The list of numpy array to be converted.
Returns
-------
- result : bool
- The result of the validation.
+ b : List[tvm.nd.NDArray]
+ The list of TVM NDArray.
"""
+ assert a is not None, "Empty result cannot be converted to TVM NDArray"
+ return [tvm.nd.array(x) for x in a]
- def to_numpy(a: List[tvm.nd.NDArray]) -> List[np.ndarray]:
- """Convert a list of TVM NDArray to a list of numpy array"""
- assert a is not None, "Empty result cannot be converted to numpy"
- return [x.numpy() for x in a]
-
- def to_tvm_ndarray(a: List[np.ndarray]) -> List[tvm.nd.NDArray]:
- """Convert a list of numpy array to a list of TVM NDArray"""
- assert a is not None, "Empty result cannot be converted to TVM NDArray"
- return [tvm.nd.array(x) for x in a]
-
- def build_and_run(mod: IRModule, target: Target, dev_type: str) -> np.ndarray:
- """Build and run the module on the target device."""
- rt_mod = tvm.build(mod, target=target)
- return run_module_via_rpc(
- rpc_config=rpc_config,
- lib=rt_mod,
- dev_type=dev_type,
- args={i: v for i, v in enumerate(inputs)}, # pylint: disable=unnecessary-comprehension
- continuation=create_calculator(backend="tir"),
- backend="tir",
- )
- # fetch functions & prepare inputs
- if isinstance(f_input_generator, str):
- f_input_generator = get_global_func(f_input_generator)
- if isinstance(f_check_metric, str):
- f_check_metric = get_global_func(f_check_metric)
- inputs = to_numpy(f_input_generator(original_mod)) # type: ignore
- # build & run original result
- original_res = to_numpy(build_and_run(original_mod, target=baseline_target, dev_type="cpu"))
- scheduled_res = to_numpy(build_and_run(scheduled_mod, target=target, dev_type=dev_type))
- # check metric
- if f_check_metric(to_tvm_ndarray(original_res), to_tvm_ndarray(scheduled_res)): # type: ignore
- return True
- else:
- print(
- ("\n\n").join(
+def is_failed_record(record: ms.database.TuningRecord) -> bool:
+ """Check if a tuning record is failed.
+
+ Parameters
+ ----------
+ record : TuningRecord
+ The tuning record to check.
+
+ Returns
+ -------
+ is_failed : bool
+ """
+ return len(record.run_secs) == 1 and record.run_secs[0] == 1e9
+
+
+def print_with_counter_func(counter: int, total: int) -> Callable:
+ """Print with counter
+
+ Parameters
+ ----------
+ counter : int
+ The counter to print with.
+ total : int
+ The total number of items to print with.
+
+ Returns
+ -------
+ print_result : Callable
+ The print result function.
+ """
+
+ def print_result(
+ result: str,
+ *,
+ original_mod: IRModule = None,
+ scheduled_mod: IRModule = None,
+ inputs: List[np.ndarray] = None,
+ original_res: List[np.ndarray] = None,
+ scheduled_res: List[np.ndarray] = None,
+ original_run_secs: List[float] = None,
+ scheduled_run_secs: List[float] = None,
+ exception: Exception = None,
+ trace: str = None,
+ ) -> None:
+ """Print the validation result."""
+ status = f"Progress {counter: 6d} / {total: 6d} (estimated) checked, result: {result:>10}, "
+
+ if result in ["pass", "wrong answer"]:
+ status += (
+ f"original: {mean(original_run_secs) * 1e3: 10.3f} ms, "
+ f"scheduled: {mean(scheduled_run_secs) * 1e3: 10.3f} ms"
+ )
+
+ output = [status]
+ if result not in ["pass", "skip"]:
+ output.extend(
[
- "Validation failed!",
- "Original Result:" + DELIMITOR + str(original_res),
- "Scheduled Result:" + DELIMITOR + str(scheduled_res),
- "Input:" + DELIMITOR + str(inputs),
"Original IRModule:" + DELIMITOR + original_mod.script(),
"Scheduled IRModule:" + DELIMITOR + scheduled_mod.script(),
+ "Trace" + DELIMITOR + str(trace),
]
)
+ if result == "wrong answer":
+ output.extend(
+ [
+ "Input:" + DELIMITOR + str(inputs),
+ "Original Result:" + DELIMITOR + str(original_res),
+ "Scheduled Result:" + DELIMITOR + str(scheduled_res),
+ "Max Diff:"
+ + DELIMITOR
+ + str(
+ [
+ np.max(np.abs(original_res[i] - scheduled_res[i]))
+ for i in range(len(original_res))
+ ]
+ )
+ + "\n",
+ ]
+ )
+ elif result == "exception":
+ output.extend(["Exception:" + DELIMITOR + str(exception) + "\n"])
+ else:
+ raise ValueError(f"Unknown result: {result}")
+ print("\n\n".join(output))
+
+ return print_result
+
+
+def make_alloc_arg_and_check(
+ inputs: List[np.ndarray],
+ original_mod: IRModule,
+ scheduled_mod: IRModule,
+ trace: str,
+ original_res: List[np.ndarray],
+ original_run_secs: List[float],
+ print_result: Callable,
+) -> Tuple[Callable, Callable]:
+ """Make alloc_arg and check functions for the given inputs and collect results.
+
+ Parameters
+ ----------
+ inputs : List[np.ndarray]
+ The inputs to the two modules.
+ original_mod : IRModule
+ The original IRModule.
+ scheduled_mod : IRModule
+ The scheduled IRModule.
+ trace : str
+ The trace of the scheduled IRModule.
+ original_res : List[np.ndarray]
+ The original results.
+ original_run_secs : List[float]
+ The original run times.
+ print_result : Callable
+ The print result function.
+
+ Returns
+ -------
+ f_with_args_alloc_argument : Callable
+ The function to allocate arguments.
+
+ f_with_args_run_evaluator : Callable
+ The function to run evaluator.
+ """
+
+ def f_with_args_alloc_argument(
+ # pylint: disable=unused-argument
+ session: tvm.rpc.RPCSession,
+ device: tvm.runtime.Device,
+ args_info: ms.runner.rpc_runner.T_ARG_INFO_JSON_OBJ_LIST,
+ alloc_repeat: int,
+ # pylint: enable=unused-argument
+ ) -> List[ms.runner.rpc_runner.T_ARGUMENT_LIST]:
+ """Allocate arguments using the given inputs.
+
+ Parameters
+ ----------
+ session : RPCSession
+ The RPC session.
+ device : Device
+ The device.
+ args_info : T_ARG_INFO_JSON_OBJ_LIST
+ argument information.
+ alloc_repeat : int
+ The number of times to repeat the allocation.
+
+ Returns
+ -------
+ args_list : List[T_ARGUMENT_LIST]
+ The list of argument lists.
+ """
+ return [[tvm.nd.array(arg, device=device) for arg in inputs] for _ in range(alloc_repeat)]
+
+ def f_with_args_run_evaluator(
+ session: tvm.rpc.RPCSession, # pylint: disable=unused-argument
+ rt_mod: tvm.runtime.Module,
+ device: tvm.runtime.Device,
+ evaluator_config: ms.runner.EvaluatorConfig,
+ repeated_args: List[ms.runner.rpc_runner.T_ARGUMENT_LIST],
+ ) -> List[float]:
+ """With args function to run the evaluator
+
+ Parameters
+ ----------
+ session : tvm.rpc.RPCSession
+ The RPC session
+ rt_mod: Module
+ The runtime module
+ device: Device
+ The device to run the evaluator
+ evaluator_config: EvaluatorConfig
+ The evaluator config
+ repeated_args: List[T_ARGUMENT_LIST]
+ The repeated arguments
+
+ Returns
+ -------
+ costs: List[float]
+ The evaluator results
+ """
+ evaluator = rt_mod.time_evaluator(
+ func_name=rt_mod.entry_name,
+ dev=device,
+ number=evaluator_config.number,
+ repeat=evaluator_config.repeat,
+ min_repeat_ms=evaluator_config.min_repeat_ms,
+ f_preproc="cache_flush_cpu_non_first_arg"
+ if evaluator_config.enable_cpu_cache_flush
+ else "",
+ )
+
+ repeated_costs: List[List[float]] = []
+ for args in repeated_args:
+ device.sync()
+ profile_result = evaluator(*args)
+ repeated_costs.append(profile_result.results)
+ costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
+
+ assert len(repeated_args) == 1, "Only support one set of arguments"
+ scheduled_res = [arg.numpy() for arg in repeated_args[0]] # type: ignore
+ # fetch comparison function
+ passed = check_and_run(
+ ARGS.check_metric_func,
+ to_tvm_ndarray(original_res),
+ to_tvm_ndarray(scheduled_res),
+ )
+
+ print_result(
+ result="pass" if passed else "wrong answer",
+ original_mod=original_mod,
+ scheduled_mod=scheduled_mod,
+ trace=trace,
+ inputs=inputs,
+ original_res=original_res,
+ scheduled_res=scheduled_res,
+ original_run_secs=original_run_secs,
+ scheduled_run_secs=costs,
)
- return False
+
+ return costs
+
+ return f_with_args_alloc_argument, f_with_args_run_evaluator
+
+
+def local_build_and_run(
+ mod: IRModule,
+ target: Target,
+ device: tvm.runtime.Device,
+ inputs: List[np.ndarray],
+) -> Tuple[List[np.ndarray], List[float]]:
+ """Build and run the module locally.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to build and run
+ target: Target
+ The target to build the module
+ device: Device
+ The device to run the module
+ inputs: List[np.ndarray]
+ The inputs to run the module
+
+ Returns
+ -------
+ res: List[np.ndarray]
+ The results of running the module
+ run_secs: List[float]
+ The running time of running the module
+ """
+ # potential memory leak https://github.com/apache/tvm/issues/11096
+ lib = tvm.build(mod, target=target)
+ tvm_inputs = [tvm.nd.array(inp, device=device) for inp in inputs]
+ device.sync()
+ func = lib.time_evaluator(lib.entry_name, dev=device, number=ARGS.number, repeat=ARGS.repeat)
+ benchmark_res = func(*tvm_inputs)
+ device.sync()
+ return [arg.numpy() for arg in tvm_inputs], list(benchmark_res.results)
+
+
+def _check_builder_result(builder_result: ms.builder.BuilderResult) -> None:
+ """Check if the builder result is defined.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ """
+ assert builder_result.error_msg is None, "Builder failed: " + str(
+ builder_result.error_msg if builder_result.error_msg else "Empty error message"
+ )
+
+
+def _apply_trace(mod: IRModule, trace: Trace) -> IRModule:
+ """Apply the trace to the module.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to apply the trace to
+ trace: Trace
+ The trace to apply
+
+ Returns
+ -------
+ mod: IRModule
+ The module with the trace applied
+ """
+ sch = Schedule(mod)
+ trace.apply_to_schedule(sch, remove_postproc=False)
+ return sch.mod
+
+
+def _build_all_mods(
+ mods: List[IRModule], builder: ms.builder.Builder, target: Target
+) -> List[ms.builder.BuilderResult]:
+ """Build all the modules.
+
+ Parameters
+ ----------
+ mods: List[IRModule]
+ The modules to build
+ builder: Builder
+ The builder to build the modules
+ target: Target
+ The target to build the modules
+
+ Returns
+ -------
+ builder_results: List[BuilderResult]
+ The builder results
+ """
+ builder_results = builder.build([ms.builder.BuilderInput(mod, target) for mod in mods])
+ assert len(builder_results) == len(
+ mods
+ ), f"Unexpected number of build results, expected {len(mods)} got {len(builder_results)}"
+ return builder_results
+
+
+def _run_single_mod(
+ builder_result: ms.builder.BuilderResult,
+ runner: ms.runner.Runner,
+ dev_type: str,
+) -> None:
+ """Run a single module.
+
+ Parameters
+ ----------
+ builder_result: BuilderResult
+ The builder result
+ runner: Runner
+ The runner to run the module
+ dev_type: str
+ The device type
+ """
+ runner_futures = runner.run(
+ # arginfo is not used in this case so we can pass an empty list
+ [ms.runner.RunnerInput(builder_result.artifact_path, device_type=dev_type, args_info=[])]
+ )
+ assert (
+ len(runner_futures) == 1
+ ), f"Unexpected number of runner futures, expected 1 got {len(runner_futures)}"
+ (runner_future,) = runner_futures # pylint: disable=unbalanced-tuple-unpacking
+ runner_res = runner_future.result()
+ assert runner_res.error_msg is None, "Runner failed: " + (
+ runner_res.error_msg if runner_res.error_msg else "Empty error message"
+ )
def main():
"""Main function"""
describe()
- database = ms.database.create(work_dir=ARGS.work_dir)
- target = ARGS.target
- if target.kind.name == "llvm":
- dev_type = "cpu"
- elif target.kind.name == "cuda":
- dev_type = "cuda"
- else:
- raise RuntimeError(f"Unsupported target kind: {target.kind.name}")
- records = database.get_all_tuning_records()
with ms.Profiler() as profiler:
- for i, record in enumerate(records):
- scope_name = f"validate #{i}"
- with profiler.timeit(scope_name):
- original_mod = record.workload.mod
- sch = Schedule(original_mod)
- record.trace.apply_to_schedule(sch=sch, remove_postproc=False)
- scheduled_mod = sch.mod
- is_success = False
+ # initialize
+ target = ARGS.target
+ dev_type = get_device_type(target)
+ builder = ms.builder.LocalBuilder()
+ database = ms.database.create(work_dir=ARGS.work_dir)
+
+ # collect records
+ with profiler.timeit("collect records"):
+ records = database.get_all_tuning_records()
+ total = len(records)
+ print(
+ f"Total {total} records to be validated. "
+ f"Collected in {float(profiler.get()['collect records']): 3.3f} sec."
+ )
+
+ # collect unique original TIR
+ with profiler.timeit("deduplicate records"):
+ workloads = set()
+ for record in records:
+ workloads.add(OriginalModule(record.workload.mod))
+ print(
+ f"Total {len(workloads)} unique original TIR to validate. "
+ f"Deduplicated in {float(profiler.get()['deduplicate records']): 3.3f} sec."
+ )
+ if ARGS.top_k < 10**9:
+ print(f"Top {ARGS.top_k} records for each original TIR will be validated.")
+ total = len(workloads) * ARGS.top_k
+ print()
+
+ # validate correctness
+ counter = 0
+ for item in workloads:
+ original_mod = item.mod
+ records = database.get_top_k(
+ workload=database.commit_workload(original_mod), top_k=ARGS.top_k
+ )
+ if len(records) < ARGS.top_k:
+ total -= ARGS.top_k - len(records)
+ inputs = to_numpy(check_and_run(ARGS.input_generator_func, original_mod))
+ original_res, original_run_secs = local_build_and_run(
+ original_mod,
+ target=ARGS.baseline_target,
+ inputs=inputs,
+ device=get_runtime_device(ARGS.baseline_target),
+ )
+ scheduled_mods = [_apply_trace(original_mod, record.trace) for record in records]
+ builder_results = _build_all_mods(scheduled_mods, builder, target) # type: ignore
+ for i, record in enumerate(records):
+ counter += 1
+ print_result = print_with_counter_func(counter=counter, total=total)
+ if is_failed_record(record):
+ # skip failed records where run_secs is 1e9
+ # these records are only negative samples for cost model
+ print_result(result="skip")
+ continue
try:
- is_success = validate_correctness(
- original_mod=original_mod,
- scheduled_mod=scheduled_mod,
- target=target,
- baseline_target=ARGS.baseline_target,
- dev_type=dev_type,
+ # prepare scheduled module
+ scheduled_mod = scheduled_mods[i]
+ # check build result
+ builder_result = builder_results[i]
+ _check_builder_result(builder_result)
+ # fetch functions
+ (
+ f_with_args_alloc_argument,
+ f_with_args_run_evaluator,
+ ) = make_alloc_arg_and_check(
+ inputs,
+ original_mod,
+ scheduled_mod,
+ str(record.trace),
+ original_res=original_res,
+ original_run_secs=original_run_secs,
+ print_result=print_result,
+ )
+ # create rpc runner
+ runner = ms.runner.RPCRunner(
Review Comment:
Maybe just make the `--rpc-host` arg and its friends optional and use local runner if `--rpc-host` is not given by user
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org