You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/07/12 03:20:51 UTC

[GitHub] [tvm] masahi commented on a diff in pull request #12059: [MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensorization on CUDA

masahi commented on code in PR #12059:
URL: https://github.com/apache/tvm/pull/12059#discussion_r918501599


##########
python/tvm/meta_schedule/testing/schedule_rule.py:
##########
@@ -110,6 +112,38 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
     raise NotImplementedError(f"{target.kind.name} is not supported")
 
 
+def multi_level_tiling_tensor_core(target: Target, scope="shared") -> ScheduleRule:
+    """Default schedule rules for with multi-level tiling reuse for tensor core"""
+    assert scope in ["shared", "global"]
+    if target.kind.name == "cuda":
+        return MultiLevelTilingTensorCore(
+            intrin_group={
+                "init": tensor_intrin.WMMA_FILL_16x16x16_F32_INTRIN,
+                "load_a": tensor_intrin.WMMA_LOAD_16x16x16_F16_A_INTRIN,
+                "load_b": tensor_intrin.WMMA_LOAD_16x16x16_F16_B_INTRIN,
+                "compute": tensor_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
+                "store": tensor_intrin.WMMA_STORE_16x16x16_F32_SHARED_INTRIN
+                if scope == "shared"
+                else tensor_intrin.WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,

Review Comment:
   Need to check dtype



##########
include/tvm/meta_schedule/schedule_rule.h:
##########
@@ -173,6 +173,32 @@ class ScheduleRule : public runtime::ObjectRef {
       Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
       Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
 
+  /*!
+   * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
+   * intrinsics
+   * \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
+   * "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
+   * loading operand A, loading operand B, tensor core computation, storing the result. The value of
+   * the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
+   * beforehand
+   * \param structure The tiling structure. Recommended:
+   * - 'SSRSRS' on CPU
+   * - 'SSSRRSRS' on GPU
+   * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
+   * - NullOpt on CPU

Review Comment:
   CPU not relevant at L185 and L188



##########
include/tvm/meta_schedule/schedule_rule.h:
##########
@@ -173,6 +173,32 @@ class ScheduleRule : public runtime::ObjectRef {
       Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
       Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
 
+  /*!
+   * \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
+   * intrinsics
+   * \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
+   * "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
+   * loading operand A, loading operand B, tensor core computation, storing the result. The value of
+   * the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
+   * beforehand
+   * \param structure The tiling structure. Recommended:
+   * - 'SSRSRS' on CPU
+   * - 'SSSRRSRS' on GPU
+   * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
+   * - NullOpt on CPU

Review Comment:
   Or do we expect this class to be useful for Intel AMX? (CPU with matrix intrinsic)



##########
src/meta_schedule/postproc/rewrite_tensorize.cc:
##########
@@ -35,26 +35,24 @@ void CollectTensorizationJobs(
   tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
     if (const auto* block = obj.as<tir::BlockNode>()) {
       tir::StmtSRef block_sref = sch->GetSRef(block);
+      std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
       if (Optional<String> intrin_name =
               tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
-        std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
-        if (block_name.find("init") == std::string::npos) {
-          jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
-            try {
-              sch->Tensorize(block, intrin_name.value());
-            } catch (const std::exception& e) {
-              LOG(WARNING) << "Tensorize failed with error " << e.what();
-            }
-          });
-        } else if (vectorize_init_loop) {
-          jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
-            Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
-            ICHECK(child_blocks.size() == 1);
-            Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
-            ICHECK(init_loops.size() == 1);
-            sch->Vectorize(init_loops[0]);
-          });
-        }
+        jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
+          try {
+            sch->Tensorize(block, intrin_name.value());
+          } catch (const std::exception& e) {
+            LOG(WARNING) << "Tensorize failed with error " << e.what();
+          }
+        });
+      } else if (block_name.find("init") && vectorize_init_loop) {

Review Comment:
   Do we ever hit this condition after your change in [rewrite_reduction_block.cc](https://github.com/apache/tvm/pull/12059/files#diff-470a1ee8bb8d9ce151669661a93e24fe3b9df3094d431026d0d54fda5b2e2adf)?
   
   To vectorize init loop, should we switch to using `tir::attr::meta_schedule_auto_tensorize_init`?



##########
src/meta_schedule/schedule_rule/multi_level_tiling.h:
##########
@@ -112,6 +117,31 @@ class State : public ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
 };
 
+class TensorCoreStateNode : public StateNode {
+ public:
+  /*! \brief The Tensor Core reindex block A for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_A;
+  /*! \brief The Tensor Core reindex block B for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_B;
+  /*! \brief The Tensor Core reindex store block for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_store;
+
+  State Copy() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
+};
+
+class TensorCoreState : public State {
+ public:
+  explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
+                           Array<Array<tir::LoopRV>> tiles = {});
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
+};

Review Comment:
   This class can be moved to `multi_level_tiling_tensor_core.cc` I think



##########
src/meta_schedule/schedule_rule/multi_level_tiling.h:
##########
@@ -112,6 +117,31 @@ class State : public ObjectRef {
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
 };
 
+class TensorCoreStateNode : public StateNode {
+ public:
+  /*! \brief The Tensor Core reindex block A for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_A;
+  /*! \brief The Tensor Core reindex block B for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_B;
+  /*! \brief The Tensor Core reindex store block for Tensor Core computation */
+  tir::BlockRV tensor_core_reindex_store;
+
+  State Copy() const final;
+
+  static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
+};
+
+class TensorCoreState : public State {
+ public:
+  explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
+                           Array<Array<tir::LoopRV>> tiles = {});
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
+};
+
+struct AutoTensorizationState : public State {};

Review Comment:
   Unused?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org