You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/04/24 17:11:57 UTC
[tvm] branch main updated: [Frontend][Tensorflow] SelectV2 and
BroadcastArgs op support for tf2 models (#7901)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fad10d7 [Frontend][Tensorflow] SelectV2 and BroadcastArgs op support for tf2 models (#7901)
fad10d7 is described below
commit fad10d7914d3807e56501d7cc239d505da175e07
Author: srinidhigoud <sr...@gmail.com>
AuthorDate: Sat Apr 24 10:11:31 2021 -0700
[Frontend][Tensorflow] SelectV2 and BroadcastArgs op support for tf2 models (#7901)
---
python/tvm/relay/frontend/tensorflow.py | 43 ++++++++++++++++++++++--
tests/python/frontend/tensorflow/test_forward.py | 29 +++++++++++++++-
2 files changed, 68 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 997f68b..f566a3f 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -19,6 +19,7 @@
"""TF: Tensorflow frontend."""
import warnings
from collections import defaultdict
+from collections import deque
# Numpy support
import numpy as np
@@ -1770,6 +1771,43 @@ def _bias_add():
return _impl
+def _broadcast_args():
+ def _impl(inputs, attr, params, mod):
+ if isinstance(inputs[0], _expr.Var):
+ s0 = params[inputs[0].name_hint]
+ else:
+ s0 = _infer_value(inputs[0], params, mod)
+ if isinstance(inputs[1], _expr.Var):
+ s1 = params[inputs[1].name_hint]
+ else:
+ s1 = _infer_value(inputs[1], params, mod)
+ s0 = list(s0.asnumpy().reshape([-1]))
+ s1 = list(s1.asnumpy().reshape([-1]))
+ s0_size, s1_size = len(s0), len(s1)
+
+ out = deque([])
+ for i in range(1, min(s0_size, s1_size) + 1):
+ if s0[s0_size - i] == s1[s1_size - i]:
+ out.appendleft(s0[s0_size - i])
+ elif s0[s0_size - i] == 1:
+ out.appendleft(s1[s1_size - i])
+ else:
+ assert s1[s1_size - i] == 1, "Incompatible broadcast type %s and %s" % (
+ s0[s0_size - i],
+ s1[s1_size - i],
+ )
+ out.appendleft(s0[s0_size - i])
+ if s0_size < s1_size:
+ for i in range(s0_size + 1, s1_size + 1):
+ out.appendleft(s1[s1_size - i])
+ if s1_size < s0_size:
+ for i in range(s1_size + 1, s0_size + 1):
+ out.appendleft(s0[s0_size - i])
+ return _expr.const(list(out), attr["T"].name)
+
+ return _impl
+
+
def _broadcast_to():
def _impl(inputs, attr, params, mod):
if isinstance(inputs[1], _expr.Var):
@@ -2745,6 +2783,7 @@ _convert_map = {
"BatchToSpaceND": _batch_to_space_nd(),
"BiasAdd": _bias_add(),
"BroadcastTo": _broadcast_to(),
+ "BroadcastArgs": _broadcast_args(),
"Cast": _cast(),
"Ceil": AttrCvt("ceil"),
"CheckNumerics": _check_numerics(),
@@ -2838,6 +2877,7 @@ _convert_map = {
"Round": AttrCvt("round"),
"Rsqrt": _rsqrt(),
"Select": _where(),
+ "SelectV2": _where(),
"Selu": _selu(),
"Shape": _shape(),
"Sigmoid": AttrCvt("sigmoid"),
@@ -3941,7 +3981,6 @@ class GraphProto(object):
raise ImportError("Unable to import tensorflow which is required {}".format(e))
input_op_name = node_name.split(":")[0].split("^")[-1]
-
if input_op_name not in self._nodes:
node = self._tf_node_map[input_op_name]
attr = self._parse_attr(node.attr)
@@ -4002,7 +4041,6 @@ class GraphProto(object):
inputs[i] = actual_input
op = self._convert_operator(node.op, node.name, inputs, attr)
-
if isinstance(op, np.ndarray):
self._params[node.name] = tvm.nd.array(op)
op = [
@@ -4024,7 +4062,6 @@ class GraphProto(object):
tn = node_name.split(":")
tensor_slot = int(tn[1]) if len(tn) > 1 else 0
return out[tensor_slot]
-
return out[0]
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 80a70e4..f4e7522 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -3079,6 +3079,33 @@ def test_forward_resize():
#######################################################################
+# BroadcastArgs
+# -----------
+
+
+def _test_broadcast_args(in_shape_1, in_shape_2):
+ """ One iteration of broadcast_args"""
+
+ shape_1 = np.array(in_shape_1).astype("int32")
+ shape_2 = np.array(in_shape_2).astype("int32")
+
+ with tf.Graph().as_default():
+ shape_1 = constant_op.constant(shape_1, shape=shape_1.shape, dtype=shape_1.dtype)
+ shape_2 = constant_op.constant(shape_2, shape=shape_2.shape, dtype=shape_2.dtype)
+ tf.raw_ops.BroadcastArgs(s0=shape_1, s1=shape_2)
+
+ compare_tf_with_tvm(None, "", "BroadcastArgs:0", opt_level=0)
+
+
+def test_forward_broadcast_args():
+ """ Resize Bilinear """
+
+ _test_broadcast_args((4, 1, 32, 32), [4, 8, 32, 32])
+ _test_broadcast_args((6, 32, 32, 1), [6, 32, 32, 16])
+ _test_broadcast_args((32, 32, 16), [6, 32, 32, 16])
+
+
+#######################################################################
# BroadcastTo
# -----------
@@ -3621,7 +3648,7 @@ def test_forward_logical():
#######################################################################
-# Where, Select
+# Where, Select, SelectV2
# -------------
def test_forward_where():
""" Where: return elements depending on conditions"""