You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/07 00:41:01 UTC
[incubator-tvm] branch master updated: [Frontend][Torch] Check
graph inputs match expected (#4992)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new de34649 [Frontend][Torch] Check graph inputs match expected (#4992)
de34649 is described below
commit de34649330b17d4278b01785893a862874a35ce3
Author: Jeremy Johnson <je...@arm.com>
AuthorDate: Sat Mar 7 00:40:50 2020 +0000
[Frontend][Torch] Check graph inputs match expected (#4992)
* [Frontend][Torch] Check graph inputs match expected
* error/warn when missing/unused graph inputs
* Change to use get_graph_input_names
---
python/tvm/relay/frontend/pytorch.py | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 5716837..ff37f82 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -905,6 +905,20 @@ def _report_missing_conversion(op_names):
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)
+def _check_input_names(script_module, input_shapes):
+ """ Check the graph inputs match the inputs """
+ ir_inputs = get_graph_input_names(script_module)
+
+ for ir_input in ir_inputs:
+ if ir_input not in input_shapes:
+ msg = "Missing graph input {} in input_shapes".format(ir_input)
+ raise RuntimeError(msg)
+
+ for input_name in input_shapes:
+ if input_name not in ir_inputs:
+ msg = "Unused graph input {} in input_shapes".format(input_name)
+ logging.warning(msg)
+
def _getattr_attr_name(node):
attribute_names = node.attributeNames()
@@ -1150,6 +1164,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
+ _check_input_names(script_module, input_shapes)
params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)