You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/29 20:03:25 UTC

[GitHub] [tvm] Laurawly commented on a change in pull request #9595: [CUTLASS] Initial conv2d support

Laurawly commented on a change in pull request #9595:
URL: https://github.com/apache/tvm/pull/9595#discussion_r758694948



##########
File path: src/relay/backend/contrib/cutlass/codegen.cc
##########
@@ -234,6 +238,112 @@ std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs,
   return gemm_decl.str();
 }
 
+Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {
+  Str2StrMap args = ArgsCommon(attrs);
+  auto arg0_shape = attrs["arg0_shape"].as<ArrayNode>();
+  auto arg1_shape = attrs["arg1_shape"].as<ArrayNode>();
+  auto out_shape = attrs["ret_shape"].as<ArrayNode>();
+  args["N"] = GetDimAsStr(arg0_shape->at(0));
+  args["H"] = GetDimAsStr(arg0_shape->at(1));
+  args["W"] = GetDimAsStr(arg0_shape->at(2));
+  args["C"] = GetDimAsStr(arg0_shape->at(3));
+  args["K"] = GetDimAsStr(arg1_shape->at(0));
+  args["R"] = GetDimAsStr(arg1_shape->at(1));
+  args["S"] = GetDimAsStr(arg1_shape->at(1));
+  args["P"] = GetDimAsStr(out_shape->at(1));
+  args["Q"] = GetDimAsStr(out_shape->at(2));
+  args["pad_h"] = GetDimAsStr(attrs["padding"].as<ArrayNode>()->at(0));
+  args["pad_w"] = GetDimAsStr(attrs["padding"].as<ArrayNode>()->at(1));
+  args["stride_h"] = GetDimAsStr(attrs["strides"].as<ArrayNode>()->at(0));
+  args["stride_w"] = GetDimAsStr(attrs["strides"].as<ArrayNode>()->at(1));
+  args["dilation_h"] = GetDimAsStr(attrs["dilation"].as<ArrayNode>()->at(0));
+  args["dilation_w"] = GetDimAsStr(attrs["dilation"].as<ArrayNode>()->at(1));
+  return args;
+}
+
+std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
+                     const std::vector<std::string>& func_args) {
+  std::ostringstream conv2d_decl;
+  CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
+  CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n");
+  CutlassPrint(conv2d_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n");
+
+  CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n");
+  CutlassPrint(conv2d_decl, attrs.at("op_def"));
+  CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
+                                " = cutlass::conv::device::ImplicitGemmConvolution<" +
+                                attrs.at("op_name") + ">;\n");
+  CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") + ";\n");
+
+  auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) {
+    if (attrs.at(axis) == kAnyDim) {
+      return var_name + "->shape[" + std::to_string(axis_idx) + "]";
+    } else {
+      return attrs.at(axis);
+    }
+  };
+
+  CutlassPrint(conv2d_decl, "int N = " + get_dim("N", func_args[0], 0) + ";\n");
+  CutlassPrint(conv2d_decl, "int H = " + get_dim("H", func_args[0], 1) + ";\n");
+  CutlassPrint(conv2d_decl, "int W = " + get_dim("W", func_args[0], 2) + ";\n");
+  CutlassPrint(conv2d_decl, "int C = " + attrs.at("C") + ";\n");
+  CutlassPrint(conv2d_decl, "int K = " + attrs.at("K") + ";\n");
+  CutlassPrint(conv2d_decl, "int R = " + attrs.at("R") + ";\n");
+  CutlassPrint(conv2d_decl, "int S = " + attrs.at("S") + ";\n");
+  CutlassPrint(conv2d_decl, "int P = " + get_dim("P", "out0", 1) + ";\n");
+  CutlassPrint(conv2d_decl, "int Q = " + get_dim("Q", "out0", 2) + ";\n");
+  CutlassPrint(conv2d_decl, "int pad_h = " + attrs.at("pad_h") + ";\n");
+  CutlassPrint(conv2d_decl, "int pad_w = " + attrs.at("pad_w") + ";\n");
+  CutlassPrint(conv2d_decl, "int stride_h = " + attrs.at("stride_h") + ";\n");
+  CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n");
+  CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n");
+  CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n");
+
+  CutlassPrint(
+      conv2d_decl,
+      "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, "
+      "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, 1);\n");
+
+  ICHECK(func_args.size() >= 2);
+  CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n");
+  CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");
+  CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
+  CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
+  CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
+
+  CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n");
+  CutlassPrint(conv2d_decl,
+               "TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n");
+  CutlassPrint(conv2d_decl,
+               "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n");
+  CutlassPrint(conv2d_decl,
+               "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n");
+  CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
+  CutlassPrint(conv2d_decl, " problem_size,\n");
+  CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
+  CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
+  CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
+  CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
+  CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
+  CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");
+
+  CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
+  // Allocate workspace memory
+  CutlassPrint(conv2d_decl,
+               "cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);\n");

Review comment:
       There's memory leak by allocating workspace this way. @ZihengJiang 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org