You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2021/02/20 04:27:13 UTC

[tvm] branch main updated: [Frontend][Tensorflow] Support explicit_paddings for TF 2.x (#7445)

This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5688068  [Frontend][Tensorflow] Support explicit_paddings for TF 2.x (#7445)
5688068 is described below

commit 5688068eb02912a4ec926a88f5cad3f0f370454e
Author: Trevor Morris <tr...@amazon.com>
AuthorDate: Fri Feb 19 20:26:55 2021 -0800

    [Frontend][Tensorflow] Support explicit_paddings for TF 2.x (#7445)
    
    * Ignore some TF2.0 attributes
    
    * Support explicit padding for conv2d, max_pool, conv3d
    
    * Remove conv3d explicit padding test since TF API doesn't allow it
---
 python/tvm/relay/frontend/tensorflow.py          | 44 +++++++++++++++++++++---
 tests/python/frontend/tensorflow/test_forward.py | 40 +++++++++++++++++++++
 2 files changed, 79 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py
index 6a29ce2..ac52ab7 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -268,6 +268,13 @@ def _pooling(name):
             pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
 
             attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
+        elif attr["padding"] == "EXPLICIT":
+            paddings = attr["explicit_paddings"]
+            assert len(paddings) == 8
+            if flip_layout or attr["data_format"] == "NHWC":
+                attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]]
+            else:
+                attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]]
         else:
             msg = 'Value {} in attribute "padding" of operator Pooling is ' "not valid."
             raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
@@ -278,7 +285,7 @@ def _pooling(name):
         out = AttrCvt(
             op_name=_dimension_picker(name),
             transforms={"kernel_shape": "pool_size", "data_format": "layout"},
-            ignores=["ksize"],
+            ignores=["ksize", "explicit_paddings"],
             extras={"ceil_mode": False},
             custom_check=_dimension_constraint(),
         )(inputs, attr)
@@ -418,6 +425,13 @@ def _conv(opname):
             pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
 
             attr["padding"] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
+        elif attr["padding"] == "EXPLICIT":
+            paddings = attr["explicit_paddings"]
+            assert len(paddings) == 8
+            if flip_layout or attr["data_format"] == "NHWC":
+                attr["padding"] = [paddings[2], paddings[4], paddings[3], paddings[5]]
+            else:
+                attr["padding"] = [paddings[4], paddings[6], paddings[5], paddings[7]]
         else:
             msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid."
             raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
@@ -626,7 +640,27 @@ def _conv3d(opname):
             pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
 
             attr["padding"] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
-
+        elif attr["padding"] == "EXPLICIT":
+            paddings = attr["explicit_paddings"]
+            assert len(paddings) == 10
+            if flip_layout or attr["data_format"] == "NDHWC":
+                attr["padding"] = [
+                    paddings[2],
+                    paddings[4],
+                    paddings[6],
+                    paddings[3],
+                    paddings[5],
+                    paddings[7],
+                ]
+            else:
+                attr["padding"] = [
+                    paddings[4],
+                    paddings[6],
+                    paddings[8],
+                    paddings[5],
+                    paddings[7],
+                    paddings[9],
+                ]
         else:
             msg = 'Value {} in attribute "padding" of operator Conv is not ' "valid."
             raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
@@ -1445,9 +1479,9 @@ def _squeeze():
     def _impl(inputs, attr, params, mod):
         if len(attr["squeeze_dims"]) == 0:
             attr["squeeze_dims"] = None
-        return AttrCvt(op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T"])(
-            inputs, attr
-        )
+        return AttrCvt(
+            op_name="squeeze", transforms={"squeeze_dims": "axis"}, ignores=["T", "_cloned"]
+        )(inputs, attr)
 
     return _impl
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index f956ea0..ecf6441 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -414,6 +414,16 @@ def test_forward_pooling():
             pooling_type=pool_type,
             dilation_rate=[2],
         )
+    # Explicit padding
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
+        _test_pooling(
+            input_shape=[2, 9, 10, 2],
+            window_shape=[4, 4],
+            padding=[[0, 0], [0, 1], [2, 3], [0, 0]],
+            pooling_type="MAX",
+            dilation_rate=[1, 1],
+            strides=[1, 1],
+        )
 
 
 #######################################################################
@@ -830,6 +840,36 @@ def test_forward_convolution():
         [4, 8, 8, 176],
         add_shapes_to_graph_def=False,
     )
+    # Explicit padding
+    if package_version.parse(tf.VERSION) >= package_version.parse("2.4.1"):
+        _test_convolution(
+            "conv",
+            [4, 8, 8, 16],
+            [1, 1, 16, 32],
+            [1, 1],
+            [1, 1],
+            [[0, 0], [2, 3], [0, 1], [0, 0]],
+            "NHWC",
+        )
+        _test_convolution(
+            "depthwise",
+            [4, 8, 8, 16],
+            [1, 1, 16, 1],
+            [1, 1],
+            [1, 1],
+            [[0, 0], [2, 3], [0, 1], [0, 0]],
+            "NHWC",
+        )
+        _test_convolution(
+            "conv_transpose",
+            [4, 8, 8, 32],
+            [3, 3, 176, 32],
+            [1, 1],
+            [2, 2],
+            [[0, 0], [1, 0], [1, 0], [0, 0]],
+            "NHWC",
+            [4, 16, 16, 176],
+        )
 
 
 #######################################################################