You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2022/07/26 13:38:02 UTC
[tvm] 02/03: This is PR #12130.
This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch ci-docker-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 2bcf79b059cb124f2e035a9ba06386d975bd9925
Author: Leandro Nunes <le...@arm.com>
AuthorDate: Fri Jul 22 16:41:03 2022 +0100
This is PR #12130.
---
python/tvm/relay/frontend/keras.py | 8 ++++---
tests/python/frontend/tflite/test_forward.py | 35 +++++++++++++++++++++-------
2 files changed, 31 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py
index 3f7a96544a..8c8a4a1ddc 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -635,9 +635,11 @@ def _convert_pooling(
_op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
)
if pool_type == "GlobalAveragePooling2D":
- return _convert_flatten(
- _op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
- )
+ global_avg_pool2d = _op.nn.global_avg_pool2d(inexpr, **global_pool_params)
+ keep_dims = len(keras_layer.input.shape) == len(keras_layer.output.shape)
+ if keep_dims:
+ return global_avg_pool2d
+ return _convert_flatten(global_avg_pool2d, keras_layer, etab, data_layout)
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 6acc8554b4..709ed3f2bf 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -935,7 +935,11 @@ def _test_tflite2_quantized_convolution(
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", ""))
+ if tf.__version__ < LooseVersion("2.9"):
+ input_node = data_in.name.replace(":0", "")
+ else:
+ input_node = "serving_default_" + data_in.name + ":0"
+ tvm_output = run_tvm_graph(tflite_model_quant, data, input_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
)
@@ -1934,10 +1938,12 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8):
# TFLite 2.6.x upgrade support
if tf.__version__ < LooseVersion("2.6.1"):
in_node = ["serving_default_input_int8"]
- else:
+ elif tf.__version__ < LooseVersion("2.9"):
in_node = (
["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"]
)
+ else:
+ in_node = "serving_default_input"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
@@ -1965,8 +1971,10 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- in_node = ["tfl.quantize"]
-
+ if tf.__version__ < LooseVersion("2.9"):
+ in_node = ["tfl.quantize"]
+ else:
+ in_node = "serving_default_input"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
@@ -2047,7 +2055,10 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8):
tf.math.cos, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- in_node = ["tfl.quantize"]
+ if tf.__version__ < LooseVersion("2.9"):
+ in_node = ["tfl.quantize"]
+ else:
+ in_node = "serving_default_input"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
@@ -2955,7 +2966,6 @@ def _test_quantize_dequantize(data):
add = tf.keras.layers.Add()([data_in, relu])
concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
- input_name = data_in.name.split(":")[0]
# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
@@ -2965,7 +2975,11 @@ def _test_quantize_dequantize(data):
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
+ if tf.__version__ < LooseVersion("2.9"):
+ in_node = data_in.name.split(":")[0]
+ else:
+ in_node = "serving_default_" + data_in.name + ":0"
+ tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
)
@@ -2982,7 +2996,6 @@ def _test_quantize_dequantize_const(data):
add = tf.keras.layers.Add()([data, relu])
concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
- input_name = data_in.name.split(":")[0]
# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
@@ -2992,7 +3005,11 @@ def _test_quantize_dequantize_const(data):
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)
tflite_output = run_tflite_graph(tflite_model_quant, data)
- tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
+ if tf.__version__ < LooseVersion("2.9"):
+ in_node = data_in.name.split(":")[0]
+ else:
+ in_node = "serving_default_" + data_in.name + ":0"
+ tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
)