You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/16 00:00:38 UTC

[GitHub] [tvm] junrushao1994 commented on a change in pull request #8509: [TIR] Tir constants integration into compilation pipeline

junrushao1994 commented on a change in pull request #8509:
URL: https://github.com/apache/tvm/pull/8509#discussion_r785373106



##########
File path: include/tvm/tir/stmt.h
##########
@@ -585,6 +585,118 @@ class Allocate : public Stmt {
   TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
 };
 
+/*!
+ * \brief Describes one parameter that should be linked into the generated module.
+ *
+ * When parameters are to be linked in with generated code (i.e. on target_host-compatible
+ * backends), Relay attaches instances of this object to a global TIR function. Code-generators
+ * use the information contained in this node to include the parameter data in the generated
+ * module.
+ */
+class LinkedParamNode : public Object {
+ public:
+  /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
+  int64_t id;
+
+  /*! \brief Parameter data which should get linked into the final module. */
+  ::tvm::runtime::NDArray param;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("id", &id);
+    v->Visit("param", &param);
+  }
+
+  static constexpr const char* _type_key = "tir.LinkedParam";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
+};
+
+/*!
+ * \brief Managed reference to LinkedParamNode.
+ */
+class LinkedParam : public ObjectRef {
+ public:
+  TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
+};
+
+/*!
+ * \brief Allocate a buffer that can be used in body.
+ */
+class AllocateConstNode : public StmtNode {
+ public:
+  /*! \brief The buffer variable. */
+  Var buffer_var;
+  /*! \brief The optional data associated to the constant.
+   */
+  Optional<runtime::NDArray> data;
+  /*! \brief If the PrimFunc containing the Stmt is added to IRModule,
+       this is an optional index to indicate the index within
+       "Constants" attribute, that is a Array<NDArray> of IRModule.
+   */
+  Optional<Integer> irmod_storage_idx;
+  /*! \brief The type of the buffer. */
+  DataType dtype;
+  /*! \brief The extents of the buffer. */
+  Array<PrimExpr> extents;
+  /*! \brief The body to be executed. */
+  Stmt body;

Review comment:
       nit: let's organize the fields in the same way as AllocateNode a few lines above. BTW, did we forget the `annotations` field?

##########
File path: include/tvm/ir/module.h
##########
@@ -32,6 +32,7 @@
 #include <tvm/runtime/container/array.h>
 #include <tvm/runtime/container/map.h>
 #include <tvm/runtime/container/string.h>
+#include <tvm/tir/function.h>

Review comment:
       Note that in general, the header file dependency requires that `tvm/tir/*` depends on `tvm/ir/*`, and thus including this may introduce a cyclic dependency. Let's find other ways around

##########
File path: python/tvm/topi/nn/dense.py
##########
@@ -81,7 +81,9 @@ def matmul(
         out_dim, red_dim = tensor_b.shape
     else:
         red_dim, out_dim = tensor_b.shape
-    assert in_dim == red_dim
+
+    # cmp should be done by values (i.e. by string representation proxies)
+    assert str(in_dim) == str(red_dim)

Review comment:
       Good point, but using string sounds a bit scary to me :-)
   
   IIUC, how about:
   
   ```suggestion
       assert int(in_dim) == int(red_dim)
   ```
   

##########
File path: include/tvm/tir/stmt.h
##########
@@ -585,6 +585,118 @@ class Allocate : public Stmt {
   TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
 };
 
+/*!
+ * \brief Describes one parameter that should be linked into the generated module.
+ *
+ * When parameters are to be linked in with generated code (i.e. on target_host-compatible
+ * backends), Relay attaches instances of this object to a global TIR function. Code-generators
+ * use the information contained in this node to include the parameter data in the generated
+ * module.
+ */
+class LinkedParamNode : public Object {
+ public:
+  /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
+  int64_t id;
+
+  /*! \brief Parameter data which should get linked into the final module. */
+  ::tvm::runtime::NDArray param;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("id", &id);
+    v->Visit("param", &param);
+  }
+
+  static constexpr const char* _type_key = "tir.LinkedParam";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
+};
+
+/*!
+ * \brief Managed reference to LinkedParamNode.
+ */
+class LinkedParam : public ObjectRef {
+ public:
+  TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
+};
+
+/*!
+ * \brief Allocate a buffer that can be used in body.
+ */
+class AllocateConstNode : public StmtNode {
+ public:
+  /*! \brief The buffer variable. */
+  Var buffer_var;
+  /*! \brief The optional data associated to the constant.
+   */
+  Optional<runtime::NDArray> data;
+  /*! \brief If the PrimFunc containing the Stmt is added to IRModule,
+       this is an optional index to indicate the index within
+       "Constants" attribute, that is a Array<NDArray> of IRModule.
+   */
+  Optional<Integer> irmod_storage_idx;
+  /*! \brief The type of the buffer. */
+  DataType dtype;
+  /*! \brief The extents of the buffer. */
+  Array<PrimExpr> extents;
+  /*! \brief The body to be executed. */
+  Stmt body;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("buffer_var", &buffer_var);
+    v->Visit("dtype", &dtype);
+    v->Visit("extents", &extents);
+    v->Visit("body", &body);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
+    return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
+           equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce.DefHash(buffer_var);
+    hash_reduce(dtype);
+    hash_reduce(extents);
+    hash_reduce(body);
+    hash_reduce(data);
+  }
+
+  /*!
+   * \brief If the buffer size is constant, return the size.
+   *        Otherwise return 0.
+   * \return The result.
+   */
+  int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
+  /*!
+   * \brief If the buffer size is constant, return the size.
+   *        Otherwise return 0.
+   * \param extents The extents of the buffer.
+   * \return The result.
+   */
+  TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
+
+  static constexpr const char* _type_key = "tir.AllocateConst";

Review comment:
       did we forget to declare that the sequal/shash fields exist?

##########
File path: include/tvm/tir/stmt.h
##########
@@ -585,6 +585,118 @@ class Allocate : public Stmt {
   TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
 };
 
+/*!
+ * \brief Describes one parameter that should be linked into the generated module.
+ *
+ * When parameters are to be linked in with generated code (i.e. on target_host-compatible
+ * backends), Relay attaches instances of this object to a global TIR function. Code-generators
+ * use the information contained in this node to include the parameter data in the generated
+ * module.
+ */
+class LinkedParamNode : public Object {
+ public:
+  /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
+  int64_t id;
+
+  /*! \brief Parameter data which should get linked into the final module. */
+  ::tvm::runtime::NDArray param;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("id", &id);
+    v->Visit("param", &param);
+  }
+
+  static constexpr const char* _type_key = "tir.LinkedParam";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
+};
+
+/*!
+ * \brief Managed reference to LinkedParamNode.
+ */
+class LinkedParam : public ObjectRef {
+ public:
+  TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
+};
+
+/*!
+ * \brief Allocate a buffer that can be used in body.
+ */
+class AllocateConstNode : public StmtNode {
+ public:
+  /*! \brief The buffer variable. */
+  Var buffer_var;
+  /*! \brief The optional data associated to the constant.
+   */
+  Optional<runtime::NDArray> data;
+  /*! \brief If the PrimFunc containing the Stmt is added to IRModule,
+       this is an optional index to indicate the index within
+       "Constants" attribute, that is a Array<NDArray> of IRModule.
+   */
+  Optional<Integer> irmod_storage_idx;

Review comment:
       Per RFC discussion, IIUC we agree to put constants as attributes of the IRModule, instead of storing the `data` in-class. Therefore, I would suggest that we instead use a field to refer to the function attributes instead. For example,
   
   ```
   Optional<String> mod_attr_key;  //  NullOpt means the constant is not bound to the IRModule yet
   ```

##########
File path: include/tvm/tir/stmt.h
##########
@@ -585,6 +585,118 @@ class Allocate : public Stmt {
   TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
 };
 
+/*!

Review comment:
       Note that LinkedParam is not a statement, and thus this header may not be the best place. Given we are moving constants to IRModule (which is beyond TIR level), it might make more sense to move it to `tvm/ir/module.h`.
   
   I'm not 100% sure. @tqchen @areusch please kindly share your ideas

##########
File path: include/tvm/tir/stmt.h
##########
@@ -585,6 +585,118 @@ class Allocate : public Stmt {
   TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
 };
 
+/*!
+ * \brief Describes one parameter that should be linked into the generated module.
+ *
+ * When parameters are to be linked in with generated code (i.e. on target_host-compatible
+ * backends), Relay attaches instances of this object to a global TIR function. Code-generators
+ * use the information contained in this node to include the parameter data in the generated
+ * module.
+ */
+class LinkedParamNode : public Object {
+ public:
+  /*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
+  int64_t id;
+
+  /*! \brief Parameter data which should get linked into the final module. */
+  ::tvm::runtime::NDArray param;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("id", &id);
+    v->Visit("param", &param);
+  }
+
+  static constexpr const char* _type_key = "tir.LinkedParam";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
+};
+
+/*!
+ * \brief Managed reference to LinkedParamNode.
+ */
+class LinkedParam : public ObjectRef {
+ public:
+  TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
+};
+
+/*!
+ * \brief Allocate a buffer that can be used in body.
+ */
+class AllocateConstNode : public StmtNode {
+ public:
+  /*! \brief The buffer variable. */
+  Var buffer_var;
+  /*! \brief The optional data associated to the constant.
+   */
+  Optional<runtime::NDArray> data;
+  /*! \brief If the PrimFunc containing the Stmt is added to IRModule,
+       this is an optional index to indicate the index within
+       "Constants" attribute, that is a Array<NDArray> of IRModule.
+   */
+  Optional<Integer> irmod_storage_idx;
+  /*! \brief The type of the buffer. */
+  DataType dtype;
+  /*! \brief The extents of the buffer. */
+  Array<PrimExpr> extents;
+  /*! \brief The body to be executed. */
+  Stmt body;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("buffer_var", &buffer_var);
+    v->Visit("dtype", &dtype);
+    v->Visit("extents", &extents);
+    v->Visit("body", &body);
+    v->Visit("span", &span);
+  }
+
+  bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
+    return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
+           equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce.DefHash(buffer_var);
+    hash_reduce(dtype);
+    hash_reduce(extents);
+    hash_reduce(body);
+    hash_reduce(data);
+  }
+
+  /*!
+   * \brief If the buffer size is constant, return the size.
+   *        Otherwise return 0.
+   * \return The result.
+   */
+  int32_t constant_allocation_size() const { return constant_allocation_size(extents); }

Review comment:
       Note: in TVM, except for accessor methods, we prefer to use CamelCase.
   
   BTW, do you prefer int32_t or int64_t? I don't have strong opinion

##########
File path: include/tvm/ir/module.h
##########
@@ -349,6 +350,9 @@ class IRModuleNode : public Object {
    */
   std::unordered_set<String> import_set_;
   friend class IRModule;
+
+ public:
+  void ExtractPrimFuncConstants(tir::PrimFunc func);

Review comment:
       IMHO this looks like a pass instead of a method of IRModule.
   
   Update: After a second look, I don't believe we need this method even as a pass. Instead we should make sure that the constants are always stored in the IRModule attributes from the very beginning

##########
File path: include/tvm/tir/transform.h
##########
@@ -25,10 +25,12 @@
 #define TVM_TIR_TRANSFORM_H_
 
 #include <tvm/ir/transform.h>
+#include <tvm/relay/expr.h>

Review comment:
       Note that by design, TIR doesn't depend on Relay (otherwise there might be cyclic dependency).

##########
File path: include/tvm/ir/module.h
##########
@@ -357,6 +361,8 @@ class IRModuleNode : public Object {
  */
 class IRModule : public ObjectRef {
  public:
+  static constexpr const char* _constants_attrs_key = "Constants";

Review comment:
       Usually we put TIR attributes outside a class. Would you mind moving it outside?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org