You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/08 09:42:19 UTC

[GitHub] [tvm] shingjan opened a new pull request #9680: [WIP][TVMScript] Improve printer for TIR syntax sugar

shingjan opened a new pull request #9680:
URL: https://github.com/apache/tvm/pull/9680


   1. For reads & writes:
   
   Before this PR
   ```
   @T.prim_func
   def func(a: T.handle, b: T.handle, c: T.handle) -> None:
               ...
               T.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
               T.writes([C[vi, vj], A[vi, vk]])
               ...
   ```
   After this PR
   ```
   @T.prim_func
   def func(a: T.handle, b: T.handle, c: T.handle) -> None:
               ...
               T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
               T.writes(C[vi, vj], A[vi, vk])
               ...
   ```
   
   2. For loops:
   
   Before this PR
   ```
   @T.prim_func
   def func(a: T.handle) -> None:
       ...
       for i in T.serial(0, 128):
           for j in T.parallel(0, 128):
               for k in T.vectorized(0, 128):
                   for x in T.unroll(0, 128):
                       for y in T.thread_binding(0, 128, thread="threadIdx.x"):
                           ...
   ```
   After this PR
   ```
   ```
   
   3. For T.match_buffer():
   
   Before this PR
   ```
   ```
   After this PR
   ```
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9680: [TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9680:
URL: https://github.com/apache/tvm/pull/9680#discussion_r765504067



##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +479,47 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as syntax sugar
+bool TVMScriptPrinter::IsMatchBufferSugarred(const Buffer& buf) {
+  if (memo_var_.find(buf->data) != memo_var_.end()) {
+    return false;
+  }
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  if (buf->elem_offset->IsInstance<VarNode>()) {
+    Var elem_offset = Downcast<Var>(buf->elem_offset);
+    if (memo_var_.find(elem_offset) != memo_var_.end()) {
+      return false;
+    }
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  if (buf.scope() != "global") {
+    return false;
+  }
+  if (buf->data_alignment != runtime::kAllocAlignment) {
+    return false;
+  }
+  if (buf->offset_factor != 1) {
+    return false;
+  }
+  if (buf->buffer_type != 1) {
+    return false;
+  }
+  return true;
+}
+
+Doc TVMScriptPrinter::MatchBufferDeclaration(const Buffer& buffer) {
+  Doc doc = Print(buffer->shape);
+  doc << ", dtype=" << PrintDType(buffer->dtype);

Review comment:
       `dtype` is dropped here and bracket is used here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9680: [TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9680:
URL: https://github.com/apache/tvm/pull/9680#discussion_r766036466



##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +480,62 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as T.buffer in prim_func arguments
+bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
+  if (memo_var_.find(buf->data) != memo_var_.end()) {
+    return false;
+  }
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  if (buf->elem_offset->IsInstance<VarNode>()) {
+    Var elem_offset = Downcast<Var>(buf->elem_offset);
+    if (memo_var_.find(elem_offset) != memo_var_.end()) {
+      return false;
+    }
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  if (buf.scope() != "global") {
+    return false;
+  }
+  if (buf->data_alignment != runtime::kAllocAlignment) {
+    return false;
+  }
+  if (buf->offset_factor != 1) {
+    return false;
+  }
+  if (buf->buffer_type != BufferType::kDefault) {
+    return false;
+  }
+  return true;
+}
+
+Doc TVMScriptPrinter::MatchBufferDeclaration(const Buffer& buffer) {
+  Doc doc;
+  doc << tir_prefix_ << ".Buffer[" << PrintTuple(buffer->shape.as<ArrayNode>());
+  doc << ", " << PrintDType(buffer->dtype) << "]";
+  return doc;
+}
+
+// print array out as tuple with parentheses
+Doc TVMScriptPrinter::PrintTuple(const ArrayNode* op) {
+  Doc doc;
+  doc << '(';
+  for (size_t i = 0; i < op->size(); ++i) {
+    if (i != 0) {
+      doc << ", ";
+    }
+    doc << Print(op->at(i));
+  }
+  doc << ')';

Review comment:
       if tuple size is 1, extra `,` is needed at the end, i.e. `(n,)` instead of `(n)`

##########
File path: tests/python/unittest/test_tvmscript_roundtrip.py
##########
@@ -79,6 +79,7 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None:
 
 def test_opt_gemm_normalize():
     mod = Module1
+    print(mod.script(show_meta=True))

Review comment:
       remove this

##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -220,6 +221,14 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
   Doc AllocBuf(const Buffer& buffer);
   void TryDeallocVar(const Var& var);
   bool ContainsOptionalInfo(const Stmt& stmt);
+  /*! Helper function for match buffer printing */
+  /*!
+   * @brief check if a T.match_buffer decl is syntax sugarred

Review comment:
       uses `\brief` for consistency, the comment should also be made more clear

##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -1220,6 +1300,19 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
   std::vector<Doc> params;
   for (const auto& param : op->params) {
     var_not_in_headers_.insert(param.get());
+    auto it = op->buffer_map.find(param);
+    // check if this param is a T.handle
+    if (it != op->buffer_map.end()) {
+      // check if this match_buffer has only the first two arguments specified
+      const auto buf = (*it).second;
+      if (IsSimpleBuffer(buf)) {
+        buf_not_in_headers_.insert(buf.get());
+        Doc buf_param_doc;
+        buf_param_doc << buf->name << ": " << MatchBufferDeclaration(buf);

Review comment:
       This is problematic. See the impl of `AllocBuf` and `AllocBufferDeclaration`, I can see at least 
   `memo_buf_`, `memo_var_` are not correctly updated if the buffer doesn't go through these two functions. Possible solution is to unify the new behavior into these two functions




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9680: [TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9680:
URL: https://github.com/apache/tvm/pull/9680#discussion_r765383155



##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +479,47 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as syntax sugar

Review comment:
       Mention the exact behavior, 'syntax sugar' is not a clear description 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 merged pull request #9680: [TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 merged pull request #9680:
URL: https://github.com/apache/tvm/pull/9680


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9680: [WIP][TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9680:
URL: https://github.com/apache/tvm/pull/9680#discussion_r765363578



##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -1105,6 +1108,19 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
   return block_attr_doc;
 }
 
+// This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
+// List. Therefore the brackets are removed before and after printing arguments out
+Doc TVMScriptPrinter::PrintBlockAttrArray(const ArrayNode* op) {

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #9680: [TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #9680:
URL: https://github.com/apache/tvm/pull/9680#discussion_r765510977



##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +479,47 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as syntax sugar
+bool TVMScriptPrinter::IsMatchBufferSugarred(const Buffer& buf) {
+  if (memo_var_.find(buf->data) != memo_var_.end()) {
+    return false;
+  }
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  if (buf->elem_offset->IsInstance<VarNode>()) {
+    Var elem_offset = Downcast<Var>(buf->elem_offset);
+    if (memo_var_.find(elem_offset) != memo_var_.end()) {
+      return false;
+    }
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  if (buf.scope() != "global") {
+    return false;
+  }
+  if (buf->data_alignment != runtime::kAllocAlignment) {
+    return false;
+  }
+  if (buf->offset_factor != 1) {
+    return false;
+  }
+  if (buf->buffer_type != 1) {

Review comment:
       Please use BufferType:: kDefault




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] shingjan commented on a change in pull request #9680: [TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
shingjan commented on a change in pull request #9680:
URL: https://github.com/apache/tvm/pull/9680#discussion_r765503503



##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +479,47 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as syntax sugar
+bool TVMScriptPrinter::IsMatchBufferSugarred(const Buffer& buf) {
+  if (memo_var_.find(buf->data) != memo_var_.end()) {
+    return false;
+  }
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  if (buf->elem_offset->IsInstance<VarNode>()) {
+    Var elem_offset = Downcast<Var>(buf->elem_offset);
+    if (memo_var_.find(elem_offset) != memo_var_.end()) {
+      return false;
+    }
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  if (buf.scope() != "global") {
+    return false;
+  }
+  if (buf->data_alignment != runtime::kAllocAlignment) {
+    return false;
+  }
+  if (buf->offset_factor != 1) {
+    return false;
+  }
+  if (buf->buffer_type != 1) {

Review comment:
       Yes. `BufferType` is an enum in which 1 is for `kDefault` and 2 is for `kAutoBroadcast`. Details [here](https://github.com/apache/tvm/blob/0e0adf514c41dcab27ef60df60e84ce28a6effaf/include/tvm/tir/buffer.h#L42)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9680: [TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9680:
URL: https://github.com/apache/tvm/pull/9680#discussion_r765380321



##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +479,47 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as syntax sugar
+bool TVMScriptPrinter::IsMatchBufferSugarred(const Buffer& buf) {

Review comment:
       ```suggestion
   bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
   ```
   

##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +479,47 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as syntax sugar
+bool TVMScriptPrinter::IsMatchBufferSugarred(const Buffer& buf) {
+  if (memo_var_.find(buf->data) != memo_var_.end()) {
+    return false;
+  }
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  if (buf->elem_offset->IsInstance<VarNode>()) {
+    Var elem_offset = Downcast<Var>(buf->elem_offset);
+    if (memo_var_.find(elem_offset) != memo_var_.end()) {
+      return false;
+    }
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  if (buf.scope() != "global") {
+    return false;
+  }
+  if (buf->data_alignment != runtime::kAllocAlignment) {
+    return false;
+  }
+  if (buf->offset_factor != 1) {
+    return false;
+  }
+  if (buf->buffer_type != 1) {
+    return false;
+  }
+  return true;
+}
+
+Doc TVMScriptPrinter::MatchBufferDeclaration(const Buffer& buffer) {
+  Doc doc = Print(buffer->shape);

Review comment:
       Printing `T.Buffer` should also be part of this function

##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +479,47 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as syntax sugar
+bool TVMScriptPrinter::IsMatchBufferSugarred(const Buffer& buf) {
+  if (memo_var_.find(buf->data) != memo_var_.end()) {
+    return false;
+  }
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  if (buf->elem_offset->IsInstance<VarNode>()) {
+    Var elem_offset = Downcast<Var>(buf->elem_offset);
+    if (memo_var_.find(elem_offset) != memo_var_.end()) {
+      return false;
+    }
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  if (buf.scope() != "global") {
+    return false;
+  }
+  if (buf->data_alignment != runtime::kAllocAlignment) {
+    return false;
+  }
+  if (buf->offset_factor != 1) {
+    return false;
+  }
+  if (buf->buffer_type != 1) {
+    return false;
+  }
+  return true;
+}
+
+Doc TVMScriptPrinter::MatchBufferDeclaration(const Buffer& buffer) {
+  Doc doc = Print(buffer->shape);
+  doc << ", dtype=" << PrintDType(buffer->dtype);

Review comment:
       Let's use bracket syntax which is simpler

##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -471,6 +479,47 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
   return doc;
 }
 
+// check if all arguments, except the first two, are specified for T.match_buffer
+// if not, then this match buffer is printed out as syntax sugar
+bool TVMScriptPrinter::IsMatchBufferSugarred(const Buffer& buf) {
+  if (memo_var_.find(buf->data) != memo_var_.end()) {
+    return false;
+  }
+  if (!buf->strides.empty()) {
+    return false;
+  }
+  if (buf->elem_offset->IsInstance<VarNode>()) {
+    Var elem_offset = Downcast<Var>(buf->elem_offset);
+    if (memo_var_.find(elem_offset) != memo_var_.end()) {
+      return false;
+    }
+  } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
+    IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
+    if (elem_offset->value != 0) {
+      return false;
+    }
+  }
+  if (buf.scope() != "global") {
+    return false;
+  }
+  if (buf->data_alignment != runtime::kAllocAlignment) {
+    return false;
+  }
+  if (buf->offset_factor != 1) {
+    return false;
+  }
+  if (buf->buffer_type != 1) {

Review comment:
       What is 1? Is it some enum?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on a change in pull request #9680: [WIP][TVMScript] Improve printer for TIR syntax sugar

Posted by GitBox <gi...@apache.org>.
vinx13 commented on a change in pull request #9680:
URL: https://github.com/apache/tvm/pull/9680#discussion_r765234469



##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -1105,6 +1108,19 @@ Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
   return block_attr_doc;
 }
 
+// This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
+// List. Therefore the brackets are removed before and after printing arguments out
+Doc TVMScriptPrinter::PrintBlockAttrArray(const ArrayNode* op) {

Review comment:
       nit:
   ```suggestion
   Doc TVMScriptPrinter::PrintExpandedArray(const ArrayNode* op) {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org