You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/06/12 22:11:44 UTC
[incubator-tvm] branch master updated: Edit onnx parser to infer
values in post order (#5755)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 995b9ff Edit onnx parser to infer values in post order (#5755)
995b9ff is described below
commit 995b9ff8a452bb46b64080b9b1fc0f10f0a778cf
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Fri Jun 12 15:11:34 2020 -0700
Edit onnx parser to infer values in post order (#5755)
* edit onnx parser to infer values in post order to speed up onnx imports with many calls to infer_value
* fix pylint
---
python/tvm/relay/frontend/onnx.py | 119 +++++++++++++++++++++++++++++++++++++-
1 file changed, 116 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py
index 17cb148..dabe55f 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -27,12 +27,29 @@ from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import vision as _vision
+
+from ..function import Function
+from ..expr import Call, Let
+from ..expr import If, Tuple, TupleGetItem
+from ..expr import RefCreate, RefRead, RefWrite
+from ..expr_functor import ExprFunctor
+from ..adt import Match, Clause
+
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
-from .common import infer_type, infer_value, infer_value_simulated, get_name
+from .common import infer_type, get_name
+from .common import infer_value as _infer_value
+from .common import infer_value_simulated as _infer_value_simulated
__all__ = ['from_onnx']
+g = None
+
+def infer_value(input_val, params, mod=None):
+ return g.infer_value(input_val, params, mod)
+
+def infer_value_simulated(input_val, params):
+ return g.infer_value_simulated(input_val, params)
class onnx_input():
""" Dual purpose list or dictionary access object."""
@@ -1891,8 +1908,7 @@ def _get_convert_map(opset):
'NonZero': NonZero.get_converter(opset),
}
-
-class GraphProto(object):
+class GraphProto(ExprFunctor):
"""A helper class for handling Relay expression copying from pb2.GraphProto.
Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
@@ -1914,6 +1930,101 @@ class GraphProto(object):
self._shape = shape if shape else {}
self._dtype = dtype
+ #For infering Values
+ self._tmp_params = {}
+ self._infer_simulated = True
+ self._mod = None
+ super(GraphProto, self).__init__()
+
+ def infer_value(self, input_val, params, mod=None):
+ self._tmp_params = params
+ self._infer_simulated = False
+ self._mod = mod
+ return self.visit(input_val).data
+ #return _infer_value(input_val, params, mod)
+
+ def infer_value_simulated(self, input_val, params):
+ self._tmp_params = params
+ self._infer_simulated = True
+ return self.visit(input_val).data
+ #return _infer_value_simulated(input_val, params)
+
+ def infer(self, expr):
+ if self._infer_simulated:
+ out = _infer_value_simulated(expr, self._tmp_params)
+ else:
+ out = _infer_value(expr, self._tmp_params)
+ return _expr.const(out.asnumpy())
+
+ def visit_function(self, fn):
+ new_params = [self.visit(x) for x in fn.params]
+ new_body = self.visit(fn.body)
+ return self.infer(Function(
+ list(new_params),
+ new_body,
+ fn.ret_type,
+ fn.type_params,
+ fn.attrs))
+
+ def visit_let(self, let):
+ newvar = self.visit(let.var)
+ newval = self.visit(let.value)
+ newbody = self.visit(let.body)
+ return self.infer(Let(newvar, newval, newbody))
+
+ def visit_call(self, call):
+ new_fn = self.visit(call.op)
+ new_args = [self.visit(arg) for arg in call.args]
+ return self.infer(Call(new_fn, new_args, call.attrs))
+
+ def visit_var(self, var):
+ return self.infer(var)
+
+ def visit_global_id(self, global_var):
+ return self.infer(global_var)
+
+ def visit_if(self, ite):
+ return self.infer(If(
+ self.visit(ite.cond),
+ self.visit(ite.true_branch),
+ self.visit(ite.false_branch)))
+
+ def visit_tuple(self, tup):
+ return Tuple([self.visit(field) for field in tup.fields])
+
+ def visit_tuple_getitem(self, op):
+ tuple_value = self.visit(op.tuple_value)
+ if not tuple_value.same_as(op.tuple_value):
+ return self.infer(TupleGetItem(tuple_value, op.index))
+ return self.infer(op)
+
+ def visit_global_var(self, gvar):
+ return self.infer(gvar)
+
+ def visit_op(self, op):
+ return op
+
+ def visit_constant(self, const):
+ return const
+
+ def visit_constructor(self, con):
+ return con
+
+ def visit_match(self, m):
+ return self.infer(Match(
+ self.visit(m.data),
+ [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
+ complete=m.complete))
+
+ def visit_ref_create(self, r):
+ return RefCreate(self.visit(r.value))
+
+ def visit_ref_write(self, r):
+ return RefWrite(self.visit(r.ref), self.visit(r.value))
+
+ def visit_ref_read(self, r):
+ return RefRead(self.visit(r.ref))
+
def from_onnx(self, graph, opset):
"""Construct Relay expression from ONNX graph.
@@ -2172,6 +2283,7 @@ def from_onnx(model,
warnings.warn(str(e))
except ImportError:
pass
+ global g
g = GraphProto(shape, dtype)
graph = model.graph
if opset is None:
@@ -2180,4 +2292,5 @@ def from_onnx(model,
except AttributeError:
opset = 1
mod, params = g.from_onnx(graph, opset)
+ g = None
return mod, params