You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/04/19 12:45:48 UTC

[GitHub] [tvm] pfk-beta commented on issue #11058: [Bug] tuning cannot send task but tracker - device connection is ok

pfk-beta commented on issue #11058:
URL: https://github.com/apache/tvm/issues/11058#issuecomment-1102600200

   tuning script:
   ```
   import os
   import sys
   import argparse
   
   import numpy as np
   from tcl_scripts.common import load_model
   
   import tvm
   from tvm import relay, autotvm
   import tvm.relay.testing
   from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
   from tvm.contrib.utils import tempdir
   import tvm.contrib.graph_executor as runtime
   from tvm.contrib import ndk
   
   if not sys.warnoptions:
       import warnings
       warnings.simplefilter("ignore") # Change the filter in this process
       os.environ["PYTHONWARNINGS"] = "ignore" # Also affect subprocesses
   
   
   def create_tuning_log_name(args):
       target = args.target.replace(' ', '_')
   
       return f"{args.model_name}__{args.rpc_key}__{target}__{args.n_trials}.log"
   
   
   def get_task_tuner(task, tuner_name):
       if tuner_name == "xgb" or tuner_name == "xgb-rank":
           tuner = XGBTuner(task, loss_type="rank")
       elif tuner_name == "ga":
           tuner = GATuner(task, pop_size=50)
       elif tuner_name == "random":
           tuner = RandomTuner(task)
       elif tuner_name == "gridsearch":
           tuner = GridSearchTuner(task)
       else:
           raise ValueError("Invalid tuner: " + tuner_name)
       return tuner
   
   
   def tune_model(args, mod, params, target):
       measure_option = autotvm.measure_option(
           builder=autotvm.LocalBuilder(build_func="ndk"),
           runner=autotvm.RPCRunner(
               args.rpc_key,
               host=args.rpc_tracker,
               port=args.rpc_port,
               number=args.runner_number,
               repeat=args.runner_repeat,
               timeout=60,
           )
       )
       
       tasks = autotvm.task.extract_from_program(
           mod["main"],
           target=target,
           params=params,
           ops=(relay.op.get("nn.conv2d"),),
       ) # TODO: parametrize ops which will be tuned
   
       if not tasks:
           print("No tasks...")
   
       task_log_filename = None
       for i, task in enumerate(reversed(tasks)):
           prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
           task_log_filename = args.output_tuning_log + f".tmp"
   
           task_tuner = get_task_tuner(task, "xgb")
   
           if os.path.isfile(task_log_filename):
               task_tuner.load_history(
                   autotvm.record.load_from_file(task_log_filename))
   
           n_trials = min(args.n_trials, len(task.config_space))
           task_tuner.tune(
               n_trial=n_trials,
               early_stopping=args.early_stopping,
               measure_option=measure_option,
               callbacks=[
                   autotvm.callback.progress_bar(n_trials, prefix=prefix),
                   autotvm.callback.log_to_file(task_log_filename),
               ],
           )
   
       if task_log_filename:
           # pick best records to a cache file
           autotvm.record.pick_best(task_log_filename, args.output_tuning_log)
           os.remove(task_log_filename)
   
   
   
   def main(args):
       target = tvm.target.Target(args.target, host=args.target_host)
   
       mod, params = load_model(args)
   
       tune_model(args, mod, params, target)
   
   
   if __name__ == "__main__":
       parser = argparse.ArgumentParser()
   
       parser.add_argument('--model_name', required=True,
           help="How do you name this model? "
           "This value is used for generating name, "
           "if output_tuning_log is not specified. No spaces.")
       parser.add_argument('--input_model', required=True,
           help="Fullpath to model")
       parser.add_argument('--input_name', required=True,
           help="Name of input node")
       parser.add_argument('--input_shape', required=True,
           help="Shape of input node, coma-separated, no spaces.")
       parser.add_argument('--input_dtype',
           default="float32", required=True,
           help="Dtype of input node")
   
       parser.add_argument('--rpc_tracker', 
           required=True,
           help="IP address of RPC tracker")
       parser.add_argument('--rpc_port', 
           type=int, required=True,
           help="IP port of RPC tracker")
       parser.add_argument('--rpc_key', 
           required=True,
           help="Key of RPC tracker")
   
       parser.add_argument('--output_tuning_log', 
           default=None,
           help="Where to save tuning output to be used for benchmark.")
       parser.add_argument('--runner_number', type=int, default=4, 
           help="Number of separate benchmark runs")
       parser.add_argument('--runner_repeat', type=int, default=3,
           help="Number of inference in one run")
       parser.add_argument('--n_trials', 
           type=int, default=10, 
           help="Number of trials. Must be larnger than 1. Typically 2000...")
       parser.add_argument('--early_stopping', 
           type=int, default=400,
           help='Early stopping for tuning. Ignore when no tuning.')
   
       parser.add_argument('--target', default="opencl", help="")
       parser.add_argument('--target_host',
           default="llvm -mtriple=aarch64-linux-gnu", help="")
   
       args = parser.parse_args()
       args.input_shape = eval(args.input_shape)
       if not args.output_tuning_log:
           args.output_tuning_log = create_tuning_log_name(args)
   
       main(args)
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org