You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/19 22:09:28 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #9328: [Op] Do not override specified layout in pooling (2nd PR)

comaniac commented on a change in pull request #9328:
URL: https://github.com/apache/tvm/pull/9328#discussion_r732275775



##########
File path: tests/python/relay/test_pass_convert_op_layout.py
##########
@@ -2039,5 +2039,337 @@ def expected():
         assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_max_pool_uses_specified_convert_layout():
+    relay.op.get("nn.max_pool2d").reset_attr("FTVMConvertOpLayout")
+
+    @tvm.ir.register_op_attr("nn.max_pool2d", "FTVMConvertOpLayout")
+    def convert_maxpool2d(attrs, inputs, tinfos, desired_layouts):
+        # stick by convertng layout and out_layout to use NHWC and NHWC,
+        #   respectively, as specified in the transforms.ConvertLayout() function's arguments later
+        new_attrs = dict(attrs)
+        new_attrs["layout"] = str(desired_layouts[0])
+        new_attrs["out_layout"] = str(desired_layouts[0])
+        return relay.nn.max_pool2d(*inputs, **new_attrs)

Review comment:
       We should register this to `relay/op/nn/_nn.py` once along with other converters such as conv2d. Otherwise users will need to manually specify this function to use the feature.

##########
File path: tests/python/relay/test_pass_convert_op_layout.py
##########
@@ -2039,5 +2039,337 @@ def expected():
         assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_max_pool_uses_specified_convert_layout():

Review comment:
       Better to put this test along with other existing pooling tests.

##########
File path: src/relay/op/nn/pooling.cc
##########
@@ -49,8 +49,12 @@ InferCorrectLayoutOutput PoolInferCorrectLayout(const Attrs& attrs,
   ICHECK(attrs_ptr);
   ObjectPtr<T> params = make_object<T>(*attrs_ptr);
 
-  if (new_in_layouts.defined()) {
-    // Set the pool with the new layout.
+  if (params->out_layout != "") {
+    // when users specify the out_layout of pooling, transforms.ConvertLayout pass will
+    //   follow user's preference
+    ICHECK_EQ(params->layout, params->out_layout);

Review comment:
       ```suggestion
     if (params->out_layout != "") {
       // when users specify the out_layout of pooling, follow user's preference.
       ICHECK_EQ(params->layout, params->out_layout) << "Pooling input/output layouts mismatch: " << params->layout << " vs. " << params->out_layout;
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org