You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2022/01/20 16:13:04 UTC

[GitHub] [incubator-mxnet] agrabows opened a new pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

agrabows opened a new pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835


   ## Description ##
   Reshape operators from both NDArray and NumPy modules of MXNet had no quantized version so if it appeared between two quantized operators data would have to be dequantized and quantized again after reshape (as shown on picture below with version after change as well). Goal was to create quantized reshape operator.
   Previous version:
   ![image](https://user-images.githubusercontent.com/59651240/150377082-bf394c8d-6df5-45e0-a4b9-d4e39a1a1ce4.png)
   New version:
   ![image](https://user-images.githubusercontent.com/59651240/150377132-d73db877-dcf0-41b8-a056-80fbbeb1bc4c.png)
   
   ## Checklist ##
   ### Essentials ###
   - [x] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
   - [x] Changes are complete (i.e. I finished coding on this PR)
   - [x] All changes have test coverage
   
   ### Changes ###
   - [x] Add quantized version of numpy/numpy_extension reshape
   - [x] Add quantized version of ndarray reshape


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] RafLit commented on a change in pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

Posted by GitBox <gi...@apache.org>.
RafLit commented on a change in pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835#discussion_r800398959



##########
File path: src/operator/quantization/quantized_reshape.cc
##########
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantized_reshape.cc
+ * \author: Adam Grabowski, adam.grabowski@intel.com
+ */
+
+#include <utility>
+#include "quantized_reshape-inl.h"
+
+namespace mxnet {
+namespace op {
+
+void QuantizedReshapeCompute(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<TBlob>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 3U);
+  CHECK_EQ(req.size(), 3U);
+
+  if (req[0] != kWriteInplace)
+    UnaryOp::IdentityCompute<cpu>(attrs, ctx, inputs, req, outputs);
+
+  *outputs[1].dptr<float>() = *inputs[1].dptr<float>();
+  *outputs[2].dptr<float>() = *inputs[2].dptr<float>();
+}
+
+#define MXNET_OPERATOR_REGISTER_QUANTIZED_RESHAPE(name)                                      \
+  NNVM_REGISTER_OP(name)                                                                     \
+      .set_num_inputs(3)                                                                     \
+      .set_num_outputs(3)                                                                    \
+      .set_attr<nnvm::FListInputNames>(                                                      \
+          "FListInputNames",                                                                 \
+          [](const NodeAttrs& attrs) {                                                       \
+            return std::vector<std::string>{"data", "min_data", "max_data"};                 \
+          })                                                                                 \
+      .set_attr<nnvm::FListOutputNames>(                                                     \
+          "FListOutputNames",                                                                \
+          [](const NodeAttrs& attrs) {                                                       \
+            return std::vector<std::string>{"output", "min_output", "max_output"};           \
+          })                                                                                 \
+      .set_attr<nnvm::FInplaceOption>(                                                       \
+          "FInplaceOption",                                                                  \
+          [](const NodeAttrs& attrs) {                                                       \
+            return std::vector<std::pair<int, int> >{{0, 0}, {1, 1}, {2, 2}};                \
+          })                                                                                 \
+      .set_attr<FCompute>("FCompute<cpu>", QuantizedReshapeCompute)                          \
+      .set_attr<FResourceRequest>(                                                           \
+          "FResourceRequest",                                                                \
+          [](const NodeAttrs& n) {                                                           \
+            return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};                \
+          })                                                                                 \
+      .set_attr<nnvm::FInferType>("FInferType", QuantizedReshapeType)                        \
+      .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)                             \
+      .set_attr<FQuantizable>("FQuantizable",                                                \
+                              [](const NodeAttrs& attrs) { return QuantizeType::kSupport; }) \
+      .add_argument("data", "NDArray-or-Symbol", "Array to be reshaped.")                    \
+      .add_argument("min_data",                                                              \
+                    "NDArray-or-Symbol",                                                     \
+                    "The minimum scalar value "                                              \
+                    "possibly produced for the data")                                        \
+      .add_argument("max_data",                                                              \
+                    "NDArray-or-Symbol",                                                     \
+                    "The maximum scalar value "                                              \
+                    "possibly produced for the data")
+
+MXNET_OPERATOR_REGISTER_QUANTIZED_RESHAPE(_contrib_quantized_reshape)
+    .add_alias("quantized_reshape")
+    .set_attr_parser(ParamParser<ReshapeParam>)
+    .set_attr<mxnet::FInferShape>("FInferShape", QuantizedReshapeInferShape<ReshapeShape>)
+    .add_arguments(ReshapeParam::__FIELDS__());
+
+MXNET_OPERATOR_REGISTER_QUANTIZED_RESHAPE(_npx_quantized_reshape)
+    .set_attr_parser(ParamParser<NumpyXReshapeParam>)
+    .set_attr<mxnet::FInferShape>("FInferShape", QuantizedReshapeInferShape<NumpyXReshapeShape>)
+    .add_arguments(NumpyXReshapeParam::__FIELDS__());
+
+template <bool is_numpy_op>
+nnvm::ObjectPtr QuantizedReshapeNode(const NodeAttrs& attrs) {
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+
+  if constexpr (is_numpy_op) {
+    node->attrs.op = Op::Get("_npx_quantized_reshape");
+  } else {
+    node->attrs.op = Op::Get("_contrib_quantized_reshape");
+  }

Review comment:
       You could make the op name a string template parameter instead




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] RafLit commented on a change in pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

Posted by GitBox <gi...@apache.org>.
RafLit commented on a change in pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835#discussion_r795432242



##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -945,6 +945,46 @@ def check_quantized_bn(data_shape, qdtype):
       check_quantized_bn((32, 3, 224, 224), qdtype)
 
 
+def test_quantized_reshape():
+    test_cases = [((2, 3, 5, 5),  (-2, -1),         False, (2, 75)), 
+                  ((2, 3, 5, 5),  (-2, -2, -1),     False, (2, 3, 25)), 
+                  ((5, 3, 4, 5),  (-2, -1, -2),     False, (5, 15, 4)), 
+                  ((2, 3, 5, 4),  (-1, -2, -2),     False, (8, 3, 5)), 
+                  ((2, 3, 5, 5),  (-2, -2, -2, -2), False, (2, 3, 5, 5)), 
+                  ((2, 1, 4, 5),  (-2, -3, -2, -2), False, (2, 4, 5)), 
+                  ((1, 1, 4, 1),  (-3, -3, -2, -2), False, (4, 1)), 
+                  ((1, 1, 1, 1),  (-3, -3, -3, -3), False, ()), 
+                  ((2, 4, 5, 3),  (-1, 2, 2, 1),    False, (30, 2, 2, 1)), 
+                  ((2, 3, 5, 6),  (-4,),            False, (2, 3, 5, 6)), 
+                  ((2, 3, 5, 6),  (6, 1, -4),       False, (6, 1, 5, 6)), 
+                  ((2, 3, 5, 6),  (-5, -5),         False, (6, 30)), 
+                  ((2, 3, 5, 6),  (-5, -1),         False, (6, 30)), 
+                  ((64,),         (-6, 16, 4),      False, (16, 4)), 
+                  ((64,),         (-6, 16, -1),     False, (16, 4)),
+                  ((64, 1, 2, 3), (-6, 16, -1, -4), False, (16, 4, 1, 2, 3)), 
+                  ((8, 5, 4, 6),  (-4, -1, 3, -6),  True,  (8, 5, 4, 2, 3))]
+
+    def check_quantized_reshape(shape, qdtype, newshape, reverse, expected_ret_shape):
+        if qdtype == 'uint8':
+            data_low = 0.0
+            data_high = 127.0
+        else:
+            data_low = -127.0
+            data_high = 127.0
+        qdata = mx.np.random.uniform(low=data_low, high=data_high, size=shape).astype(qdtype)
+        min_data = mx.np.array([-1023.343], dtype='float32')
+        max_data = mx.np.array([2343.324275], dtype='float32')
+        qoutput, min_output, max_output = npx.quantized_reshape(qdata, min_data, max_data, newshape=newshape, reverse=reverse)
+        assert qoutput.shape == expected_ret_shape
+        assert same(qdata.asnumpy().flatten(), qoutput.asnumpy().flatten())

Review comment:
       is flatten necessary here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] agrabows commented on pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

Posted by GitBox <gi...@apache.org>.
agrabows commented on pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835#issuecomment-1034635202


   @mxnet-bot run ci [centos-gpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] RafLit commented on a change in pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

Posted by GitBox <gi...@apache.org>.
RafLit commented on a change in pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835#discussion_r795424225



##########
File path: src/operator/quantization/quantized_reshape-inl.h
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantized_reshape-inl.h
+ * \author: Adam Grabowski, adam.grabowski@intel.com
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RESHAPE_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RESHAPE_INL_H_
+
+#include <string>
+#include <vector>
+#include "operator/tensor/matrix_op-inl.h"
+#include "operator/numpy/np_matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+struct QuantizedReshapeParam : public dmlc::Parameter<QuantizedReshapeParam> {
+  mxnet::TShape newshape;
+  mxnet::Tuple<int> shape;
+  bool reverse, keep_highest, is_numpy_op;
+  std::string order;

Review comment:
       I don't think merging both parameter types is a good idea. Duplicating type information in a variable is counter-intuitive. What do you think about separating the operator into two versions: numpy and ndarray one? 

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -945,6 +945,46 @@ def check_quantized_bn(data_shape, qdtype):
       check_quantized_bn((32, 3, 224, 224), qdtype)
 
 
+def test_quantized_reshape():
+    test_cases = [((2, 3, 5, 5),  (-2, -1),         False, (2, 75)), 
+                  ((2, 3, 5, 5),  (-2, -2, -1),     False, (2, 3, 25)), 
+                  ((5, 3, 4, 5),  (-2, -1, -2),     False, (5, 15, 4)), 
+                  ((2, 3, 5, 4),  (-1, -2, -2),     False, (8, 3, 5)), 
+                  ((2, 3, 5, 5),  (-2, -2, -2, -2), False, (2, 3, 5, 5)), 
+                  ((2, 1, 4, 5),  (-2, -3, -2, -2), False, (2, 4, 5)), 
+                  ((1, 1, 4, 1),  (-3, -3, -2, -2), False, (4, 1)), 
+                  ((1, 1, 1, 1),  (-3, -3, -3, -3), False, ()), 
+                  ((2, 4, 5, 3),  (-1, 2, 2, 1),    False, (30, 2, 2, 1)), 
+                  ((2, 3, 5, 6),  (-4,),            False, (2, 3, 5, 6)), 
+                  ((2, 3, 5, 6),  (6, 1, -4),       False, (6, 1, 5, 6)), 
+                  ((2, 3, 5, 6),  (-5, -5),         False, (6, 30)), 
+                  ((2, 3, 5, 6),  (-5, -1),         False, (6, 30)), 
+                  ((64,),         (-6, 16, 4),      False, (16, 4)), 
+                  ((64,),         (-6, 16, -1),     False, (16, 4)),
+                  ((64, 1, 2, 3), (-6, 16, -1, -4), False, (16, 4, 1, 2, 3)), 
+                  ((8, 5, 4, 6),  (-4, -1, 3, -6),  True,  (8, 5, 4, 2, 3))]
+
+    def check_quantized_reshape(shape, qdtype, newshape, reverse, expected_ret_shape):
+        if qdtype == 'uint8':
+            data_low = 0.0
+            data_high = 127.0
+        else:
+            data_low = -127.0
+            data_high = 127.0
+        qdata = mx.np.random.uniform(low=data_low, high=data_high, size=shape).astype(qdtype)
+        min_data = mx.np.array([-1023.343], dtype='float32')
+        max_data = mx.np.array([2343.324275], dtype='float32')
+        qoutput, min_output, max_output = npx.quantized_reshape(qdata, min_data, max_data, newshape=newshape, reverse=reverse)
+        assert qoutput.shape == expected_ret_shape
+        assert same(qdata.asnumpy().flatten(), qoutput.asnumpy().flatten())

Review comment:
       is flatten necessary here?

##########
File path: tests/python/dnnl/subgraphs/test_conv_subgraph.py
##########
@@ -73,6 +73,28 @@ def forward(self, x):
   check_fusion(net, data_shape, attr)
 
 
+@mx.util.use_np
+@pytest.mark.parametrize('data_shape', DATA_SHAPE)
+@pytest.mark.parametrize('use_bias', [True, False])
+def test_conv_reshape_conv(use_bias, data_shape):
+
+  class Conv_Reshape_Conv(nn.HybridBlock):
+    def __init__(self, **kwargs):
+        super(Conv_Reshape_Conv, self).__init__(**kwargs)
+        self.conv0 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=1, use_bias=use_bias)
+        self.conv1 = nn.Conv2D(channels=32, kernel_size=(5, 5), strides=1, use_bias=use_bias)
+
+    def forward(self, x):
+      out = self.conv0(x)
+      out = mx.npx.reshape(out, newshape=(-1, int(out.shape[1]/4), out.shape[2]*2, out.shape[3]*2))
+      out = self.conv1(out)
+      return out
+
+  attr = {'conv': []}
+  net = Conv_Reshape_Conv()
+  check_fusion(net, data_shape, attr)

Review comment:
       Wouldn't it be better to use check_quantize explicitly? Is anything fused in this example?

##########
File path: tests/python/quantization/test_quantization.py
##########
@@ -945,6 +945,46 @@ def check_quantized_bn(data_shape, qdtype):
       check_quantized_bn((32, 3, 224, 224), qdtype)
 
 
+def test_quantized_reshape():
+    test_cases = [((2, 3, 5, 5),  (-2, -1),         False, (2, 75)), 
+                  ((2, 3, 5, 5),  (-2, -2, -1),     False, (2, 3, 25)), 
+                  ((5, 3, 4, 5),  (-2, -1, -2),     False, (5, 15, 4)), 
+                  ((2, 3, 5, 4),  (-1, -2, -2),     False, (8, 3, 5)), 
+                  ((2, 3, 5, 5),  (-2, -2, -2, -2), False, (2, 3, 5, 5)), 
+                  ((2, 1, 4, 5),  (-2, -3, -2, -2), False, (2, 4, 5)), 
+                  ((1, 1, 4, 1),  (-3, -3, -2, -2), False, (4, 1)), 
+                  ((1, 1, 1, 1),  (-3, -3, -3, -3), False, ()), 
+                  ((2, 4, 5, 3),  (-1, 2, 2, 1),    False, (30, 2, 2, 1)), 
+                  ((2, 3, 5, 6),  (-4,),            False, (2, 3, 5, 6)), 
+                  ((2, 3, 5, 6),  (6, 1, -4),       False, (6, 1, 5, 6)), 
+                  ((2, 3, 5, 6),  (-5, -5),         False, (6, 30)), 
+                  ((2, 3, 5, 6),  (-5, -1),         False, (6, 30)), 
+                  ((64,),         (-6, 16, 4),      False, (16, 4)), 
+                  ((64,),         (-6, 16, -1),     False, (16, 4)),
+                  ((64, 1, 2, 3), (-6, 16, -1, -4), False, (16, 4, 1, 2, 3)), 
+                  ((8, 5, 4, 6),  (-4, -1, 3, -6),  True,  (8, 5, 4, 2, 3))]
+
+    def check_quantized_reshape(shape, qdtype, newshape, reverse, expected_ret_shape):
+        if qdtype == 'uint8':
+            data_low = 0.0
+            data_high = 127.0

Review comment:
       can this be increased to 255?

##########
File path: src/operator/quantization/quantized_reshape-inl.h
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantized_reshape-inl.h
+ * \author: Adam Grabowski, adam.grabowski@intel.com
+ */
+
+#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RESHAPE_INL_H_
+#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RESHAPE_INL_H_
+
+#include <string>
+#include <vector>
+#include "operator/tensor/matrix_op-inl.h"
+#include "operator/numpy/np_matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+struct QuantizedReshapeParam : public dmlc::Parameter<QuantizedReshapeParam> {
+  mxnet::TShape newshape;
+  mxnet::Tuple<int> shape;
+  bool reverse, keep_highest, is_numpy_op;
+  std::string order;
+
+  DMLC_DECLARE_PARAMETER(QuantizedReshapeParam) {
+    DMLC_DECLARE_FIELD(newshape).set_default(mxnet::TShape(0, -1));
+    DMLC_DECLARE_FIELD(shape).set_default(mxnet::Tuple<int>());
+    DMLC_DECLARE_FIELD(reverse).set_default(false);
+    DMLC_DECLARE_FIELD(order).set_default("C");
+    DMLC_DECLARE_FIELD(keep_highest).set_default(false);
+    DMLC_DECLARE_FIELD(is_numpy_op).set_default(true);

Review comment:
       I think redundant type information should be avoided

##########
File path: src/operator/quantization/quantized_reshape.cc
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantized_reshape.cc
+ * \author: Adam Grabowski, adam.grabowski@intel.com
+ */
+
+#include <utility>
+#include "quantized_reshape-inl.h"
+
+namespace mxnet {
+namespace op {
+
+DMLC_REGISTER_PARAMETER(QuantizedReshapeParam);
+
+void QuantizedReshapeCompute(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<TBlob>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 3U);
+  CHECK_EQ(req.size(), 3U);
+
+  if (req[0] != kWriteInplace)
+    UnaryOp::IdentityCompute<cpu>(attrs, ctx, inputs, req, outputs);
+
+  *outputs[1].dptr<float>() = *inputs[1].dptr<float>();
+  *outputs[2].dptr<float>() = *inputs[2].dptr<float>();
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_reshape)
+    .add_alias("_npx_quantized_reshape")
+    .set_num_inputs(3)
+    .set_num_outputs(3)
+    .set_attr_parser(ParamParser<QuantizedReshapeParam>)
+    .set_attr<nnvm::FListInputNames>(
+        "FListInputNames",
+        [](const NodeAttrs& attrs) {
+          return std::vector<std::string>{"data", "min_data", "max_data"};
+        })
+    .set_attr<nnvm::FListOutputNames>(
+        "FListOutputNames",
+        [](const NodeAttrs& attrs) {
+          return std::vector<std::string>{"output", "min_output", "max_output"};
+        })
+    .set_attr<nnvm::FInplaceOption>(
+        "FInplaceOption",
+        [](const NodeAttrs& attrs) {
+          return std::vector<std::pair<int, int> >{{0, 0}, {1, 1}, {2, 2}};
+        })
+    .set_attr<FCompute>("FCompute<cpu>", QuantizedReshapeCompute)
+    .set_attr<FResourceRequest>("FResourceRequest",
+                                [](const NodeAttrs& n) {
+                                  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+                                })
+    .set_attr<mxnet::FInferShape>("FInferShape", QuantizedReshapeInferShape)
+    .set_attr<nnvm::FInferType>("FInferType", QuantizedReshapeType)
+    .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+    .set_attr<FQuantizable>("FQuantizable",
+                            [](const NodeAttrs& attrs) { return QuantizeType::kSupport; })
+    .add_argument("data", "NDArray-or-Symbol", "Array to be reshaped.")
+    .add_argument("min_data",
+                  "NDArray-or-Symbol",
+                  "The minimum scalar value "
+                  "possibly produced for the data")
+    .add_argument("max_data",
+                  "NDArray-or-Symbol",
+                  "The maximum scalar value "
+                  "possibly produced for the data")
+    .add_arguments(QuantizedReshapeParam::__FIELDS__());
+
+template <bool is_numpy_op>
+nnvm::ObjectPtr QuantizedReshapeNode(const NodeAttrs& attrs) {
+  QuantizedReshapeParam param;
+  if (is_numpy_op) {

Review comment:
       this will be checked at runtime for both the numpy version and ndarray. It would be more intuitive to separate them




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835#issuecomment-1017669396


   Hey @agrabows , Thanks for submitting the PR 
   All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands: 
   - To trigger all jobs: @mxnet-bot run ci [all] 
   - To trigger specific jobs: @mxnet-bot run ci [job1, job2] 
   *** 
   **CI supported jobs**: [miscellaneous, centos-gpu, clang, centos-cpu, unix-gpu, website, windows-cpu, windows-gpu, unix-cpu, edge, sanity]
   *** 
   _Note_: 
    Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin. 
   All CI tests must pass before the PR can be merged. 
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835#issuecomment-1034635350


   Jenkins CI successfully triggered : [centos-gpu]


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] bgawrych merged pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

Posted by GitBox <gi...@apache.org>.
bgawrych merged pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] agrabows commented on a change in pull request #20835: [FEATURE] Add quantized version of reshape with DNNL reorder primitive.

Posted by GitBox <gi...@apache.org>.
agrabows commented on a change in pull request #20835:
URL: https://github.com/apache/incubator-mxnet/pull/20835#discussion_r800765469



##########
File path: src/operator/quantization/quantized_reshape.cc
##########
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file quantized_reshape.cc
+ * \author: Adam Grabowski, adam.grabowski@intel.com
+ */
+
+#include <utility>
+#include "quantized_reshape-inl.h"
+
+namespace mxnet {
+namespace op {
+
+void QuantizedReshapeCompute(const nnvm::NodeAttrs& attrs,
+                             const OpContext& ctx,
+                             const std::vector<TBlob>& inputs,
+                             const std::vector<OpReqType>& req,
+                             const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 3U);
+  CHECK_EQ(req.size(), 3U);
+
+  if (req[0] != kWriteInplace)
+    UnaryOp::IdentityCompute<cpu>(attrs, ctx, inputs, req, outputs);
+
+  *outputs[1].dptr<float>() = *inputs[1].dptr<float>();
+  *outputs[2].dptr<float>() = *inputs[2].dptr<float>();
+}
+
+#define MXNET_OPERATOR_REGISTER_QUANTIZED_RESHAPE(name)                                      \
+  NNVM_REGISTER_OP(name)                                                                     \
+      .set_num_inputs(3)                                                                     \
+      .set_num_outputs(3)                                                                    \
+      .set_attr<nnvm::FListInputNames>(                                                      \
+          "FListInputNames",                                                                 \
+          [](const NodeAttrs& attrs) {                                                       \
+            return std::vector<std::string>{"data", "min_data", "max_data"};                 \
+          })                                                                                 \
+      .set_attr<nnvm::FListOutputNames>(                                                     \
+          "FListOutputNames",                                                                \
+          [](const NodeAttrs& attrs) {                                                       \
+            return std::vector<std::string>{"output", "min_output", "max_output"};           \
+          })                                                                                 \
+      .set_attr<nnvm::FInplaceOption>(                                                       \
+          "FInplaceOption",                                                                  \
+          [](const NodeAttrs& attrs) {                                                       \
+            return std::vector<std::pair<int, int> >{{0, 0}, {1, 1}, {2, 2}};                \
+          })                                                                                 \
+      .set_attr<FCompute>("FCompute<cpu>", QuantizedReshapeCompute)                          \
+      .set_attr<FResourceRequest>(                                                           \
+          "FResourceRequest",                                                                \
+          [](const NodeAttrs& n) {                                                           \
+            return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};                \
+          })                                                                                 \
+      .set_attr<nnvm::FInferType>("FInferType", QuantizedReshapeType)                        \
+      .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)                             \
+      .set_attr<FQuantizable>("FQuantizable",                                                \
+                              [](const NodeAttrs& attrs) { return QuantizeType::kSupport; }) \
+      .add_argument("data", "NDArray-or-Symbol", "Array to be reshaped.")                    \
+      .add_argument("min_data",                                                              \
+                    "NDArray-or-Symbol",                                                     \
+                    "The minimum scalar value "                                              \
+                    "possibly produced for the data")                                        \
+      .add_argument("max_data",                                                              \
+                    "NDArray-or-Symbol",                                                     \
+                    "The maximum scalar value "                                              \
+                    "possibly produced for the data")
+
+MXNET_OPERATOR_REGISTER_QUANTIZED_RESHAPE(_contrib_quantized_reshape)
+    .add_alias("quantized_reshape")
+    .set_attr_parser(ParamParser<ReshapeParam>)
+    .set_attr<mxnet::FInferShape>("FInferShape", QuantizedReshapeInferShape<ReshapeShape>)
+    .add_arguments(ReshapeParam::__FIELDS__());
+
+MXNET_OPERATOR_REGISTER_QUANTIZED_RESHAPE(_npx_quantized_reshape)
+    .set_attr_parser(ParamParser<NumpyXReshapeParam>)
+    .set_attr<mxnet::FInferShape>("FInferShape", QuantizedReshapeInferShape<NumpyXReshapeShape>)
+    .add_arguments(NumpyXReshapeParam::__FIELDS__());
+
+template <bool is_numpy_op>
+nnvm::ObjectPtr QuantizedReshapeNode(const NodeAttrs& attrs) {
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+
+  if constexpr (is_numpy_op) {
+    node->attrs.op = Op::Get("_npx_quantized_reshape");
+  } else {
+    node->attrs.op = Op::Get("_contrib_quantized_reshape");
+  }

Review comment:
       We could do it using global strings however global variables are not advisable. Instead, after consulting with @bgawrych and @mozga-intel, enumerate version of this functionality was created.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org