You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/08/11 10:09:07 UTC

[GitHub] [tvm] lhutton1 commented on a diff in pull request #12353: [CMSIS-NN] Pad fusion with QNN Conv2D

lhutton1 commented on code in PR #12353:
URL: https://github.com/apache/tvm/pull/12353#discussion_r943274640


##########
src/relay/backend/contrib/cmsisnn/fuse_pads.cc:
##########
@@ -0,0 +1,219 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file fuse_pads.cc

Review Comment:
   Nit: src/relay/backend/contrib/cmsisnn/fuse_pads.cc



##########
tests/python/contrib/test_cmsisnn/test_fuse_pads.py:
##########
@@ -0,0 +1,279 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN integration tests: fuse_pads pass"""
+import numpy as np
+import pytest
+import tvm
+import tvm.testing
+from tvm import relay
+
+tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)
+
+
+class CheckForPadsWithinCompositeFunc(tvm.relay.ExprVisitor):
+    """Provides method to test number of pads present inside the function being visited."""
+
+    def __init__(self):
+        super().__init__()
+        self.num_pads_ = 0
+
+    def visit_call(self, call):
+        super().visit_call(call)
+        if (
+            isinstance(call, tvm.relay.Call)
+            and isinstance(call.op, tvm.ir.op.Op)
+            and call.op.name == "nn.pad"
+        ):
+            self.num_pads_ += 1
+
+    def check_num_pads(self):
+        assert self.num_pads_ == 0, "CMSIS-NN composite function should not have pads"
+
+
+def set_external_func_attr(func, compiler, ext_symbol):
+    func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+    func = func.with_attr("Compiler", compiler)
+    func = func.with_attr("global_symbol", ext_symbol)
+    return func
+
+
+def set_composite_func_attr(func, name):
+    func = func.with_attr("Composite", name)
+    return func
+
+
+@pytest.mark.parametrize(
+    "ifm_shape, pad_width, conv2d_padding, ofm_shape",
+    [
+        [(1, 25, 25, 12), ((0, 0), (0, 2), (1, 2), (0, 0)), (1, 1, 1, 1), (1, 26, 28, 2)],
+        [(1, 64, 100, 4), ((0, 0), (1, 3), (1, 1), (0, 0)), (0, 0, 0, 0), (1, 64, 100, 2)],
+        [(1, 55, 55, 3), ((0, 0), (2, 1), (3, 5), (0, 0)), (0, 0, 1, 1), (1, 57, 59, 2)],
+    ],
+)
+def test_invalid_padding_for_fusion(ifm_shape, pad_width, conv2d_padding, ofm_shape):
+    """Negative tests for pads preceding Conv2D that cannot be fused."""
+    dtype = "int8"
+    kernel_size = (3, 3)
+    ofm_channels = 2
+    local_input = relay.var("local_input", shape=ifm_shape, dtype=dtype)
+    pad = relay.nn.pad(
+        local_input,
+        pad_width=pad_width,  # ((), (top, bottom), (left, right), ())
+        pad_value=10,
+        pad_mode="constant",
+    )
+    rng = np.random.default_rng(12321)
+    local_weight = tvm.nd.array(
+        rng.integers(
+            np.iinfo(dtype).min,
+            high=np.iinfo(dtype).max,
+            size=(ofm_channels, kernel_size[0], kernel_size[1], ifm_shape[3]),
+            dtype=dtype,
+        )
+    )
+    local_weight = relay.const(local_weight, dtype)
+    conv2d = relay.qnn.op.conv2d(
+        pad,
+        local_weight,
+        relay.const(1, "int32"),
+        relay.const(1, "int32"),
+        relay.const(1, "float32"),
+        relay.const(1, "float32"),
+        data_layout="NHWC",
+        kernel_layout="OHWI",
+        channels=ofm_channels,
+        kernel_size=(3, 3),
+        padding=conv2d_padding,
+        out_dtype="int32",
+    )
+    requantize = relay.qnn.op.requantize(
+        conv2d,
+        relay.const(1, "float32"),
+        relay.const(1, "int32"),
+        relay.const(1, "float32"),
+        relay.const(1, "int32"),
+        axis=0,
+        out_dtype=dtype,
+    )
+    local_func = relay.Function(relay.analysis.free_vars(requantize), requantize)
+    local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_conv2d")
+
+    mod = tvm.IRModule()
+    ext_input = relay.var("ext_input", shape=ifm_shape, dtype=dtype)
+    call_local_func = relay.Call(local_func, [ext_input])
+    extern_func = relay.Function(relay.analysis.free_vars(call_local_func), call_local_func)
+    extern_var = relay.GlobalVar("external_function")
+    extern_func = set_external_func_attr(extern_func, "cmsis-nn", extern_var.name_hint)
+    mod[extern_var] = extern_func
+
+    main_input = relay.var("main_input", shape=ifm_shape, dtype=dtype)
+    call_extern_func = relay.Call(extern_var, [main_input])
+    main_func = relay.Function([main_input], call_extern_func, relay.TensorType(ofm_shape, dtype))
+    main_var = relay.GlobalVar("main")
+    mod[main_var] = main_func
+
+    mod = relay.transform.InferType()(mod)
+
+    error_regex = r"Difference on each side of a dimension should be either 0 or 1"
+
+    with pytest.raises(tvm.TVMError, match=error_regex):
+        mod = CMSISNNFusePads()(mod)
+
+
+@pytest.mark.parametrize(
+    "ifm_shape, pad_width, conv2d_padding, ofm_shape",
+    [
+        [(1, 25, 25, 12), ((0, 0), (0, 1), (1, 2), (0, 0)), (1, 1, 1, 1), (1, 26, 28, 2)],
+        [(1, 64, 100, 4), ((0, 0), (1, 1), (1, 1), (0, 0)), (0, 0, 0, 0), (1, 64, 100, 2)],
+        [(1, 55, 55, 3), ((0, 0), (2, 1), (3, 2), (0, 0)), (0, 0, 1, 1), (1, 57, 59, 2)],
+    ],
+)
+def test_pad_conv2d_fusion(ifm_shape, pad_width, conv2d_padding, ofm_shape):
+    """Tests the pads and conv2d fusion."""
+    dtype = "int8"
+    kernel_size = (3, 3)
+    ofm_channels = 2
+    local_input = relay.var("local_input", shape=ifm_shape, dtype=dtype)
+    pad = relay.nn.pad(
+        local_input,
+        pad_width=pad_width,  # ((), (top, bottom), (left, right), ())
+        pad_value=10,
+        pad_mode="constant",
+    )
+    rng = np.random.default_rng(12321)
+    local_weight = tvm.nd.array(
+        rng.integers(
+            np.iinfo(dtype).min,
+            high=np.iinfo(dtype).max,
+            size=(ofm_channels, kernel_size[0], kernel_size[1], ifm_shape[3]),
+            dtype=dtype,
+        )
+    )
+    local_weight = relay.const(local_weight, dtype)
+    conv2d = relay.qnn.op.conv2d(
+        pad,
+        local_weight,
+        relay.const(1, "int32"),
+        relay.const(1, "int32"),
+        relay.const(1, "float32"),
+        relay.const(1, "float32"),
+        data_layout="NHWC",
+        kernel_layout="OHWI",
+        channels=ofm_channels,
+        kernel_size=(3, 3),
+        padding=conv2d_padding,
+        out_dtype="int32",
+    )
+    requantize = relay.qnn.op.requantize(
+        conv2d,
+        relay.const(1, "float32"),
+        relay.const(1, "int32"),
+        relay.const(1, "float32"),
+        relay.const(1, "int32"),
+        axis=0,
+        out_dtype=dtype,
+    )
+    local_func = relay.Function(relay.analysis.free_vars(requantize), requantize)
+    local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_conv2d")
+
+    mod = tvm.IRModule()
+    ext_input = relay.var("ext_input", shape=ifm_shape, dtype=dtype)
+    call_local_func = relay.Call(local_func, [ext_input])
+    extern_func = relay.Function(relay.analysis.free_vars(call_local_func), call_local_func)
+    extern_var = relay.GlobalVar("external_function")
+    extern_func = set_external_func_attr(extern_func, "cmsis-nn", extern_var.name_hint)
+    mod[extern_var] = extern_func
+
+    main_input = relay.var("main_input", shape=ifm_shape, dtype=dtype)
+    call_extern_func = relay.Call(extern_var, [main_input])
+    main_func = relay.Function([main_input], call_extern_func, relay.TensorType(ofm_shape, dtype))
+    main_var = relay.GlobalVar("main")
+    mod[main_var] = main_func
+
+    mod = relay.transform.InferType()(mod)
+
+    mod = CMSISNNFusePads()(mod)
+    pad_verifier = CheckForPadsWithinCompositeFunc()
+    pad_verifier.visit_function(mod[extern_var])
+    pad_verifier.check_num_pads()
+
+
+def test_without_preceding_pad():
+    """Tests the pass FusePads when padding is not present before qnn.conv2d."""
+    dtype = "int8"
+    ifm_shape = (1, 56, 56, 64)
+    ofm_shape = (1, 56, 56, 64)
+    local_input = relay.var("local_input", shape=ifm_shape, dtype=dtype)
+    rng = np.random.default_rng(12321)
+    local_weight = tvm.nd.array(
+        rng.integers(
+            np.iinfo(dtype).min,
+            high=np.iinfo(dtype).max,
+            size=(64, 3, 3, 64),
+            dtype=dtype,
+        )
+    )
+    local_weight = relay.const(local_weight, dtype)
+    conv2d = relay.qnn.op.conv2d(
+        local_input,
+        local_weight,
+        relay.const(1, "int32"),
+        relay.const(1, "int32"),
+        relay.const(1, "float32"),
+        relay.const(1, "float32"),
+        data_layout="NHWC",
+        kernel_layout="OHWI",
+        channels=64,
+        kernel_size=(3, 3),
+        padding=(1, 1, 1, 1),
+        out_dtype="int32",
+    )
+    requantize = relay.qnn.op.requantize(
+        conv2d,
+        relay.const(1, "float32"),
+        relay.const(1, "int32"),
+        relay.const(1, "float32"),
+        relay.const(1, "int32"),
+        axis=0,
+        out_dtype=dtype,
+    )
+    relu = relay.nn.relu(requantize)
+    local_func = relay.Function(relay.analysis.free_vars(relu), relu)
+    local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_conv2d")
+
+    mod = tvm.IRModule()
+    ext_input = relay.var("ext_input", shape=ifm_shape, dtype=dtype)
+    call_local_func = relay.Call(local_func, [ext_input])
+    extern_func = relay.Function(relay.analysis.free_vars(call_local_func), call_local_func)
+    extern_var = relay.GlobalVar("external_function")
+    extern_func = set_external_func_attr(extern_func, "cmsis-nn", extern_var.name_hint)
+    mod[extern_var] = extern_func
+
+    main_input = relay.var("main_input", shape=ifm_shape, dtype=dtype)
+    call_extern_func = relay.Call(extern_var, [main_input])
+    main_func = relay.Function(relay.analysis.free_vars(call_extern_func), call_extern_func)
+    main_func = relay.Function([main_input], call_extern_func, relay.TensorType(ofm_shape, dtype))
+    main_var = relay.GlobalVar("main")
+    mod[main_var] = main_func
+
+    mod = relay.transform.InferType()(mod)
+
+    mod = CMSISNNFusePads()(mod)
+    pad_verifier = CheckForPadsWithinCompositeFunc()
+    pad_verifier.visit_function(mod[extern_var])
+    pad_verifier.check_num_pads()

Review Comment:
   Could we add a test to check that padding for non-cmsisnn functions is not altered?



##########
src/relay/backend/contrib/cmsisnn/fuse_pads.cc:
##########
@@ -0,0 +1,219 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file fuse_pads.cc
+ * \brief Fuses pads that precede qnn.conv2d ops inside CMSIS-NN composite functions.
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../op/make_op.h"
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+#include "convolutions.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D.
+ * Then, it will fuse preceding pads with qnn.conv2d.
+ */
+class FusePadsMutator : public MixedModeMutator {
+ public:
+  explicit FusePadsMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief In order to eliminate preceding nn.pad op, pad_width of nn.pad is passed onto
+   * convolution layer to update Conv2DAttrs's padding attribute. */
+  void UpdateConv2DPadding(const CallNode* conv2d_call, const Array<Array<Integer>>& pad_width,
+                           const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = std::move(conv2d_attrs->kernel_layout);
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);

Review Comment:
   Is it possible to copy conv2d_attrs instead?



##########
src/relay/backend/contrib/cmsisnn/fuse_pads.cc:
##########
@@ -0,0 +1,219 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file fuse_pads.cc
+ * \brief Fuses pads that precede qnn.conv2d ops inside CMSIS-NN composite functions.
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../op/make_op.h"
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+#include "convolutions.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D.
+ * Then, it will fuse preceding pads with qnn.conv2d.
+ */
+class FusePadsMutator : public MixedModeMutator {
+ public:
+  explicit FusePadsMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief In order to eliminate preceding nn.pad op, pad_width of nn.pad is passed onto
+   * convolution layer to update Conv2DAttrs's padding attribute. */
+  void UpdateConv2DPadding(const CallNode* conv2d_call, const Array<Array<Integer>>& pad_width,
+                           const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = std::move(conv2d_attrs->kernel_layout);
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
+
+    // pad_width: ((), (top, bottom), (left, right), ()) for NHWC layout
+    // conv2d_attrs->padding: (top, left, bottom, right)
+    std::string data_layout = conv2d_attrs->data_layout.c_str();
+    int pos_h = data_layout.find("H");
+    int pos_w = data_layout.find("W");
+
+    int pad_top =
+        qnn::get_const_int(conv2d_attrs->padding[0]) + qnn::get_const_int(pad_width[pos_h][0]);
+    int pad_left =
+        qnn::get_const_int(conv2d_attrs->padding[1]) + qnn::get_const_int(pad_width[pos_w][0]);
+    int pad_bottom =
+        qnn::get_const_int(conv2d_attrs->padding[2]) + qnn::get_const_int(pad_width[pos_h][1]);
+    int pad_right =
+        qnn::get_const_int(conv2d_attrs->padding[3]) + qnn::get_const_int(pad_width[pos_w][1]);
+
+    int pad_diff_w = pad_right - pad_left;
+    int pad_diff_h = pad_bottom - pad_top;
+    bool can_pad_be_fused =
+        ((pad_diff_w == 0 || pad_diff_w == 1) && (pad_diff_h == 0 || pad_diff_h == 1));
+    std::string error = "Difference on each side of a dimension should be either 0 or 1. ";
+    error += "Effective padding in this case: (pad_top, pad_left, pad_bottom, pad_right)=(";
+    error += std::to_string(pad_top);
+    error += ", ";
+    error += std::to_string(pad_left);
+    error += ", ";
+    error += std::to_string(pad_bottom);
+    error += ", ";
+    error += std::to_string(pad_right);
+    error += ")";
+    ICHECK(can_pad_be_fused) << error;
+
+    attrs->padding = {pad_top, pad_left, pad_bottom, pad_right};
+    *new_attrs = tvm::Attrs{attrs};
+  }
+
+  /*!
+   * \brief Identifies the sequence for qnn.conv2D and fuses the preceding nn.pad present within the
+   * CMSIS-NN partitioned function. */
+  Expr FusePadConv2d(const Expr& expr) {
+    const CallNode* clip_call = nullptr;
+    const CallNode* requantize_call = nullptr;
+    const CallNode* bias_add_call = nullptr;

Review Comment:
   Curious, why do we need to maintain references to these ops and reconstruct them at the end of this function? Will a call to `VisitExpr` after the new conv2d has been created do this for us instead?



##########
tests/python/contrib/test_cmsisnn/test_conv2d.py:
##########
@@ -729,7 +830,7 @@ def test_invalid_parameters(
         in_dtype,
         kernel_dtype,
         in_dtype,
-        False,
+        is_depthwise=False,
     )
     model, params = make_model(
         shape=ifm_shape,

Review Comment:
   Would be good to check a couple of other things here:
   - The pass runs as part of the pipeline as these tests could pass even when the pass is not run
   - Attempting to fuse a pad operation that is unsupported falls back to offloading just conv2d (think we might have missed this case in microNPU)



##########
src/relay/backend/contrib/cmsisnn/fuse_pads.cc:
##########
@@ -0,0 +1,219 @@
+

Review Comment:
   Nit: additional line here?



##########
src/relay/backend/contrib/cmsisnn/fuse_pads.cc:
##########
@@ -0,0 +1,219 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file fuse_pads.cc
+ * \brief Fuses pads that precede qnn.conv2d ops inside CMSIS-NN composite functions.
+ */
+
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../op/make_op.h"
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+#include "convolutions.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D.
+ * Then, it will fuse preceding pads with qnn.conv2d.
+ */
+class FusePadsMutator : public MixedModeMutator {
+ public:
+  explicit FusePadsMutator(const IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief In order to eliminate preceding nn.pad op, pad_width of nn.pad is passed onto
+   * convolution layer to update Conv2DAttrs's padding attribute. */
+  void UpdateConv2DPadding(const CallNode* conv2d_call, const Array<Array<Integer>>& pad_width,
+                           const Conv2DAttrs* conv2d_attrs, Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = std::move(conv2d_attrs->kernel_layout);
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
+
+    // pad_width: ((), (top, bottom), (left, right), ()) for NHWC layout
+    // conv2d_attrs->padding: (top, left, bottom, right)
+    std::string data_layout = conv2d_attrs->data_layout.c_str();
+    int pos_h = data_layout.find("H");
+    int pos_w = data_layout.find("W");
+
+    int pad_top =
+        qnn::get_const_int(conv2d_attrs->padding[0]) + qnn::get_const_int(pad_width[pos_h][0]);
+    int pad_left =
+        qnn::get_const_int(conv2d_attrs->padding[1]) + qnn::get_const_int(pad_width[pos_w][0]);
+    int pad_bottom =
+        qnn::get_const_int(conv2d_attrs->padding[2]) + qnn::get_const_int(pad_width[pos_h][1]);
+    int pad_right =
+        qnn::get_const_int(conv2d_attrs->padding[3]) + qnn::get_const_int(pad_width[pos_w][1]);
+
+    int pad_diff_w = pad_right - pad_left;
+    int pad_diff_h = pad_bottom - pad_top;

Review Comment:
   Is it possible to separate this logic and register it globally, so we can use it in python in the check function as well to avoid duplication?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org