You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/28 16:38:16 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6532: [BYOC][ACL] Support add operation

comaniac commented on a change in pull request #6532:
URL: https://github.com/apache/incubator-tvm/pull/6532#discussion_r496082766



##########
File path: python/tvm/relay/op/contrib/arm_compute_lib.py
##########
@@ -345,3 +345,23 @@ def maximum(attrs, args):
     type_a = args[0].checked_type
     type_b = args[0].checked_type
     return (type_a.dtype == "float32") and (type_b.dtype == "float32")
+
+
+@tvm.ir.register_op_attr("add", "target.arm_compute_lib")
+def add(attrs, args):
+    """Check if the external ACL codegen for add should be used."""
+    for typ in [args[0].checked_type, args[1].checked_type]:
+        if typ.dtype not in ["float32"]:

Review comment:
       ```suggestion
           if typ.dtype != "float32":
   ```

##########
File path: python/tvm/relay/op/contrib/arm_compute_lib.py
##########
@@ -345,3 +345,23 @@ def maximum(attrs, args):
     type_a = args[0].checked_type
     type_b = args[0].checked_type
     return (type_a.dtype == "float32") and (type_b.dtype == "float32")
+
+
+@tvm.ir.register_op_attr("add", "target.arm_compute_lib")
+def add(attrs, args):
+    """Check if the external ACL codegen for add should be used."""
+    for typ in [args[0].checked_type, args[1].checked_type]:
+        if typ.dtype not in ["float32"]:
+            return False
+
+    return True
+
+
+@tvm.ir.register_op_attr("qnn.add", "target.arm_compute_lib")
+def qnn_add(attrs, args):
+    """Check if the external ACL codegen for add should be used."""
+    for typ in [args[0].checked_type, args[1].checked_type]:
+        if typ.dtype not in ["uint8"]:

Review comment:
       ```suggestion
           if typ.dtype != "uint8":
   ```

##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -140,8 +141,13 @@ class ACLRuntime : public JSONRuntimeBase {
           CreateGlobalPoolingLayer(&layer_, node);
         } else if ("reshape" == op_name) {
           CreateReshapeLayer(&layer_, node);
+<<<<<<< HEAD
         } else if ("maximum" == op_name) {
           CreateMaximumLayer(&layer_, node);
+=======
+        } else if ("add" == op_name || "qnn.add" == op_name) {
+          CreateAddLayer(&layer_, node);
+>>>>>>> a7fa43daf... ACL: add operation

Review comment:
       Fix conflict.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -416,15 +422,54 @@ class ACLRuntime : public JSONRuntimeBase {
     auto function = std::make_shared<arm_compute::NEElementwiseMax>();
     function->configure(&layer->inputs[0], &layer->inputs[1], &layer->outputs[0]);
     layer->function = function;
-  }
 
-  /*! \brief Allow ACL functions to request auxiliary memory from TVM. */
-  ACLAllocator allocator_;
-  /*!
-   * \brief The network layers represented by acl functions.
-   * \note Currently only supports a single layer.
-   */
-  CachedLayer layer_;
+    /*!
+     * \brief Creates an add/qnn.add layer
+     *
+     * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
+     * \param node  The JSON representation of the operator.
+     */
+    void CreateAddLayer(CachedLayer * layer, const JSONGraphNode& node) {
+      auto op_name = node.GetOpName();
+      if ("add" == op_name) {
+        layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[0]));
+        layer->inputs.push_back(MakeACLTensorFromJSONEntry(node.GetInputs()[1]));
+        layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
+      } else if ("qnn.add" == op_name) {
+        layer->inputs.push_back(MakeACLTensorFromJSONEntry(
+            node.GetInputs()[0], &node.GetInputs()[2], &node.GetInputs()[3]));
+        layer->inputs.push_back(MakeACLTensorFromJSONEntry(
+            node.GetInputs()[1], &node.GetInputs()[4], &node.GetInputs()[5]));
+        layer->outputs.push_back(
+            MakeACLTensorFromJSONNode(node, &node.GetInputs()[6], &node.GetInputs()[7]));
+      } else {
+        LOG(FATAL) << "Unsupported op: " << op_name;

Review comment:
       IMO here can simply throw out assertion, since you have checked if `op_name` is `add` or `qnn.add` before invoking this function in `BuildEngine`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org