You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/07 00:41:01 UTC

[incubator-tvm] branch master updated: [Frontend][Torch] Check graph inputs match expected (#4992)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new de34649  [Frontend][Torch] Check graph inputs match expected (#4992)
de34649 is described below

commit de34649330b17d4278b01785893a862874a35ce3
Author: Jeremy Johnson <je...@arm.com>
AuthorDate: Sat Mar 7 00:40:50 2020 +0000

    [Frontend][Torch] Check graph inputs match expected (#4992)
    
    * [Frontend][Torch] Check graph inputs match expected
    
    * error/warn when missing/unused graph inputs
    
    * Change to use get_graph_input_names
---
 python/tvm/relay/frontend/pytorch.py | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index 5716837..ff37f82 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -905,6 +905,20 @@ def _report_missing_conversion(op_names):
         msg = "The following operators are not implemented: {}".format(missing)
         raise NotImplementedError(msg)
 
+def _check_input_names(script_module, input_shapes):
+    """ Check the graph inputs match the inputs """
+    ir_inputs = get_graph_input_names(script_module)
+
+    for ir_input in ir_inputs:
+        if ir_input not in input_shapes:
+            msg = "Missing graph input {} in input_shapes".format(ir_input)
+            raise RuntimeError(msg)
+
+    for input_name in input_shapes:
+        if input_name not in ir_inputs:
+            msg = "Unused graph input {} in input_shapes".format(input_name)
+            logging.warning(msg)
+
 
 def _getattr_attr_name(node):
     attribute_names = node.attributeNames()
@@ -1150,6 +1164,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
 
     op_names = get_all_op_names(graph)
     _report_missing_conversion(op_names)
+    _check_input_names(script_module, input_shapes)
 
     params = script_module.state_dict()
     input_vars = parse_inputs(graph.inputs(), input_shapes)