You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/22 11:15:25 UTC

[GitHub] [incubator-tvm] lhutton1 opened a new pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

lhutton1 opened a new pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109


   Adds support for asymmetric per-layer quantization in the Arm Compute Library runtime module. This includes support for qnn.conv2d, nn.maxpool2d and reshape.
   
   * Adds support for different datatypes (as opposed to hard-coded fp32).
   * Adds support for creating ACL tensor with scale and offset values.
   * Improved qnn.conv2d layout conversion to support different kernel layouts.
   * Added qnn.conv2d composite operator (pad?, qnn.conv2d, bias?, relu?, qnn.requantize).
   * Updated tests to reflect these changes.
   * Added table of supported operators in tutorial.
   
   Change-Id: I8f610bd37af1e3740fd48c2d502bcc4727d9d712
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459603811



##########
File path: tests/python/contrib/test_arm_compute_lib/test_conv2d.py
##########
@@ -127,51 +241,31 @@ def test_conv2d():
 
     device = Device()
     np.random.seed(0)
+    r = random.Random(0)

Review comment:
       The discussion in #6084 seems to suggest not relying on a seed




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics merged pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
zhiics merged pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459604438



##########
File path: src/runtime/contrib/arm_compute_lib/acl_utils.cc
##########
@@ -108,6 +111,30 @@ arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
                                     arm_compute::DimensionRoundingType::FLOOR);
 }
 
+arm_compute::DataType MakeDataType(const DLDataType& data_type) {
+  if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 32) {
+    return arm_compute::DataType::F32;
+  } else if (data_type.code == DLDataTypeCode::kDLUInt && data_type.bits == 8) {
+    return arm_compute::DataType::QASYMM8;
+  } else if (data_type.code == DLDataTypeCode::kDLInt && data_type.bits == 32) {
+    return arm_compute::DataType::S32;
+  } else {
+    LOG(FATAL) << "Datatype " << data_type << " unsupported by ACL runtime";
+    return arm_compute::DataType::UNKNOWN;
+  }
+}
+
+template <typename T>
+std::vector<T> GetVectorFromDLTensor(const DLTensor* tensor) {

Review comment:
       I had a look for something like this in the runtime, although couldn't find anything yet




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459616949



##########
File path: tests/python/contrib/test_arm_compute_lib/test_conv2d.py
##########
@@ -127,51 +241,31 @@ def test_conv2d():
 
     device = Device()
     np.random.seed(0)
+    r = random.Random(0)

Review comment:
       Although it is not the best practice neither to randomly generate input tensors as you pointed out, it is relatively acceptable in general. The reasons are:
   
   1. The module we are testing does not highly depend on the input values. 
   2. It is tedious to put a constant large tensor (e.g., 1, 3, 224, 224 = 150k numbers) in the code base.
   3. Sometimes people use `np.ones` or `np.zeros` to be inputs for the above two reasons, but these tensors are too far away from real world applications and may not fit our testing purposes.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459011905



##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen.cc
##########
@@ -49,6 +49,18 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
  public:
   ACLJSONSerializer(const std::string& symbol, const Expr& expr) : JSONSerializer(symbol, expr) {}
 
+  /*!
+   * \brief A series of operators that form a composite
+   * convolution. Supports both nn.conv2d and qnn.conv2d.
+   */
+  struct CompositeConvNodes {

Review comment:
       It seems to me that this struct represents one composite convolution node, so I would suggest s/CompositeConvNodes/CompositeConvNode/.
   

##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen.cc
##########
@@ -78,57 +90,83 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
 
  private:
   /*!
-   * \brief Create a JSON representation of a composite convolution.
+   * \brief Extract convolution nodes from a composite function.
    *
-   * \param call The call to be represented.
-   * \return A JSON representation of a specific operator.
+   * \param cn The call node of the composite function.
+   * \return Extracted composite convolution nodes.
    */
-  std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
-    const std::string name = "nn.conv2d";
-    const CallNode* pad = nullptr;
-    const CallNode* conv = nullptr;
-    const CallNode* bias = nullptr;
-    bool has_activation = false;
-
-    // Unpack composite function
+  static CompositeConvNodes UnpackCompositeConvolution(const CallNode* cn) {
+    CompositeConvNodes nodes{};
     const auto* fn = cn->op.as<FunctionNode>();
     CHECK(fn);
     const auto* current_call = fn->body.as<CallNode>();
+    if (backend::IsOp(current_call, "qnn.requantize")) {
+      nodes.requantize = current_call;
+      current_call = current_call->args[0].as<CallNode>();
+    }
     if (backend::IsOp(current_call, "nn.relu")) {
-      has_activation = true;
+      nodes.activation = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
     if (backend::IsOp(current_call, "nn.bias_add")) {
-      bias = current_call;
+      nodes.bias = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
-    CHECK(backend::IsOp(current_call, "nn.conv2d"));
-    conv = current_call;
+    if (nodes.requantize) {
+      CHECK(backend::IsOp(current_call, "qnn.conv2d"));

Review comment:
       It would be clearer to add those CHECK before the assignment of each `current_call`.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -163,24 +149,61 @@ class ACLRuntime : public JSONRuntimeBase {
   struct CachedLayer {
     std::shared_ptr<arm_compute::IFunction> function;
     std::vector<arm_compute::Tensor> inputs;
-    std::vector<arm_compute::Tensor> const_inputs;
     std::vector<arm_compute::Tensor> outputs;
   };
 
+  /*!
+   * \brief Create an ACL tensor given the JSON representation.

Review comment:
       ```suggestion
      * \brief Create an ACL tensor given the JSON representation. If scale and offset are given, then create a quantized ACL tensor.
   ```
   
   - IMHO, this function can be one of the overloading functions of `MakeTensor` and moved to `acl_utils`.
   - `MakeACLTensor` might be clearer.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_utils.cc
##########
@@ -108,6 +111,30 @@ arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
                                     arm_compute::DimensionRoundingType::FLOOR);
 }
 
+arm_compute::DataType MakeDataType(const DLDataType& data_type) {
+  if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 32) {
+    return arm_compute::DataType::F32;
+  } else if (data_type.code == DLDataTypeCode::kDLUInt && data_type.bits == 8) {
+    return arm_compute::DataType::QASYMM8;
+  } else if (data_type.code == DLDataTypeCode::kDLInt && data_type.bits == 32) {
+    return arm_compute::DataType::S32;
+  } else {
+    LOG(FATAL) << "Datatype " << data_type << " unsupported by ACL runtime";
+    return arm_compute::DataType::UNKNOWN;
+  }
+}
+
+template <typename T>
+std::vector<T> GetVectorFromDLTensor(const DLTensor* tensor) {

Review comment:
       I feel that we have an existing utility for this. @zhiics.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_utils.h
##########
@@ -58,35 +58,26 @@ void CheckACLError(const arm_compute::Status& status);
  *
  * \param tensor_rep A JSON tensor representation.
  * \param data (optional) Initialize the tensor with memory.
+ * \param scale (optional) The quantization scale.
+ * \param offset (optional) The quantization offset.
  * \return arm_compute::Tensor.
  */
-arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data = nullptr);
-
-/*!
- * \brief Make an acl tensor from type and shape, without having a JSON representation.
- *
- * \param shape The shape of the tensor to create.
- * \return arm_compute::Tensor.
- */
-arm_compute::Tensor MakeOutputTensor(const std::vector<int64_t>& shape);
+arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data = nullptr,

Review comment:
       I would suggest renaming all these functions to `MakeACLXX` so that we can easily know what they are doing.

##########
File path: tests/python/contrib/test_arm_compute_lib/test_conv2d.py
##########
@@ -127,51 +241,31 @@ def test_conv2d():
 
     device = Device()
     np.random.seed(0)
+    r = random.Random(0)

Review comment:
       I'm not sure if it's a good idea to random shapes in unit tests. It would make the CI nondeterministic on different machine even the random seed is fixed.
   
   cc @zhiics @tqchen 

##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -163,24 +149,61 @@ class ACLRuntime : public JSONRuntimeBase {
   struct CachedLayer {
     std::shared_ptr<arm_compute::IFunction> function;
     std::vector<arm_compute::Tensor> inputs;
-    std::vector<arm_compute::Tensor> const_inputs;
     std::vector<arm_compute::Tensor> outputs;
   };
 
+  /*!
+   * \brief Create an ACL tensor given the JSON representation.
+   *
+   * \param tensor The tensor to represent.
+   * \param scale (optional) The scale of the tensor as an input.
+   * \param offset (optional) The offset of the tensor as an input.
+   * \return ACL Tensor.
+   */
+  arm_compute::Tensor GetACLTensor(const JSONGraphNodeEntry& tensor,
+                                   JSONGraphNodeEntry* scale = nullptr,
+                                   JSONGraphNodeEntry* offset = nullptr) {
+    JSONGraphNode node = nodes_[tensor.id_];
+    void* node_data = nullptr;
+    if (node.GetOpType() == "const") {
+      node_data = data_entry_[EntryID(tensor)]->data;
+    }
+    return GetACLTensor(node, scale, offset, node_data);
+  }
+
+  /*!
+   * \brief Create an ACL tensor given the JSON representation.

Review comment:
       ditto.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_utils.cc
##########
@@ -38,35 +38,38 @@ void CheckACLError(const arm_compute::Status& status) {
   CHECK(status.error_code() == arm_compute::ErrorCode::OK) << "ACL: " << status.error_description();
 }
 
-arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data) {
-  CHECK(tensor_rep.GetOpType() == "input" || tensor_rep.GetOpType() == "const");
+arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data, const DLTensor* scale,
+                               const DLTensor* offset) {
   arm_compute::Tensor tensor;
-  arm_compute::TensorInfo info = MakeTensorInfo(tensor_rep.GetOpShape()[0]);
+  std::vector<int64_t> shape = tensor_rep.GetOpShape()[0];
+  DLDataType dtype = tensor_rep.GetOpDataType()[0];
+  arm_compute::TensorInfo info = MakeTensorInfo(shape, dtype, scale, offset);
   tensor.allocator()->init(info);
   if (data != nullptr) {
     CheckACLError(tensor.allocator()->import_memory(data));
   }
   return tensor;
 }
 
-arm_compute::Tensor MakeOutputTensor(const std::vector<int64_t>& shape) {
-  arm_compute::Tensor tensor;
-  tensor.allocator()->init(MakeTensorInfo(shape));
-  return tensor;
-}
-
-arm_compute::TensorInfo MakeTensorInfo(const std::vector<int64_t>& shape) {
-  arm_compute::TensorShape acl_shape = MakeTensorShape(shape);
-  return arm_compute::TensorInfo(acl_shape, 1, arm_compute::DataType::F32,
-                                 arm_compute::DataLayout::NHWC);
-}
-
-arm_compute::TensorShape MakeTensorShape(const std::vector<int64_t>& shape) {
+arm_compute::TensorInfo MakeTensorInfo(const std::vector<int64_t>& shape, const DLDataType& dtype,
+                                       const DLTensor* scale, const DLTensor* offset) {
   arm_compute::TensorShape acl_shape;
   for (unsigned int i = shape.size(); i > 0; --i) {
     acl_shape.set(shape.size() - i, shape[i - 1]);
   }
-  return acl_shape;
+  arm_compute::DataType acl_dtype = MakeDataType(dtype);
+  arm_compute::TensorInfo info(acl_shape, 1, acl_dtype, arm_compute::DataLayout::NHWC);
+
+  if (scale != nullptr && offset != nullptr) {

Review comment:
       Add a comment here saying we are making a quantized tensor.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_utils.cc
##########
@@ -108,6 +111,30 @@ arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
                                     arm_compute::DimensionRoundingType::FLOOR);
 }
 
+arm_compute::DataType MakeDataType(const DLDataType& data_type) {

Review comment:
       `MakeACLDataType`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459571834



##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -163,24 +149,61 @@ class ACLRuntime : public JSONRuntimeBase {
   struct CachedLayer {
     std::shared_ptr<arm_compute::IFunction> function;
     std::vector<arm_compute::Tensor> inputs;
-    std::vector<arm_compute::Tensor> const_inputs;
     std::vector<arm_compute::Tensor> outputs;
   };
 
+  /*!
+   * \brief Create an ACL tensor given the JSON representation.

Review comment:
       Sounds good, thanks for the suggestion




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
zhiics commented on pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#issuecomment-665762293


   Thanks @lhutton1 @comaniac 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459618012



##########
File path: tests/python/contrib/test_arm_compute_lib/test_conv2d.py
##########
@@ -127,51 +241,31 @@ def test_conv2d():
 
     device = Device()
     np.random.seed(0)
+    r = random.Random(0)

Review comment:
       BTW, why do we need a random input? It looks we can just give the output channel a value or multiple values according to the number of trails, right?

##########
File path: src/runtime/contrib/arm_compute_lib/acl_utils.cc
##########
@@ -108,6 +111,30 @@ arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
                                     arm_compute::DimensionRoundingType::FLOOR);
 }
 
+arm_compute::DataType MakeDataType(const DLDataType& data_type) {
+  if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 32) {
+    return arm_compute::DataType::F32;
+  } else if (data_type.code == DLDataTypeCode::kDLUInt && data_type.bits == 8) {
+    return arm_compute::DataType::QASYMM8;
+  } else if (data_type.code == DLDataTypeCode::kDLInt && data_type.bits == 32) {
+    return arm_compute::DataType::S32;
+  } else {
+    LOG(FATAL) << "Datatype " << data_type << " unsupported by ACL runtime";
+    return arm_compute::DataType::UNKNOWN;
+  }
+}
+
+template <typename T>
+std::vector<T> GetVectorFromDLTensor(const DLTensor* tensor) {

Review comment:
       yeah, maybe we have it but I am not sure where it is either. But do we have to flatten the tensor during runtime?

##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -198,13 +225,31 @@ class ACLRuntime : public JSONRuntimeBase {
 
     arm_compute::Size2D dilation_2d(std::stoi(dilation[0]), std::stoi(dilation[1]));
 
-    layer->outputs.push_back(MakeOutputTensor(node.GetOpShape()[0]));
+    // Collect inputs and outputs, handling both nn.conv2d and qnn.conv2d cases.
+    std::vector<JSONGraphNodeEntry> inputs = node.GetInputs();
+    size_t num_inputs = inputs.size();
+    bool has_bias;
+    if (node.GetOpName() == "qnn.conv2d") {

Review comment:
       `CHECK_GE(num_inputs, 8U)`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459304915



##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -163,24 +149,61 @@ class ACLRuntime : public JSONRuntimeBase {
   struct CachedLayer {
     std::shared_ptr<arm_compute::IFunction> function;
     std::vector<arm_compute::Tensor> inputs;
-    std::vector<arm_compute::Tensor> const_inputs;
     std::vector<arm_compute::Tensor> outputs;
   };
 
+  /*!
+   * \brief Create an ACL tensor given the JSON representation.

Review comment:
       The issue with moving these to `acl_utils` is that `nodes_` and `data_entry_` from the json runtime are required. It might be more simple to remove utils and move the remaining functions into the runtime as static functions?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459624243



##########
File path: src/runtime/contrib/arm_compute_lib/acl_utils.cc
##########
@@ -108,6 +111,30 @@ arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
                                     arm_compute::DimensionRoundingType::FLOOR);
 }
 
+arm_compute::DataType MakeDataType(const DLDataType& data_type) {
+  if (data_type.code == DLDataTypeCode::kDLFloat && data_type.bits == 32) {
+    return arm_compute::DataType::F32;
+  } else if (data_type.code == DLDataTypeCode::kDLUInt && data_type.bits == 8) {
+    return arm_compute::DataType::QASYMM8;
+  } else if (data_type.code == DLDataTypeCode::kDLInt && data_type.bits == 32) {
+    return arm_compute::DataType::S32;
+  } else {
+    LOG(FATAL) << "Datatype " << data_type << " unsupported by ACL runtime";
+    return arm_compute::DataType::UNKNOWN;
+  }
+}
+
+template <typename T>
+std::vector<T> GetVectorFromDLTensor(const DLTensor* tensor) {

Review comment:
       Its not so much about flattening the tensor but rather converting it to the type that ACL expects. `arm_compute::QuantizationInfo(std::vector<float> scale, std::vector<int32_t> offset)`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459578673



##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen.cc
##########
@@ -78,57 +90,83 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
 
  private:
   /*!
-   * \brief Create a JSON representation of a composite convolution.
+   * \brief Extract convolution nodes from a composite function.
    *
-   * \param call The call to be represented.
-   * \return A JSON representation of a specific operator.
+   * \param cn The call node of the composite function.
+   * \return Extracted composite convolution nodes.
    */
-  std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
-    const std::string name = "nn.conv2d";
-    const CallNode* pad = nullptr;
-    const CallNode* conv = nullptr;
-    const CallNode* bias = nullptr;
-    bool has_activation = false;
-
-    // Unpack composite function
+  static CompositeConvNodes UnpackCompositeConvolution(const CallNode* cn) {
+    CompositeConvNodes nodes{};
     const auto* fn = cn->op.as<FunctionNode>();
     CHECK(fn);
     const auto* current_call = fn->body.as<CallNode>();
+    if (backend::IsOp(current_call, "qnn.requantize")) {
+      nodes.requantize = current_call;
+      current_call = current_call->args[0].as<CallNode>();
+    }
     if (backend::IsOp(current_call, "nn.relu")) {
-      has_activation = true;
+      nodes.activation = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
     if (backend::IsOp(current_call, "nn.bias_add")) {
-      bias = current_call;
+      nodes.bias = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
-    CHECK(backend::IsOp(current_call, "nn.conv2d"));
-    conv = current_call;
+    if (nodes.requantize) {
+      CHECK(backend::IsOp(current_call, "qnn.conv2d"));

Review comment:
       Hoping I understood correctly... Not sure that would work because an assert in the first condition enforces that nn.conv2d must come immediately before nn.relu with no other calls in-between. For example if we had nn.conv2d -> nn.bias_add -> nn.relu, using the example above compilation would fail as a call to nn.conv2d is expected immediately after nn.relu.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459321017



##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen.cc
##########
@@ -78,57 +90,83 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
 
  private:
   /*!
-   * \brief Create a JSON representation of a composite convolution.
+   * \brief Extract convolution nodes from a composite function.
    *
-   * \param call The call to be represented.
-   * \return A JSON representation of a specific operator.
+   * \param cn The call node of the composite function.
+   * \return Extracted composite convolution nodes.
    */
-  std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
-    const std::string name = "nn.conv2d";
-    const CallNode* pad = nullptr;
-    const CallNode* conv = nullptr;
-    const CallNode* bias = nullptr;
-    bool has_activation = false;
-
-    // Unpack composite function
+  static CompositeConvNodes UnpackCompositeConvolution(const CallNode* cn) {
+    CompositeConvNodes nodes{};
     const auto* fn = cn->op.as<FunctionNode>();
     CHECK(fn);
     const auto* current_call = fn->body.as<CallNode>();
+    if (backend::IsOp(current_call, "qnn.requantize")) {
+      nodes.requantize = current_call;
+      current_call = current_call->args[0].as<CallNode>();
+    }
     if (backend::IsOp(current_call, "nn.relu")) {
-      has_activation = true;
+      nodes.activation = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
     if (backend::IsOp(current_call, "nn.bias_add")) {
-      bias = current_call;
+      nodes.bias = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
-    CHECK(backend::IsOp(current_call, "nn.conv2d"));
-    conv = current_call;
+    if (nodes.requantize) {
+      CHECK(backend::IsOp(current_call, "qnn.conv2d"));

Review comment:
       Not sure I understand, would you be able to elaborate?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459606833



##########
File path: tests/python/contrib/test_arm_compute_lib/test_conv2d.py
##########
@@ -127,51 +241,31 @@ def test_conv2d():
 
     device = Device()
     np.random.seed(0)
+    r = random.Random(0)

Review comment:
       Exactly, so I meant we should explicitly list some shapes here and only use them for testing instead of randomly generating a set of shapes.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459608769



##########
File path: tests/python/contrib/test_arm_compute_lib/test_conv2d.py
##########
@@ -127,51 +241,31 @@ def test_conv2d():
 
     device = Device()
     np.random.seed(0)
+    r = random.Random(0)

Review comment:
       I'm wondering if this point could be extended to input as well. Currently I seed a random input, should this not be the case?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459604117



##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen.cc
##########
@@ -78,57 +90,83 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
 
  private:
   /*!
-   * \brief Create a JSON representation of a composite convolution.
+   * \brief Extract convolution nodes from a composite function.
    *
-   * \param call The call to be represented.
-   * \return A JSON representation of a specific operator.
+   * \param cn The call node of the composite function.
+   * \return Extracted composite convolution nodes.
    */
-  std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
-    const std::string name = "nn.conv2d";
-    const CallNode* pad = nullptr;
-    const CallNode* conv = nullptr;
-    const CallNode* bias = nullptr;
-    bool has_activation = false;
-
-    // Unpack composite function
+  static CompositeConvNodes UnpackCompositeConvolution(const CallNode* cn) {
+    CompositeConvNodes nodes{};
     const auto* fn = cn->op.as<FunctionNode>();
     CHECK(fn);
     const auto* current_call = fn->body.as<CallNode>();
+    if (backend::IsOp(current_call, "qnn.requantize")) {
+      nodes.requantize = current_call;
+      current_call = current_call->args[0].as<CallNode>();
+    }
     if (backend::IsOp(current_call, "nn.relu")) {
-      has_activation = true;
+      nodes.activation = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
     if (backend::IsOp(current_call, "nn.bias_add")) {
-      bias = current_call;
+      nodes.bias = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
-    CHECK(backend::IsOp(current_call, "nn.conv2d"));
-    conv = current_call;
+    if (nodes.requantize) {
+      CHECK(backend::IsOp(current_call, "qnn.conv2d"));

Review comment:
       Ahh I see. `current_call` might be updated multiple times and you just want to check the final one. Then we can keep the current solution and maybe just add a comment to clarify. Sorry for the misunderstanding.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] lhutton1 commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
lhutton1 commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459626275



##########
File path: tests/python/contrib/test_arm_compute_lib/test_conv2d.py
##########
@@ -127,51 +241,31 @@ def test_conv2d():
 
     device = Device()
     np.random.seed(0)
+    r = random.Random(0)

Review comment:
       Makes sense, thanks for the explanation @comaniac. I'll change this to fixed shapes and channels. Yep that's correct @zhiics 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6109: [BYOC][ACL] Support asymmetric per-layer quantized operators

Posted by GitBox <gi...@apache.org>.
comaniac commented on a change in pull request #6109:
URL: https://github.com/apache/incubator-tvm/pull/6109#discussion_r459551485



##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen.cc
##########
@@ -78,57 +90,83 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
 
  private:
   /*!
-   * \brief Create a JSON representation of a composite convolution.
+   * \brief Extract convolution nodes from a composite function.
    *
-   * \param call The call to be represented.
-   * \return A JSON representation of a specific operator.
+   * \param cn The call node of the composite function.
+   * \return Extracted composite convolution nodes.
    */
-  std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
-    const std::string name = "nn.conv2d";
-    const CallNode* pad = nullptr;
-    const CallNode* conv = nullptr;
-    const CallNode* bias = nullptr;
-    bool has_activation = false;
-
-    // Unpack composite function
+  static CompositeConvNodes UnpackCompositeConvolution(const CallNode* cn) {
+    CompositeConvNodes nodes{};
     const auto* fn = cn->op.as<FunctionNode>();
     CHECK(fn);
     const auto* current_call = fn->body.as<CallNode>();
+    if (backend::IsOp(current_call, "qnn.requantize")) {
+      nodes.requantize = current_call;
+      current_call = current_call->args[0].as<CallNode>();
+    }
     if (backend::IsOp(current_call, "nn.relu")) {
-      has_activation = true;
+      nodes.activation = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
     if (backend::IsOp(current_call, "nn.bias_add")) {
-      bias = current_call;
+      nodes.bias = current_call;
       current_call = current_call->args[0].as<CallNode>();
     }
-    CHECK(backend::IsOp(current_call, "nn.conv2d"));
-    conv = current_call;
+    if (nodes.requantize) {
+      CHECK(backend::IsOp(current_call, "qnn.conv2d"));

Review comment:
       I meant we could just have the CHECK embedded to make the logic clearer.
   
   ```c
       if (backend::IsOp(current_call, "nn.relu")) {
          nodes.activation = current_call;
          current_call = current_call->args[0].as<CallNode>();
          CHECK(backend::IsOp(current_call, "qnn.conv2d"));
       }
       if ( ... ) {
           CHECK(backend::IsOp(current_call, "nn.conv2d"));
       }
       if ( ... ) {
           CHECK(backend::IsOp(current_call, "nn.conv2d"));
       }
   ```

##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -163,24 +149,61 @@ class ACLRuntime : public JSONRuntimeBase {
   struct CachedLayer {
     std::shared_ptr<arm_compute::IFunction> function;
     std::vector<arm_compute::Tensor> inputs;
-    std::vector<arm_compute::Tensor> const_inputs;
     std::vector<arm_compute::Tensor> outputs;
   };
 
+  /*!
+   * \brief Create an ACL tensor given the JSON representation.

Review comment:
       Hmmm I see. Then this is different from other functions in `acl_utils`. We could use the different naming convention. Specifically, we could keep the function here and rename it to something like `MakeACLTensorFromDataEntry` (you might have a better name for it).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org