You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/06/12 22:11:44 UTC

[incubator-tvm] branch master updated: Edit onnx parser to infer values in post order (#5755)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 995b9ff  Edit onnx parser to infer values in post order (#5755)
995b9ff is described below

commit 995b9ff8a452bb46b64080b9b1fc0f10f0a778cf
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Jun 12 15:11:34 2020 -0700

    Edit onnx parser to infer values in post order (#5755)
    
    * edit onnx parser to infer values in post order to speed up onnx imports with many calls to infer_value
    
    * fix pylint
---
 python/tvm/relay/frontend/onnx.py | 119 +++++++++++++++++++++++++++++++++++++-
 1 file changed, 116 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 17cb148..dabe55f 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -27,12 +27,29 @@ from .. import expr as _expr
 from .. import function as _function
 from .. import op as _op
 from .. import vision as _vision
+
+from ..function import Function
+from ..expr import Call, Let
+from ..expr import If, Tuple, TupleGetItem
+from ..expr import RefCreate, RefRead, RefWrite
+from ..expr_functor import ExprFunctor
+from ..adt import Match, Clause
+
 from .common import AttrCvt, Renamer
 from .common import get_relay_op, new_var, infer_shape, infer_channels
-from .common import infer_type, infer_value, infer_value_simulated, get_name
+from .common import infer_type, get_name
+from .common import infer_value as _infer_value
+from .common import infer_value_simulated as _infer_value_simulated
 
 __all__ = ['from_onnx']
 
+g = None
+
+def infer_value(input_val, params, mod=None):
+    return g.infer_value(input_val, params, mod)
+
+def infer_value_simulated(input_val, params):
+    return g.infer_value_simulated(input_val, params)
 
 class onnx_input():
     """ Dual purpose list or dictionary access object."""
@@ -1891,8 +1908,7 @@ def _get_convert_map(opset):
         'NonZero': NonZero.get_converter(opset),
     }
 
-
-class GraphProto(object):
+class GraphProto(ExprFunctor):
     """A helper class for handling Relay expression copying from pb2.GraphProto.
     Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
 
@@ -1914,6 +1930,101 @@ class GraphProto(object):
         self._shape = shape if shape else {}
         self._dtype = dtype
 
+        #For infering Values
+        self._tmp_params = {}
+        self._infer_simulated = True
+        self._mod = None
+        super(GraphProto, self).__init__()
+
+    def infer_value(self, input_val, params, mod=None):
+        self._tmp_params = params
+        self._infer_simulated = False
+        self._mod = mod
+        return self.visit(input_val).data
+        #return _infer_value(input_val, params, mod)
+
+    def infer_value_simulated(self, input_val, params):
+        self._tmp_params = params
+        self._infer_simulated = True
+        return self.visit(input_val).data
+        #return _infer_value_simulated(input_val, params)
+
+    def infer(self, expr):
+        if self._infer_simulated:
+            out = _infer_value_simulated(expr, self._tmp_params)
+        else:
+            out = _infer_value(expr, self._tmp_params)
+        return _expr.const(out.asnumpy())
+
+    def visit_function(self, fn):
+        new_params = [self.visit(x) for x in fn.params]
+        new_body = self.visit(fn.body)
+        return self.infer(Function(
+            list(new_params),
+            new_body,
+            fn.ret_type,
+            fn.type_params,
+            fn.attrs))
+
+    def visit_let(self, let):
+        newvar = self.visit(let.var)
+        newval = self.visit(let.value)
+        newbody = self.visit(let.body)
+        return self.infer(Let(newvar, newval, newbody))
+
+    def visit_call(self, call):
+        new_fn = self.visit(call.op)
+        new_args = [self.visit(arg) for arg in call.args]
+        return self.infer(Call(new_fn, new_args, call.attrs))
+
+    def visit_var(self, var):
+        return self.infer(var)
+
+    def visit_global_id(self, global_var):
+        return self.infer(global_var)
+
+    def visit_if(self, ite):
+        return self.infer(If(
+            self.visit(ite.cond),
+            self.visit(ite.true_branch),
+            self.visit(ite.false_branch)))
+
+    def visit_tuple(self, tup):
+        return Tuple([self.visit(field) for field in tup.fields])
+
+    def visit_tuple_getitem(self, op):
+        tuple_value = self.visit(op.tuple_value)
+        if not tuple_value.same_as(op.tuple_value):
+            return self.infer(TupleGetItem(tuple_value, op.index))
+        return self.infer(op)
+
+    def visit_global_var(self, gvar):
+        return self.infer(gvar)
+
+    def visit_op(self, op):
+        return op
+
+    def visit_constant(self, const):
+        return const
+
+    def visit_constructor(self, con):
+        return con
+
+    def visit_match(self, m):
+        return self.infer(Match(
+            self.visit(m.data),
+            [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
+            complete=m.complete))
+
+    def visit_ref_create(self, r):
+        return RefCreate(self.visit(r.value))
+
+    def visit_ref_write(self, r):
+        return RefWrite(self.visit(r.ref), self.visit(r.value))
+
+    def visit_ref_read(self, r):
+        return RefRead(self.visit(r.ref))
+
     def from_onnx(self, graph, opset):
         """Construct Relay expression from ONNX graph.
 
@@ -2172,6 +2283,7 @@ def from_onnx(model,
                 warnings.warn(str(e))
     except ImportError:
         pass
+    global g
     g = GraphProto(shape, dtype)
     graph = model.graph
     if opset is None:
@@ -2180,4 +2292,5 @@ def from_onnx(model,
         except AttributeError:
             opset = 1
     mod, params = g.from_onnx(graph, opset)
+    g = None
     return mod, params