You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/04/03 03:36:09 UTC
[tvm] branch main updated: [TVMC] Allow direct numpy inputs to
run_module (#7788)
This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1c1a2b5 [TVMC] Allow direct numpy inputs to run_module (#7788)
1c1a2b5 is described below
commit 1c1a2b59e3cc24aff66820da4987340a135fba66
Author: CircleSpin <2k...@gmail.com>
AuthorDate: Fri Apr 2 23:35:43 2021 -0400
[TVMC] Allow direct numpy inputs to run_module (#7788)
* progress, graph params need to figure out
* black and lint
* change np.load(inputs_file) to happen in drive_run
* make inputs optional
Co-authored-by: Jocelyn <jo...@pop-os.localdomain>
---
python/tvm/driver/tvmc/runner.py | 29 ++++++++++++++++-------------
tests/python/driver/tvmc/test_runner.py | 4 +++-
2 files changed, 19 insertions(+), 14 deletions(-)
diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index b4c4e75..d69e71f 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -107,12 +107,17 @@ def drive_run(args):
rpc_hostname, rpc_port = common.tracker_host_port_from_cli(args.rpc_tracker)
+ try:
+ inputs = np.load(args.inputs) if args.inputs else {}
+ except IOError as ex:
+ raise TVMCException("Error loading inputs file: %s" % ex)
+
outputs, times = run_module(
args.FILE,
rpc_hostname,
rpc_port,
args.rpc_key,
- inputs_file=args.inputs,
+ inputs=inputs,
device=args.device,
fill_mode=args.fill_mode,
repeat=args.repeat,
@@ -221,7 +226,7 @@ def generate_tensor_data(shape, dtype, fill_mode):
return tensor
-def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
+def make_inputs_dict(shape_dict, dtype_dict, inputs=None, fill_mode="random"):
"""Make the inputs dictionary for a graph.
Use data from 'inputs' where specified. For input tensors
@@ -230,13 +235,13 @@ def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
Parameters
----------
- inputs_file : str
- Path to a .npz file containing the inputs.
shape_dict : dict
Shape dictionary - {input_name: tuple}.
dtype_dict : dict
dtype dictionary - {input_name: dtype}.
- fill_mode : str
+ inputs : dict, optional
+ A dictionary that maps input names to numpy values.
+ fill_mode : str, optional
The fill-mode to use when generating tensor data.
Can be either "zeros", "ones" or "random".
@@ -247,10 +252,8 @@ def make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode):
"""
logger.debug("creating inputs dict")
- try:
- inputs = np.load(inputs_file) if inputs_file else {}
- except IOError as ex:
- raise TVMCException("Error loading inputs file: %s" % ex)
+ if inputs is None:
+ inputs = {}
# First check all the keys in inputs exist in the graph
for input_name in inputs:
@@ -291,7 +294,7 @@ def run_module(
port=9090,
rpc_key=None,
device=None,
- inputs_file=None,
+ inputs=None,
fill_mode="random",
repeat=1,
profile=False,
@@ -316,8 +319,8 @@ def run_module(
device: str, optional
the device (e.g. "cpu" or "gpu") to be targeted by the RPC
session, local or remote).
- inputs_file : str, optional
- Path to an .npz file containing the inputs.
+ inputs : dict, optional
+ A dictionary that maps input names to numpy values.
fill_mode : str, optional
The fill-mode to use when generating data for input tensors.
Valid options are "zeros", "ones" and "random".
@@ -379,7 +382,7 @@ def run_module(
module.load_params(params)
shape_dict, dtype_dict = get_input_info(graph, params)
- inputs_dict = make_inputs_dict(inputs_file, shape_dict, dtype_dict, fill_mode)
+ inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode)
logger.debug("setting inputs to the module")
module.set_input(**inputs_dict)
diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py
index 5fdf58f..366a6df 100644
--- a/tests/python/driver/tvmc/test_runner.py
+++ b/tests/python/driver/tvmc/test_runner.py
@@ -73,9 +73,11 @@ def test_run_tflite_module__with_profile__valid_input(
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")
+ inputs = np.load(imagenet_cat)
+
outputs, times = tvmc.run(
tflite_compiled_module_as_tarfile,
- inputs_file=imagenet_cat,
+ inputs=inputs,
hostname=None,
device="cpu",
profile=True,