You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/01/19 17:56:05 UTC

[GitHub] [incubator-tvm] masahi commented on a change in pull request #4741: [External codegen] Add test cases for fused ops with manual annotation

masahi commented on a change in pull request #4741: [External codegen] Add test cases for fused ops with manual annotation
URL: https://github.com/apache/incubator-tvm/pull/4741#discussion_r368310856
 
 

 ##########
 File path: src/relay/backend/contrib/dnnl/codegen.cc
 ##########
 @@ -50,82 +51,109 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase {
     out_.push_back({node->name_hint(), 0});
   }
 
-  void VisitExpr_(const TupleGetItemNode* op) final {
-    // Do nothing
-  }
-
   void VisitExpr_(const CallNode* call) final {
-    std::ostringstream decl_stream;
-    std::ostringstream buf_stream;
-    // Args: ID
-    std::vector<std::string> args;
+    struct Output {
+      std::string decl, buf;
+      int out_size = 1;
+      std::string out;
+    };
+
+    auto generate_body = [=](const CallNode* root_call, const std::string& func_name,
+                             const std::vector<std::string>& args,
+                             const std::vector<std::string>& fused_func_args) {
+      // Make function call with input buffers when visiting arguments
+      bool first = true;
+      std::ostringstream arg_stream;
+      arg_stream << "(";
+      for (size_t i = 0; i < root_call->args.size(); ++i) {
+        VisitExpr(root_call->args[i]);
+        for (auto out : out_) {
+          if (!first) {
+            arg_stream << ", ";
+          }
+          first = false;
+          arg_stream << out.first;
+        }
+      }
+
+      for (auto arg_name : fused_func_args) {
+        arg_stream << ", " << arg_name;
+      }
+
+      // Analyze the output buffer
+      auto type_node = root_call->checked_type().as<TensorTypeNode>();
+      CHECK(type_node != nullptr && runtime::TypeMatch(type_node->dtype, kDLFloat, 32))
+          << "Only support single output tensor with float type";
+
+      auto out_shape = GetShape(root_call->checked_type());
+
+      Output ret;
+      ret.out = "buf_" + std::to_string(buf_idx_++);
+      ret.out_size = std::accumulate(out_shape.begin(), out_shape.end(), 1, std::multiplies<int>());
+
+      this->PrintIndents();
+
+      std::ostringstream buf_stream;
+      buf_stream << "float* " << ret.out << " = (float*)std::malloc(4 * " << ret.out_size << ");";
+      ret.buf = buf_stream.str();
 
-    // Get the arguments for various DNNL kernels.
-    if (IsOp(call, "nn.conv2d")) {
-      decl_stream << "dnnl_conv2d";
-      args = Conv2d(call);
+      arg_stream << ", " << ret.out;
+      // Attach attribute arguments
+      for (size_t i = 0; i < args.size(); ++i) {
+        arg_stream << ", " << args[i];
+      }
+      arg_stream << ");";
+      ret.decl = func_name + arg_stream.str();
+
+      return ret;
+    };
+
+    Output ret;
+    if (auto conv_call = DetectFusedConv2DBiasReLU(call)) {
 
 Review comment:
   The idea is for it to serve as an example of handling fused ops inside external codegen. I assume dnnl backend itself is not meant to be used in production; The purpose is to be a more realistic example than CodegenC, so I thought why don't we add an example of how to handle fused ops. I never intended to cover other fusion cases.
   
   Since we are  trying to be so nice to new backend implementers (who might not be familiar with TVM internals) as to add convenient op level annotation and semi automatic fusion mechanism etc for them,  I don't think it is reasonable to expect them to figure out how to handle more complicated but often common cases (like fusion) on their own. Hope this make sense.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services