You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/19 12:45:48 UTC
[GitHub] [tvm] pfk-beta commented on issue #11058: [Bug] tuning cannot send task but tracker - device connection is ok
pfk-beta commented on issue #11058:
URL: https://github.com/apache/tvm/issues/11058#issuecomment-1102600200
tuning script:
```
import os
import sys
import argparse
import numpy as np
from tcl_scripts.common import load_model
import tvm
from tvm import relay, autotvm
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.utils import tempdir
import tvm.contrib.graph_executor as runtime
from tvm.contrib import ndk
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore") # Change the filter in this process
os.environ["PYTHONWARNINGS"] = "ignore" # Also affect subprocesses
def create_tuning_log_name(args):
target = args.target.replace(' ', '_')
return f"{args.model_name}__{args.rpc_key}__{target}__{args.n_trials}.log"
def get_task_tuner(task, tuner_name):
if tuner_name == "xgb" or tuner_name == "xgb-rank":
tuner = XGBTuner(task, loss_type="rank")
elif tuner_name == "ga":
tuner = GATuner(task, pop_size=50)
elif tuner_name == "random":
tuner = RandomTuner(task)
elif tuner_name == "gridsearch":
tuner = GridSearchTuner(task)
else:
raise ValueError("Invalid tuner: " + tuner_name)
return tuner
def tune_model(args, mod, params, target):
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func="ndk"),
runner=autotvm.RPCRunner(
args.rpc_key,
host=args.rpc_tracker,
port=args.rpc_port,
number=args.runner_number,
repeat=args.runner_repeat,
timeout=60,
)
)
tasks = autotvm.task.extract_from_program(
mod["main"],
target=target,
params=params,
ops=(relay.op.get("nn.conv2d"),),
) # TODO: parametrize ops which will be tuned
if not tasks:
print("No tasks...")
task_log_filename = None
for i, task in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
task_log_filename = args.output_tuning_log + f".tmp"
task_tuner = get_task_tuner(task, "xgb")
if os.path.isfile(task_log_filename):
task_tuner.load_history(
autotvm.record.load_from_file(task_log_filename))
n_trials = min(args.n_trials, len(task.config_space))
task_tuner.tune(
n_trial=n_trials,
early_stopping=args.early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trials, prefix=prefix),
autotvm.callback.log_to_file(task_log_filename),
],
)
if task_log_filename:
# pick best records to a cache file
autotvm.record.pick_best(task_log_filename, args.output_tuning_log)
os.remove(task_log_filename)
def main(args):
target = tvm.target.Target(args.target, host=args.target_host)
mod, params = load_model(args)
tune_model(args, mod, params, target)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', required=True,
help="How do you name this model? "
"This value is used for generating name, "
"if output_tuning_log is not specified. No spaces.")
parser.add_argument('--input_model', required=True,
help="Fullpath to model")
parser.add_argument('--input_name', required=True,
help="Name of input node")
parser.add_argument('--input_shape', required=True,
help="Shape of input node, coma-separated, no spaces.")
parser.add_argument('--input_dtype',
default="float32", required=True,
help="Dtype of input node")
parser.add_argument('--rpc_tracker',
required=True,
help="IP address of RPC tracker")
parser.add_argument('--rpc_port',
type=int, required=True,
help="IP port of RPC tracker")
parser.add_argument('--rpc_key',
required=True,
help="Key of RPC tracker")
parser.add_argument('--output_tuning_log',
default=None,
help="Where to save tuning output to be used for benchmark.")
parser.add_argument('--runner_number', type=int, default=4,
help="Number of separate benchmark runs")
parser.add_argument('--runner_repeat', type=int, default=3,
help="Number of inference in one run")
parser.add_argument('--n_trials',
type=int, default=10,
help="Number of trials. Must be larnger than 1. Typically 2000...")
parser.add_argument('--early_stopping',
type=int, default=400,
help='Early stopping for tuning. Ignore when no tuning.')
parser.add_argument('--target', default="opencl", help="")
parser.add_argument('--target_host',
default="llvm -mtriple=aarch64-linux-gnu", help="")
args = parser.parse_args()
args.input_shape = eval(args.input_shape)
if not args.output_tuning_log:
args.output_tuning_log = create_tuning_log_name(args)
main(args)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org
For queries about this service, please contact Infrastructure at:
users@infra.apache.org