You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/01/08 06:38:40 UTC

[GitHub] [tvm] Hzfengsy commented on a change in pull request #9871: [TIR][Schedule] Blockize and Tensorize

Hzfengsy commented on a change in pull request #9871:
URL: https://github.com/apache/tvm/pull/9871#discussion_r780634559



##########
File path: include/tvm/tir/function.h
##########
@@ -187,6 +187,56 @@ class LinkedParam : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
 };
 
+/*!
+ * \brief Tensor intrinsics for tensorization
+ */
+class TensorIntrinNode : public Object {
+ public:
+  /*! \brief The function to describe the computation. */
+  PrimFunc description;
+  /*! \brief The intrinsic function for lower-level implementation. */
+  PrimFunc implementation;

Review comment:
       Vote for `desc` and `impl`

##########
File path: include/tvm/tir/schedule/schedule.h
##########
@@ -465,6 +465,25 @@ class ScheduleNode : public runtime::Object {
    */
   virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
   /******** Schedule: Blockize & Tensorize ********/
+  /*!
+   * \brief Convert the subtree rooted at a specific loop into a block.
+   * \param loop_rv the root of the subtree
+   * \return the new block
+   */
+  virtual BlockRV Blockize(const LoopRV& loop_rv) = 0;
+  /*!
+   * \brief Tensorize the computation enclosed by loop with tensor_intrin
+   * \param loop_rv the loop to be tensorized
+   * \param intrin the tensor intrinsic
+   */
+  virtual void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) = 0;
+  /*!
+   * \brief Tensorize the computation enclosed by loop with tensor_intrin
+   * \param loop_rv The loop to be tensorized
+   * \param intrin_name Name of the tensor intrinsic
+   */
+  virtual void Tensorize(const LoopRV& loop_rv, const String& intrin_name) = 0;

Review comment:
       Should we add a tensorize API for a block? i.e. Tensorize it if we manually blockized before.

##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent */
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+  /*!
+   * \brief Constructor of TensorizeComparator
+   * \param assert_mode Whether to raise an error if the two IR ASTs do not match.
+   */
+  explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {}
+
+  // Map from rhs buffer to lhs buffer
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;

Review comment:
       Keep them private if they are internal vars

##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent */
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+  /*!
+   * \brief Constructor of TensorizeComparator
+   * \param assert_mode Whether to raise an error if the two IR ASTs do not match.
+   */
+  explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {}
+
+  // Map from rhs buffer to lhs buffer
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;
+  // Buffer indices mapping
+  std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> buffer_indices_;
+  std::vector<IterVar> extra_block_vars_;
+  // variable remap if any
+  std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_;
+
+  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
+  bool VisitStmt(const Stmt& n, const Stmt& other) override;
+
+  bool VisitStmt_(const ForNode* op, const Stmt& other) override;
+  bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockNode* op, const Stmt& other) override;

Review comment:
       Yes, only parts of types are supported now. We may add the rest when needed

##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)>;
+using StmtComparator = StmtFunctor<bool(const Stmt& n, const Stmt& other)>;
+
+/*! \brief Deep comparison to check if two IR ASTs are equivalent */
+class TensorizeComparator : public ExprComparator, public StmtComparator {
+ public:
+  /*!
+   * \brief Constructor of TensorizeComparator
+   * \param assert_mode Whether to raise an error if the two IR ASTs do not match.
+   */
+  explicit TensorizeComparator(bool assert_mode = true) : assert_mode_(assert_mode) {}
+
+  // Map from rhs buffer to lhs buffer
+  std::unordered_map<Buffer, Buffer, ObjectHash, ObjectEqual> rhs_buffer_map_;
+  // Buffer indices mapping
+  std::unordered_map<Buffer, std::vector<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> buffer_indices_;
+  std::vector<IterVar> extra_block_vars_;
+  // variable remap if any
+  std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_;
+
+  bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
+  bool VisitStmt(const Stmt& n, const Stmt& other) override;
+
+  bool VisitStmt_(const ForNode* op, const Stmt& other) override;
+  bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockRealizeNode* op, const Stmt& other) override;
+  bool VisitStmt_(const BlockNode* op, const Stmt& other) override;
+
+  bool VisitExpr_(const AddNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const SubNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MulNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const DivNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const ModNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const EQNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const NENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const LTNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const LENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const GTNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const GENode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const AndNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const OrNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MinNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const MaxNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloorDivNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloorModNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const CastNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const VarNode* op, const PrimExpr& other) override;
+  bool VisitExpr_(const BufferLoadNode* op, const PrimExpr& other) override;
+
+  bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);
+  virtual bool CompareBuffer(const Buffer& lhs, const Buffer& rhs);
+  bool CompareBufferRegion(const BufferRegion& lhs, const BufferRegion& rhs);
+  bool CompareAnnotation(const std::pair<String, ObjectRef>& lhs,
+                         const std::pair<String, ObjectRef>& rhs);
+  bool CompareAnnotationMap(const Map<String, ObjectRef>& lhs, const Map<String, ObjectRef>& rhs);
+  template <typename T>
+  bool CompareBufferAccess(const T* lhs, const T* rhs);
+  template <typename T, typename F>
+  bool CompareArray(const Array<T>& lhs, const Array<T>& rhs, F cmp);
+  bool CompareRange(const Range& lhs, const Range& rhs);
+  bool CompareType(const DataType& lhs, const DataType& rhs);
+
+ protected:

Review comment:
       Why `protected`?

##########
File path: src/tir/schedule/analysis/analysis.cc
##########
@@ -1376,5 +1376,333 @@ void CheckStorageScope(const ScheduleState& self, String storage_scope) {
   }
 }
 
+/******** Tensorize Comparator ********/
+
+bool TensorizeComparator::VisitStmt(const Stmt& n, const Stmt& other) {
+  if (n.same_as(other)) return true;
+  if (n->type_index() != other->type_index()) return false;
+  bool equal = StmtComparator::VisitStmt(n, other);
+  if (!equal && assert_mode_)
+    LOG(FATAL) << "Stmts are not matching between:\n" << n << "\nand\n" << other;

Review comment:
       Can we use `ScheduleError` here? Since it has a better error renderer

##########
File path: src/tir/ir/function.cc
##########
@@ -64,6 +64,67 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
+  // check the number of func var is equal
+  CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
+  CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());
+
+  // check both functions' bodies are directly block
+  const auto* desc_realize =
+      Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
+  const auto* intrin_realize =
+      Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
+  CHECK(desc_realize != nullptr) << "description function's body expect a directly block";
+  CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block";
+
+  const Block& desc_block = desc_realize->block;
+  const Block& intrin_block = intrin_realize->block;
+
+  // check block var number and iter type
+  CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size())
+      << "Two blocks should have the same number of block vars";
+  for (size_t i = 0; i < desc_block->iter_vars.size(); i++) {
+    const IterVar& desc_var = desc_block->iter_vars[i];
+    const IterVar& intrin_var = intrin_block->iter_vars[i];
+    CHECK(desc_var->iter_type == intrin_var->iter_type)
+        << "Block iter_type mismatch between " << desc_var->iter_type << " and "
+        << intrin_var->iter_type;
+  }
+
+  auto n = make_object<TensorIntrinNode>();

Review comment:
       The Type `TensorIntrinNode` is embedded in `make_object`. Personally think `auto` is OK here.

##########
File path: src/tir/ir/function.cc
##########
@@ -64,6 +64,67 @@ FuncType PrimFuncNode::func_type_annotation() const {
 
 TVM_REGISTER_NODE_TYPE(PrimFuncNode);
 
+class TensorIntrinManager {
+ public:
+  Map<String, tir::TensorIntrin> reg;
+
+  static TensorIntrinManager* Global() {
+    static TensorIntrinManager* inst = new TensorIntrinManager();
+    return inst;
+  }
+};
+
+TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
+  // check the number of func var is equal
+  CHECK_EQ(desc_func->params.size(), intrin_func->params.size());

Review comment:
       Do we need all params are Buffers? (To be specific Buffer Handles)

##########
File path: src/tir/schedule/analysis.h
##########
@@ -442,6 +442,79 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S
 bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
                          const StmtSRef& loop_sref, bool preserve_unit_loops);
 
+/******** Tensorization Related ********/
+
+using ExprComparator = ExprFunctor<bool(const PrimExpr& n, const PrimExpr& other)>;

Review comment:
       +1




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org