You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2022/07/26 13:38:02 UTC

[tvm] 02/03: This is PR #12130.

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch ci-docker-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 2bcf79b059cb124f2e035a9ba06386d975bd9925
Author: Leandro Nunes <le...@arm.com>
AuthorDate: Fri Jul 22 16:41:03 2022 +0100

    This is PR #12130.
---
 python/tvm/relay/frontend/keras.py           |  8 ++++---
 tests/python/frontend/tflite/test_forward.py | 35 +++++++++++++++++++++-------
 2 files changed, 31 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py
index 3f7a96544a..8c8a4a1ddc 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -635,9 +635,11 @@ def _convert_pooling(
             _op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
         )
     if pool_type == "GlobalAveragePooling2D":
-        return _convert_flatten(
-            _op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
-        )
+        global_avg_pool2d = _op.nn.global_avg_pool2d(inexpr, **global_pool_params)
+        keep_dims = len(keras_layer.input.shape) == len(keras_layer.output.shape)
+        if keep_dims:
+            return global_avg_pool2d
+        return _convert_flatten(global_avg_pool2d, keras_layer, etab, data_layout)
     pool_h, pool_w = keras_layer.pool_size
     stride_h, stride_w = keras_layer.strides
     params = {
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 6acc8554b4..709ed3f2bf 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -935,7 +935,11 @@ def _test_tflite2_quantized_convolution(
     )
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", ""))
+    if tf.__version__ < LooseVersion("2.9"):
+        input_node = data_in.name.replace(":0", "")
+    else:
+        input_node = "serving_default_" + data_in.name + ":0"
+    tvm_output = run_tvm_graph(tflite_model_quant, data, input_node)
     tvm.testing.assert_allclose(
         np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
     )
@@ -1934,10 +1938,12 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8):
         # TFLite 2.6.x upgrade support
         if tf.__version__ < LooseVersion("2.6.1"):
             in_node = ["serving_default_input_int8"]
-        else:
+        elif tf.__version__ < LooseVersion("2.9"):
             in_node = (
                 ["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"]
             )
+        else:
+            in_node = "serving_default_input"
 
         tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
         tvm.testing.assert_allclose(
@@ -1965,8 +1971,10 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
             tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype
         )
         tflite_output = run_tflite_graph(tflite_model_quant, data)
-        in_node = ["tfl.quantize"]
-
+        if tf.__version__ < LooseVersion("2.9"):
+            in_node = ["tfl.quantize"]
+        else:
+            in_node = "serving_default_input"
         tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
         tvm.testing.assert_allclose(
             np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
@@ -2047,7 +2055,10 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8):
             tf.math.cos, data, int_quant_dtype=int_quant_dtype
         )
         tflite_output = run_tflite_graph(tflite_model_quant, data)
-        in_node = ["tfl.quantize"]
+        if tf.__version__ < LooseVersion("2.9"):
+            in_node = ["tfl.quantize"]
+        else:
+            in_node = "serving_default_input"
         tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
         tvm.testing.assert_allclose(
             np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
@@ -2955,7 +2966,6 @@ def _test_quantize_dequantize(data):
     add = tf.keras.layers.Add()([data_in, relu])
     concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
     keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
-    input_name = data_in.name.split(":")[0]
 
     # To create quantized values with dynamic range of activations, needs representative dataset
     def representative_data_gen():
@@ -2965,7 +2975,11 @@ def _test_quantize_dequantize(data):
     tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
+    if tf.__version__ < LooseVersion("2.9"):
+        in_node = data_in.name.split(":")[0]
+    else:
+        in_node = "serving_default_" + data_in.name + ":0"
+    tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
     tvm.testing.assert_allclose(
         np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
     )
@@ -2982,7 +2996,6 @@ def _test_quantize_dequantize_const(data):
     add = tf.keras.layers.Add()([data, relu])
     concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
     keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
-    input_name = data_in.name.split(":")[0]
 
     # To create quantized values with dynamic range of activations, needs representative dataset
     def representative_data_gen():
@@ -2992,7 +3005,11 @@ def _test_quantize_dequantize_const(data):
     tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
+    if tf.__version__ < LooseVersion("2.9"):
+        in_node = data_in.name.split(":")[0]
+    else:
+        in_node = "serving_default_" + data_in.name + ":0"
+    tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
     tvm.testing.assert_allclose(
         np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
     )