You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/04/03 03:36:09 UTC

[tvm] branch main updated: [TVMC] Allow direct numpy inputs to run_module (#7788)

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c1a2b5  [TVMC] Allow direct numpy inputs to run_module (#7788)
1c1a2b5 is described below

commit 1c1a2b59e3cc24aff66820da4987340a135fba66
Author: CircleSpin <2k...@gmail.com>
AuthorDate: Fri Apr 2 23:35:43 2021 -0400

    [TVMC] Allow direct numpy inputs to run_module (#7788)
    
    * progress, graph params need to figure out
    
    * black and lint
    
    * change np.load(inputs_file) to happen in drive_run
    
    * make inputs optional
    
    Co-authored-by: Jocelyn <jo...@pop-os.localdomain>
---
 python/tvm/driver/tvmc/runner.py        | 29 ++++++++++++++++-------------
 tests/python/driver/tvmc/test_runner.py |  4 +++-
 2 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index b4c4e75..d69e71f 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -107,12 +107,17 @@ def drive_run(args):
 
     rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker)
 
+    try:
+        inputs = np.load(args.inputs) if args.inputs else {}
+    except IOError as ex:
+        raise TVMCException("Error loading inputs file: %s" % ex)
+
     outputs, times = run_module(
         args.FILE,
         rpc_hostname,
         rpc_port,
         args.rpc_key,
-        inputs_file=args.inputs,
+        inputs=inputs,
         device=args.device,
         fill_mode=args.fill_mode,
         repeat=args.repeat,
@@ -221,7 +226,7 @@ def generate_tensor_data(shape, dtype, fill_mode):
     return tensor
 
 
-def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
+def make_inputs_dict(shape_dict, dtype_dict, inputs=None, fill_mode="random"):
     """Make the inputs dictionary for a graph.
 
     Use data from 'inputs' where specified. For input tensors
@@ -230,13 +235,13 @@ def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
 
     Parameters
     ----------
-    inputs_file : str
-        Path to a .npz file containing the inputs.
     shape_dict : dict
         Shape dictionary - {input_name: tuple}.
     dtype_dict : dict
         dtype dictionary - {input_name: dtype}.
-    fill_mode : str
+    inputs : dict, optional
+        A dictionary that maps input names to numpy values.
+    fill_mode : str, optional
         The fill-mode to use when generating tensor data.
         Can be either "zeros", "ones" or "random".
 
@@ -247,10 +252,8 @@ def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
     """
     logger.debug("creating inputs dict")
 
-    try:
-        inputs = np.load(inputs_file) if inputs_file else {}
-    except IOError as ex:
-        raise TVMCException("Error loading inputs file: %s" % ex)
+    if inputs is None:
+        inputs = {}
 
     # First check all the keys in inputs exist in the graph
     for input_name in inputs:
@@ -291,7 +294,7 @@ def run_module(
     port=9090,
     rpc_key=None,
     device=None,
-    inputs_file=None,
+    inputs=None,
     fill_mode="random",
     repeat=1,
     profile=False,
@@ -316,8 +319,8 @@ def run_module(
     device: str, optional
         the device (e.g. "cpu" or "gpu") to be targeted by the RPC
         session, local or remote).
-    inputs_file : str, optional
-        Path to an .npz file containing the inputs.
+    inputs : dict, optional
+        A dictionary that maps input names to numpy values.
     fill_mode : str, optional
         The fill-mode to use when generating data for input tensors.
         Valid options are "zeros", "ones" and "random".
@@ -379,7 +382,7 @@ def run_module(
         module.load_params(params)
 
         shape_dict, dtype_dict = get_input_info(graph, params)
-        inputs_dict = make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode)
+        inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode)
 
         logger.debug("setting inputs to the module")
         module.set_input(**inputs_dict)
diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py
index 5fdf58f..366a6df 100644
--- a/tests/python/driver/tvmc/test_runner.py
+++ b/tests/python/driver/tvmc/test_runner.py
@@ -73,9 +73,11 @@ def test_run_tflite_module__with_profile__valid_input(
     # some CI environments wont offer TFLite, so skip in case it is not present
     pytest.importorskip("tflite")
 
+    inputs = np.load(imagenet_cat)
+
     outputs, times = tvmc.run(
         tflite_compiled_module_as_tarfile,
-        inputs_file=imagenet_cat,
+        inputs=inputs,
         hostname=None,
         device="cpu",
         profile=True,