You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/03 21:41:34 UTC

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #6351: Dynamic ONNX Importer

jwfromm commented on a change in pull request #6351:
URL: https://github.com/apache/incubator-tvm/pull/6351#discussion_r483107517



##########
File path: include/tvm/relay/transform.h
##########
@@ -208,6 +208,17 @@ TVM_DLL Pass SimplifyInference();
  */
 TVM_DLL Pass FastMath();
 
+/*!
+ * \brief Find Dynamic ops and make them static
+ *
+ * Searches the graph for dynamic ops. If the dynamic inputs to those ops are constants, it replaces
+ * them with static ops and re-performs type inference and constant folding. The pass repeats
+ * istself until the graph stops changing or we run too many iterations.

Review comment:
       'istself'

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -2056,113 +2015,25 @@ class GraphProto(ExprFunctor):
     dtype : str or dict of str to str
         The input types to the graph
     """
-
     def __init__(self, shape, dtype):
         self._nodes = {}
         self._params = {}
+        self._inputs = {}
         self._renames = {}
         self._num_input = 0
         self._num_param = 0
         self._shape = shape if shape else {}
         self._dtype = dtype
 
-        #For infering Values
-        self._tmp_params = {}
-        self._infer_simulated = True
-        self._mod = None
-        super(GraphProto, self).__init__()
-
-    def infer_value(self, input_val, params, mod=None):
-        self._tmp_params = params
-        self._infer_simulated = False
-        self._mod = mod
-        return self.visit(input_val).data
-
-    def infer_value_simulated(self, input_val, params):
-        self._tmp_params = params
-        self._infer_simulated = True
-        return self.visit(input_val).data
-
-    def infer(self, expr):
-        if self._infer_simulated:
-            out = _infer_value_simulated(expr, self._tmp_params)
-        else:
-            out = _infer_value(expr, self._tmp_params)
-        return _expr.const(out.asnumpy())
-
-    def visit_function(self, fn):
-        new_params = [self.visit(x) for x in fn.params]
-        new_body = self.visit(fn.body)
-        return self.infer(Function(
-            list(new_params),
-            new_body,
-            fn.ret_type,
-            fn.type_params,
-            fn.attrs))
-
-    def visit_let(self, let):
-        newvar = self.visit(let.var)
-        newval = self.visit(let.value)
-        newbody = self.visit(let.body)
-        return self.infer(Let(newvar, newval, newbody))
-
-    def visit_call(self, call):
-        new_fn = self.visit(call.op)
-        new_args = [self.visit(arg) for arg in call.args]
-        call = Call(new_fn, new_args, call.attrs)
-        if new_fn == _op.get("nn.batch_norm"):
-            return call
-        return self.infer(call)
-
-    def visit_var(self, var):
-        return self.infer(var)
-
-    def visit_global_id(self, global_var):
-        return self.infer(global_var)
-
-    def visit_if(self, ite):
-        return self.infer(If(
-            self.visit(ite.cond),
-            self.visit(ite.true_branch),
-            self.visit(ite.false_branch)))
-
-    def visit_tuple(self, tup):
-        return Tuple([self.visit(field) for field in tup.fields])
-
-    def visit_tuple_getitem(self, op):
-        tuple_value = self.visit(op.tuple_value)
-        if not tuple_value.same_as(op.tuple_value):
-            return self.infer(TupleGetItem(tuple_value, op.index))
-        return self.infer(op)
-
-    def visit_global_var(self, gvar):
-        return self.infer(gvar)
-
-    def visit_op(self, op):
-        return op
-
-    def visit_constant(self, const):
-        return const
-
-    def visit_constructor(self, con):
-        return con
-
-    def visit_match(self, m):
-        return self.infer(Match(
-            self.visit(m.data),
-            [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
-            complete=m.complete))
-
-    def visit_ref_create(self, r):
-        return RefCreate(self.visit(r.value))
-
-    def visit_ref_write(self, r):
-        return RefWrite(self.visit(r.ref), self.visit(r.value))
-
-    def visit_ref_read(self, r):
-        return RefRead(self.visit(r.ref))
-
-    def from_onnx(self, graph, opset):
+    def freeze(self, func, params):
+        bind_map = {}
+        for name in params.keys():
+            bind_map[self._nodes[name]] = _expr.const(params[name])
+        body = _expr.bind(func.body, bind_map)
+        fn = _function.Function(analysis.free_vars(body), body)
+        return fn, {}
+
+    def from_onnx(self, graph, opset, freeze_params=False):

Review comment:
       Add a description for `freeze_params`. Specifically it's unclear when we would need or want to use it.

##########
File path: tests/python/frontend/onnx/test_forward.py
##########
@@ -44,17 +44,19 @@ def get_input_data_shape_dict(graph_def, input_data):
     return input_names, shape_dict
 
 
-def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None):
+def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None, freeze_params=False):
     """ Generic function to execute and get tvm output with vm executor"""
+    if not isinstance(input_data, list):
+        input_data = [input_data]
+    input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)

Review comment:
       `input_names` doesnt seem to be used, maybe put a `_` instead to indicate this.

##########
File path: python/tvm/relay/frontend/onnx.py
##########
@@ -2372,10 +2248,8 @@ def _fix_outputs(self, op_name, outputs):
             outputs = outputs[:-1]
         return outputs
 
-def from_onnx(model,
-              shape=None,
-              dtype="float32",
-              opset=None):
+
+def from_onnx(model, shape=None, dtype="float32", opset=None, freeze_params=False):

Review comment:
       Also needs argument description of `freeze_params`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org