You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/24 08:42:41 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7161: [AutoScheduler] Add layout rewrite support for dense and batch matmul on CPU

comaniac commented on a change in pull request #7161:
URL: https://github.com/apache/tvm/pull/7161#discussion_r548446708



##########
File path: python/tvm/topi/nn/conv2d.py
##########
@@ -361,6 +361,12 @@ def conv2d_nhwc(
     dilation: int or a list/tuple of two ints
         dilation size, or [dilation_height, dilation_width]
 
+    out_dtype: str = "float32",
+        The type of outpu tensor

Review comment:
       ```suggestion
           The type of output tensor
   ```
   

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -1452,5 +1453,29 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout")
       return index_rewriter.Rewrite(body);
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout")
+    .set_body_typed([](String rewritten_layout, Array<String> axis_names) {
+      Array<PrimExpr> shape;
+      std::vector<std::string> extracted_names;
+      topi::parse_auto_scheduler_layout(rewritten_layout, &shape, &extracted_names);
+
+      Array<PrimExpr> ret(axis_names.size(), 1);
+
+      for (size_t i = 0; i < axis_names.size(); ++i) {
+        bool found = false;
+        for (size_t j = 0; j < extracted_names.size(); ++j) {
+          if (axis_names[i] == extracted_names[j]) {
+            ret.Set(i, ret[i] * shape[j]);
+            found = true;

Review comment:
       nit: break the loop when found.

##########
File path: src/relay/transforms/auto_scheduler_layout_rewrite.cc
##########
@@ -150,6 +156,12 @@ TVM_REGISTER_GLOBAL("relay.attrs.get_auto_scheduler_rewritten_layout")
     .set_body_typed([](const Attrs& attrs) {
       if (attrs->IsInstance<Conv2DAttrs>()) {
         return attrs.as<Conv2DAttrs>()->auto_scheduler_rewritten_layout;
+      } else if (attrs->IsInstance<DenseAttrs>()) {
+        return attrs.as<DenseAttrs>()->auto_scheduler_rewritten_layout;
+      } else if (attrs->IsInstance<BatchMatmulAttrs>()) {
+        return attrs.as<BatchMatmulAttrs>()->auto_scheduler_rewritten_layout;
+      } else {
+        LOG(FATAL) << "Unhandled attribute: " << attrs;

Review comment:
       LOG(FATAL) throws an error so it will make the following return unreachable.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org