You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/25 06:29:00 UTC

[tvm] branch main updated: [MetaSchedule] Add software pipeline in CUDA tensor core auto tensorization (#12544)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9aac161a46 [MetaSchedule] Add software pipeline in CUDA tensor core auto tensorization (#12544)
9aac161a46 is described below

commit 9aac161a46e5aca4c433ccb901c1bb84e6c8bd0c
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Aug 24 23:28:54 2022 -0700

    [MetaSchedule] Add software pipeline in CUDA tensor core auto tensorization (#12544)
    
    cc @Hzfengsy @junrushao @junrushao1994 @masahi @spectrometerHBH
---
 include/tvm/meta_schedule/schedule_rule.h          |   3 +-
 python/tvm/meta_schedule/default_config.py         |   1 +
 .../schedule_rule/multi_level_tiling.py            |   4 +
 python/tvm/meta_schedule/testing/schedule_rule.py  |   2 +
 .../multi_level_tiling_tensor_core.cc              | 122 +++++++++++++++++++-
 ...ta_schedule_schedule_rule_multi_level_tiling.py | 125 +++++++++++++++++++++
 6 files changed, 255 insertions(+), 2 deletions(-)

diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index b5f4a17b69..2da441c95e 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -190,13 +190,14 @@ class ScheduleRule : public runtime::ObjectRef {
    * NullOpt means disable vectorization
    * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
    * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
+   * \param use_software_pipeline Whether use the software pipeline.
    * \return The schedule rule created
    */
   TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
       Array<Map<String, String>> intrin_groups, String structure,
       Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
       Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
-      Optional<Map<String, ObjectRef>> reuse_write);
+      Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);
 
   /*!
    * \brief Create a rule: add-rfactor to some blocks if needed
diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index 105b3467de..0f1f7d3c2c 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -381,6 +381,7 @@ class _DefaultCUDATensorCore:
                     levels=[2],
                     scope="shared",
                 ),
+                use_software_pipeline=False,
             ),
             *_DefaultCUDA.schedule_rules(),
         ]
diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
index a728a91eb7..6703bc5716 100644
--- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
+++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
@@ -161,6 +161,8 @@ class MultiLevelTilingTensorCore(ScheduleRule):
         Data reuse configuration for reading. None means no reuse.
     reuse_write : Optional[ReuseType]
         Data reuse configuration for writing. None means no reuse.
+    use_software_pipeline : bool
+        Whether to use the software pipeline.
     """
 
     def __init__(
@@ -172,6 +174,7 @@ class MultiLevelTilingTensorCore(ScheduleRule):
         vector_load_lens: Optional[List[int]] = None,
         reuse_read: Optional[ReuseType] = None,
         reuse_write: Optional[ReuseType] = None,
+        use_software_pipeline: bool = False,
     ) -> None:
         self.__init_handle_by_constructor__(
             _ffi_api.ScheduleRuleMultiLevelTilingTensorCore,  # type: ignore # pylint: disable=no-member
@@ -182,4 +185,5 @@ class MultiLevelTilingTensorCore(ScheduleRule):
             vector_load_lens,
             reuse_read.as_dict() if reuse_read is not None else None,
             reuse_write.as_dict() if reuse_write is not None else None,
+            use_software_pipeline,
         )
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index 441ca930f8..46df4b95ce 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -119,6 +119,7 @@ def multi_level_tiling_tensor_core(
     in_dtype: Union[str, List[str]] = "float16",
     out_dtype: Union[str, List[str]] = "float32",
     trans_b: Union[bool, List[bool]] = False,
+    use_software_pipeline: bool = False,
 ) -> ScheduleRule:
     """Default schedule rules for with multi-level tiling reuse for tensor core"""
     assert write_reuse_scope in ["shared", "global"]
@@ -154,6 +155,7 @@ def multi_level_tiling_tensor_core(
                 levels=[2],
                 scope=write_reuse_scope,
             ),
+            use_software_pipeline=use_software_pipeline,
         )
     raise NotImplementedError(f"{target.kind.name} is not supported")
 
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 7a3ec513db..49704fb66b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -128,6 +128,8 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
   inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state) const;
   // Subrule: Add tensorized store
   inline std::vector<State> AddWriteReuseTensorCore(TensorCoreState state) const;
+  // Subrule: Add software pipeline
+  inline std::vector<State> AddSoftwarePipeline(TensorCoreState state) const;
 
   // Override ApplySubRules to apply tensorization-specific sub-rules
   std::vector<State> ApplySubRules(std::vector<State> states) final;
@@ -155,6 +157,8 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
  public:
   /*! \brief The candidate tensor core intrin groups to apply */
   std::vector<TensorCoreIntrinGroup> intrin_groups;
+  /*! \brief Whether to use software pipeline */
+  bool use_software_pipeline = false;
   static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore";
   TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode);
 
@@ -222,6 +226,9 @@ std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
   states = SubRule(std::move(states), [&](State state) {
     return AddReadReuseTensorCore(Downcast<TensorCoreState>(state));
   });
+  states = SubRule(std::move(states), [&](State state) {
+    return AddSoftwarePipeline(Downcast<TensorCoreState>(state));
+  });
   return states;
 }
 
@@ -286,6 +293,117 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
   return {state};
 }
 
+std::vector<State> MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
+    TensorCoreState state) const {
+  if (!use_software_pipeline) {
+    return {state};
+  }
+  // The current config is not suitable for software pipelining.
+  if (r_indices_.size() < 2) {
+    return {state};
+  }
+
+  Schedule& sch = state->sch;
+  // Check reduction length after blockize.
+  int64_t reduction_length = 1;
+  for (int r_index : r_indices_) {
+    const Array<LoopRV>& tiles = state->tiles[r_index];
+    for (const LoopRV& tile : tiles) {
+      const auto* extent = sch->Get(tile)->extent.as<IntImmNode>();
+      ICHECK(extent != nullptr) << "Dynamic extent is not supported.";
+      reduction_length *= extent->value;
+    }
+  }
+  if (reduction_length <= 1) {
+    return {state};
+  }
+
+  // Add local stage and double buffering
+  for (int i = 0; i < 2; ++i) {
+    const tir::BlockRV cache_read = state->read_reuse.at(i);
+    sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Bool(true));
+    sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
+  }
+
+  // Add annotations of software pipeline
+  //
+  // Before pipelining, the original loop can be expressed as the pseudo code below:
+  //
+  // for k0 in [0, K0):
+  //   load tile k0 to registers
+  //   load tile k0 from registers to shared memory
+  //
+  //   for k1 in [0, K1):
+  //     load fragment k1 of tile k0
+  //     compute matmul with fragment k1
+  //
+
+  // Inner software pipeline: Prefetch to tensor core fragment by one iteration
+  // The following annotation for the inner loop is equivalent the pesudo code below:
+  //
+  // Pipelined inner loop:
+  //
+  // prologue:
+  //   load fragment 0
+  // body:
+  //   for k1 in [0, K1 - 1):
+  //     load fragment k1 + 1
+  //     compute matmul with fragment k1
+  // epilogue:
+  //   compute matmul with fragment K1 - 1
+  //
+  sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_stage,
+                Array<Integer>{0, 0, 1});
+  sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_order,
+                Array<Integer>{0, 1, 2});
+  // Outer software pipeline: Interleave the outer loop with the (pipelined) inner loop.
+  // The prefetching stage of the inner pipeline is executed by one iteration in the outer loop.
+  // The following annotation for the outer loop is equivalent the pesudo code below:
+  //
+  // Pipelined outer loop with nested inner pipeline:
+  //
+  // prologue:
+  //   load tile 0 to registers
+  //   load tile 0 from registers to shared memory
+  //
+  //   // prologue of the inner pipeline
+  //   load fragment 0 of tile 0
+  //
+  // body:
+  //   for k0 in [0, K0 - 1):
+  //     load tile k0 + 1 to registers
+  //
+  //     // body of the inner pipeline
+  //     for k1 in [0, K1 - 1):
+  //       load fragment k1 + 1 of tile k0
+  //       compute matmul with fragment k1 of tile k0
+  //
+  //     load tile k0 + 1 from registers to shared memory
+  //
+  //     // prologue of the inner pipeline
+  //     load fragment 0 of tile k0 + 1
+  //
+  //     // epilogue of the inner pipeline
+  //     compute matmul with fragment K1 - 1 of tile k0
+  //
+  // epilogue:
+  //
+  //   // body of the inner pipeline
+  //   for k1 in [0, K1 - 1):
+  //     load fragment k1 + 1 of tile K0 - 1
+  //     compute matmul with fragment k1 of tile K0 - 1
+  //
+  //   // epilogue of the inner pipeline
+  //   compute matmul with fragment K1 - 1 of tile K0 - 1
+  //
+  sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage,
+                Array<Integer>{0, 0, 0, 0, 0, 1, 1});
+  sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order,
+                Array<Integer>{0, 3, 1, 4, 5, 2, 6});
+
+  return {state};
+}
+
 Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
     TensorCoreStateNode* state, const String& intrin_name) const {
   BlockRV block_rv = state->block_rv;
@@ -418,7 +536,8 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
 ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
     Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
     Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
-    Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
+    Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write,
+    bool use_software_pipeline) {
   auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
       structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
 
@@ -426,6 +545,7 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
   for (const auto& intrin_group_config : intrin_groups) {
     node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
   }
+  node->use_software_pipeline = use_software_pipeline;
   return ScheduleRule(node);
 }
 
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 4da870e455..87159fcb31 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -709,6 +709,131 @@ sch.reverse_compute_inline(block=b1)""".split(
     check_trace(spaces, expected)
 
 
+def test_cuda_tensor_core_software_pipeline_matmul_relu():
+    m = n = k = 128
+    target = Target("cuda", host="llvm")
+    ctx = _create_context(
+        create_prim_func(
+            te_workload.matmul_relu(
+                n=n,
+                m=m,
+                k=k,
+                in_dtype="float16",
+                out_dtype="float32",
+            )
+        ),
+        target=target,
+        rule=[
+            multi_level_tiling_tensor_core(
+                target=target, write_reuse_scope="shared", use_software_pipeline=True
+            ),
+            auto_inline(target),
+        ],
+    )
+    spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+    assert len(spaces) == 1
+
+    expected = [
+        """b0 = sch.get_block(name="C", func_name="main")
+b1 = sch.get_block(name="compute", func_name="main")
+sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
+b2 = sch.reindex(block=b0, buffer=("write", 0))
+b3 = sch.reindex(block=b0, buffer=("read", 0))
+b4 = sch.reindex(block=b0, buffer=("read", 1))
+sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
+sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
+sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
+sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
+l5, l6, l7 = sch.get_loops(block=b0)
+l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True)
+l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
+l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
+l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0)
+sch.reorder(l16, l18, l13, l11, l9)
+b20 = sch.blockize(loop=l13)
+sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
+sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
+sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1)
+l21, l22, l23 = sch.get_loops(block=b20)
+v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
+l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True)
+v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4)
+l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True)
+v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4)
+l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True)
+sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43)
+l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
+sch.bind(loop=l50, thread_axis="blockIdx.y")
+l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
+sch.bind(loop=l51, thread_axis="blockIdx.x")
+l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
+sch.bind(loop=l52, thread_axis="threadIdx.y")
+b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
+sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True)
+b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
+sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True)
+v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
+sch.reverse_compute_inline(block=b2)
+l56, l57, l58, l59, l60 = sch.get_loops(block=b54)
+l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True)
+l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True)
+l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54)
+sch.reorder(l70, l64, l62)
+b72 = sch.blockize(loop=l64)
+sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
+b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
+sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
+l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
+l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
+v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
+b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
+sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
+l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
+l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
+v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
+b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
+sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
+l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
+l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
+l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
+l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91)
+sch.reorder(l110, l102, l100)
+b112 = sch.blockize(loop=l102)
+sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
+b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
+sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True)
+l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
+l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
+l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
+l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113)
+sch.reorder(l132, l124, l122)
+b134 = sch.blockize(loop=l124)
+sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
+sch.compute_inline(block=b3)
+sch.compute_inline(block=b4)
+sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.annotate(block_or_loop=b73, ann_key="tir.manifest_shared_memory_local_stage", ann_val=1)
+sch.annotate(block_or_loop=b73, ann_key="double_buffer_scope", ann_val=0)
+sch.annotate(block_or_loop=b82, ann_key="tir.manifest_shared_memory_local_stage", ann_val=1)
+sch.annotate(block_or_loop=b82, ann_key="double_buffer_scope", ann_val=0)
+sch.annotate(block_or_loop=l48, ann_key="software_pipeline_stage", ann_val=[0, 0, 1])
+sch.annotate(block_or_loop=l48, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
+sch.annotate(block_or_loop=l47, ann_key="software_pipeline_stage", ann_val=[0, 0, 0, 0, 0, 1, 1])
+sch.annotate(block_or_loop=l47, ann_key="software_pipeline_order", ann_val=[0, 3, 1, 4, 5, 2, 6])
+sch.reverse_compute_inline(block=b1)""".split(
+            "\n"
+        )
+    ]
+    check_trace(spaces, expected)
+
+
 def test_cuda_tensor_core_matmul_relu_global():
     m = n = k = 128
     target = Target("cuda", host="llvm")