You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/08/25 06:29:00 UTC
[tvm] branch main updated: [MetaSchedule] Add software pipeline in CUDA tensor core auto tensorization (#12544)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9aac161a46 [MetaSchedule] Add software pipeline in CUDA tensor core auto tensorization (#12544)
9aac161a46 is described below
commit 9aac161a46e5aca4c433ccb901c1bb84e6c8bd0c
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Aug 24 23:28:54 2022 -0700
[MetaSchedule] Add software pipeline in CUDA tensor core auto tensorization (#12544)
cc @Hzfengsy @junrushao @junrushao1994 @masahi @spectrometerHBH
---
include/tvm/meta_schedule/schedule_rule.h | 3 +-
python/tvm/meta_schedule/default_config.py | 1 +
.../schedule_rule/multi_level_tiling.py | 4 +
python/tvm/meta_schedule/testing/schedule_rule.py | 2 +
.../multi_level_tiling_tensor_core.cc | 122 +++++++++++++++++++-
...ta_schedule_schedule_rule_multi_level_tiling.py | 125 +++++++++++++++++++++
6 files changed, 255 insertions(+), 2 deletions(-)
diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index b5f4a17b69..2da441c95e 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -190,13 +190,14 @@ class ScheduleRule : public runtime::ObjectRef {
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
+ * \param use_software_pipeline Whether use the software pipeline.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
Array<Map<String, String>> intrin_groups, String structure,
Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
- Optional<Map<String, ObjectRef>> reuse_write);
+ Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline);
/*!
* \brief Create a rule: add-rfactor to some blocks if needed
diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index 105b3467de..0f1f7d3c2c 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -381,6 +381,7 @@ class _DefaultCUDATensorCore:
levels=[2],
scope="shared",
),
+ use_software_pipeline=False,
),
*_DefaultCUDA.schedule_rules(),
]
diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
index a728a91eb7..6703bc5716 100644
--- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
+++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
@@ -161,6 +161,8 @@ class MultiLevelTilingTensorCore(ScheduleRule):
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
+ use_software_pipeline : bool
+ Whether to use the software pipeline.
"""
def __init__(
@@ -172,6 +174,7 @@ class MultiLevelTilingTensorCore(ScheduleRule):
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
+ use_software_pipeline: bool = False,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member
@@ -182,4 +185,5 @@ class MultiLevelTilingTensorCore(ScheduleRule):
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
+ use_software_pipeline,
)
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index 441ca930f8..46df4b95ce 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -119,6 +119,7 @@ def multi_level_tiling_tensor_core(
in_dtype: Union[str, List[str]] = "float16",
out_dtype: Union[str, List[str]] = "float32",
trans_b: Union[bool, List[bool]] = False,
+ use_software_pipeline: bool = False,
) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert write_reuse_scope in ["shared", "global"]
@@ -154,6 +155,7 @@ def multi_level_tiling_tensor_core(
levels=[2],
scope=write_reuse_scope,
),
+ use_software_pipeline=use_software_pipeline,
)
raise NotImplementedError(f"{target.kind.name} is not supported")
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 7a3ec513db..49704fb66b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -128,6 +128,8 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state) const;
// Subrule: Add tensorized store
inline std::vector<State> AddWriteReuseTensorCore(TensorCoreState state) const;
+ // Subrule: Add software pipeline
+ inline std::vector<State> AddSoftwarePipeline(TensorCoreState state) const;
// Override ApplySubRules to apply tensorization-specific sub-rules
std::vector<State> ApplySubRules(std::vector<State> states) final;
@@ -155,6 +157,8 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
public:
/*! \brief The candidate tensor core intrin groups to apply */
std::vector<TensorCoreIntrinGroup> intrin_groups;
+ /*! \brief Whether to use software pipeline */
+ bool use_software_pipeline = false;
static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore";
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode);
@@ -222,6 +226,9 @@ std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
states = SubRule(std::move(states), [&](State state) {
return AddReadReuseTensorCore(Downcast<TensorCoreState>(state));
});
+ states = SubRule(std::move(states), [&](State state) {
+ return AddSoftwarePipeline(Downcast<TensorCoreState>(state));
+ });
return states;
}
@@ -286,6 +293,117 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
return {state};
}
+std::vector<State> MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
+ TensorCoreState state) const {
+ if (!use_software_pipeline) {
+ return {state};
+ }
+ // The current config is not suitable for software pipelining.
+ if (r_indices_.size() < 2) {
+ return {state};
+ }
+
+ Schedule& sch = state->sch;
+ // Check reduction length after blockize.
+ int64_t reduction_length = 1;
+ for (int r_index : r_indices_) {
+ const Array<LoopRV>& tiles = state->tiles[r_index];
+ for (const LoopRV& tile : tiles) {
+ const auto* extent = sch->Get(tile)->extent.as<IntImmNode>();
+ ICHECK(extent != nullptr) << "Dynamic extent is not supported.";
+ reduction_length *= extent->value;
+ }
+ }
+ if (reduction_length <= 1) {
+ return {state};
+ }
+
+ // Add local stage and double buffering
+ for (int i = 0; i < 2; ++i) {
+ const tir::BlockRV cache_read = state->read_reuse.at(i);
+ sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Bool(true));
+ sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
+ }
+
+ // Add annotations of software pipeline
+ //
+ // Before pipelining, the original loop can be expressed as the pseudo code below:
+ //
+ // for k0 in [0, K0):
+ // load tile k0 to registers
+ // load tile k0 from registers to shared memory
+ //
+ // for k1 in [0, K1):
+ // load fragment k1 of tile k0
+ // compute matmul with fragment k1
+ //
+
+ // Inner software pipeline: Prefetch to tensor core fragment by one iteration
+ // The following annotation for the inner loop is equivalent the pesudo code below:
+ //
+ // Pipelined inner loop:
+ //
+ // prologue:
+ // load fragment 0
+ // body:
+ // for k1 in [0, K1 - 1):
+ // load fragment k1 + 1
+ // compute matmul with fragment k1
+ // epilogue:
+ // compute matmul with fragment K1 - 1
+ //
+ sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_stage,
+ Array<Integer>{0, 0, 1});
+ sch->Annotate(state->tiles[r_indices_[1]].back(), tir::attr::software_pipeline_order,
+ Array<Integer>{0, 1, 2});
+ // Outer software pipeline: Interleave the outer loop with the (pipelined) inner loop.
+ // The prefetching stage of the inner pipeline is executed by one iteration in the outer loop.
+ // The following annotation for the outer loop is equivalent the pesudo code below:
+ //
+ // Pipelined outer loop with nested inner pipeline:
+ //
+ // prologue:
+ // load tile 0 to registers
+ // load tile 0 from registers to shared memory
+ //
+ // // prologue of the inner pipeline
+ // load fragment 0 of tile 0
+ //
+ // body:
+ // for k0 in [0, K0 - 1):
+ // load tile k0 + 1 to registers
+ //
+ // // body of the inner pipeline
+ // for k1 in [0, K1 - 1):
+ // load fragment k1 + 1 of tile k0
+ // compute matmul with fragment k1 of tile k0
+ //
+ // load tile k0 + 1 from registers to shared memory
+ //
+ // // prologue of the inner pipeline
+ // load fragment 0 of tile k0 + 1
+ //
+ // // epilogue of the inner pipeline
+ // compute matmul with fragment K1 - 1 of tile k0
+ //
+ // epilogue:
+ //
+ // // body of the inner pipeline
+ // for k1 in [0, K1 - 1):
+ // load fragment k1 + 1 of tile K0 - 1
+ // compute matmul with fragment k1 of tile K0 - 1
+ //
+ // // epilogue of the inner pipeline
+ // compute matmul with fragment K1 - 1 of tile K0 - 1
+ //
+ sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_stage,
+ Array<Integer>{0, 0, 0, 0, 0, 1, 1});
+ sch->Annotate(state->tiles[r_indices_[0]].back(), tir::attr::software_pipeline_order,
+ Array<Integer>{0, 3, 1, 4, 5, 2, 6});
+
+ return {state};
+}
+
Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
TensorCoreStateNode* state, const String& intrin_name) const {
BlockRV block_rv = state->block_rv;
@@ -418,7 +536,8 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
- Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
+ Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write,
+ bool use_software_pipeline) {
auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
@@ -426,6 +545,7 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
for (const auto& intrin_group_config : intrin_groups) {
node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
}
+ node->use_software_pipeline = use_software_pipeline;
return ScheduleRule(node);
}
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 4da870e455..87159fcb31 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -709,6 +709,131 @@ sch.reverse_compute_inline(block=b1)""".split(
check_trace(spaces, expected)
+def test_cuda_tensor_core_software_pipeline_matmul_relu():
+ m = n = k = 128
+ target = Target("cuda", host="llvm")
+ ctx = _create_context(
+ create_prim_func(
+ te_workload.matmul_relu(
+ n=n,
+ m=m,
+ k=k,
+ in_dtype="float16",
+ out_dtype="float32",
+ )
+ ),
+ target=target,
+ rule=[
+ multi_level_tiling_tensor_core(
+ target=target, write_reuse_scope="shared", use_software_pipeline=True
+ ),
+ auto_inline(target),
+ ],
+ )
+ spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
+ assert len(spaces) == 1
+
+ expected = [
+ """b0 = sch.get_block(name="C", func_name="main")
+b1 = sch.get_block(name="compute", func_name="main")
+sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
+b2 = sch.reindex(block=b0, buffer=("write", 0))
+b3 = sch.reindex(block=b0, buffer=("read", 0))
+b4 = sch.reindex(block=b0, buffer=("read", 1))
+sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
+sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
+sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
+sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, ))
+sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
+l5, l6, l7 = sch.get_loops(block=b0)
+l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True)
+l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
+l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
+l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0)
+sch.reorder(l16, l18, l13, l11, l9)
+b20 = sch.blockize(loop=l13)
+sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
+sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
+sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1)
+l21, l22, l23 = sch.get_loops(block=b20)
+v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
+l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True)
+v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4)
+l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True)
+v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4)
+l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True)
+sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43)
+l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
+sch.bind(loop=l50, thread_axis="blockIdx.y")
+l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
+sch.bind(loop=l51, thread_axis="blockIdx.x")
+l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
+sch.bind(loop=l52, thread_axis="threadIdx.y")
+b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
+sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True)
+b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
+sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True)
+v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
+sch.reverse_compute_inline(block=b2)
+l56, l57, l58, l59, l60 = sch.get_loops(block=b54)
+l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True)
+l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True)
+l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54)
+sch.reorder(l70, l64, l62)
+b72 = sch.blockize(loop=l64)
+sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
+b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
+sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
+l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
+l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
+v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
+b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
+sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
+l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
+l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
+v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
+sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
+b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
+sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
+l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
+l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
+l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
+l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91)
+sch.reorder(l110, l102, l100)
+b112 = sch.blockize(loop=l102)
+sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
+b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
+sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True)
+l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
+l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
+l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
+l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113)
+sch.reorder(l132, l124, l122)
+b134 = sch.blockize(loop=l124)
+sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
+sch.compute_inline(block=b3)
+sch.compute_inline(block=b4)
+sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8)
+sch.annotate(block_or_loop=b73, ann_key="tir.manifest_shared_memory_local_stage", ann_val=1)
+sch.annotate(block_or_loop=b73, ann_key="double_buffer_scope", ann_val=0)
+sch.annotate(block_or_loop=b82, ann_key="tir.manifest_shared_memory_local_stage", ann_val=1)
+sch.annotate(block_or_loop=b82, ann_key="double_buffer_scope", ann_val=0)
+sch.annotate(block_or_loop=l48, ann_key="software_pipeline_stage", ann_val=[0, 0, 1])
+sch.annotate(block_or_loop=l48, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
+sch.annotate(block_or_loop=l47, ann_key="software_pipeline_stage", ann_val=[0, 0, 0, 0, 0, 1, 1])
+sch.annotate(block_or_loop=l47, ann_key="software_pipeline_order", ann_val=[0, 3, 1, 4, 5, 2, 6])
+sch.reverse_compute_inline(block=b1)""".split(
+ "\n"
+ )
+ ]
+ check_trace(spaces, expected)
+
+
def test_cuda_tensor_core_matmul_relu_global():
m = n = k = 128
target = Target("cuda", host="llvm")