You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/13 08:06:21 UTC
[tvm] branch main updated: [MetaSchedule] Filter vector_load_lens based on buffer dtype (#12408)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 96cac7501d [MetaSchedule] Filter vector_load_lens based on buffer dtype (#12408)
96cac7501d is described below
commit 96cac7501dbce42af6c5bb6f75b51a4b70c4077a
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Sat Aug 13 01:06:16 2022 -0700
[MetaSchedule] Filter vector_load_lens based on buffer dtype (#12408)
This makes the same config generic to work across workloads with different types.
---
python/tvm/meta_schedule/default_config.py | 4 +--
python/tvm/meta_schedule/testing/schedule_rule.py | 4 +--
.../schedule_rule/multi_level_tiling.cc | 29 +++++++++++++++++++---
...ta_schedule_schedule_rule_multi_level_tiling.py | 16 ++++++------
4 files changed, 38 insertions(+), 15 deletions(-)
diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index e27b6ad4b4..58f82a248b 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -293,7 +293,7 @@ class _DefaultCUDA:
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
- vector_load_lens=[1, 2, 3, 4],
+ vector_load_lens=[1, 2, 3, 4, 8, 16],
reuse_read=M.ReuseType(
req="must",
levels=[4],
@@ -374,7 +374,7 @@ class _DefaultCUDATensorCore:
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4,
- vector_load_lens=[1, 2, 3, 4],
+ vector_load_lens=[1, 2, 3, 4, 8, 16],
reuse_read=M.ReuseType(req="must", levels=[4], scope="shared"),
reuse_write=M.ReuseType(
req="must",
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index ea748ddc05..f5a936f491 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -98,7 +98,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
- vector_load_lens=[1, 2, 3, 4],
+ vector_load_lens=[1, 2, 3, 4, 8, 16],
reuse_read=ReuseType(
req="must",
levels=[4],
@@ -141,7 +141,7 @@ def multi_level_tiling_tensor_core(
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
- vector_load_lens=[1, 2, 3, 4],
+ vector_load_lens=[1, 2, 3, 4, 8, 16],
reuse_read=ReuseType(
req="must",
levels=[4],
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 5f048dec00..76c3f5fa8b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -261,11 +261,34 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
const tir::BlockRV& block) const {
- if (!vector_load_lens.empty()) {
- int n = vector_load_lens.size();
+ // Filter out invalid vector lanes according to the data type.
+ const tir::BlockNode* block_node = (*sch)->GetSRef(block)->StmtAs<tir::BlockNode>();
+ ICHECK_EQ(block_node->writes.size(), 1);
+ const runtime::DataType dtype = block_node->writes[0]->buffer->dtype;
+ std::function<bool(int)> f_filter = nullptr;
+ if (dtype == runtime::DataType::Float(32)) {
+ f_filter = [&](int vector_len) { return vector_len <= 4; };
+ } else if (dtype == runtime::DataType::Float(16)) {
+ f_filter = [&](int vector_len) {
+ return (vector_len == 1 || vector_len % 2 == 0) && vector_len <= 8;
+ };
+ } else if (dtype == runtime::DataType::Int(8)) {
+ f_filter = [&](int vector_len) { return vector_len <= 16; };
+ }
+ std::vector<int> valid_vector_lens;
+ valid_vector_lens.reserve(vector_load_lens.size());
+ if (f_filter != nullptr) {
+ std::copy_if(vector_load_lens.begin(), vector_load_lens.end(),
+ std::back_inserter(valid_vector_lens), f_filter);
+ } else {
+ valid_vector_lens = vector_load_lens;
+ }
+
+ if (!valid_vector_lens.empty()) {
+ int n = valid_vector_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
- (*sch)->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
+ (*sch)->SampleCategorical(support::AsArray<int, Integer>(valid_vector_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
(*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len);
}
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 35da2e96b3..d415ae9ce6 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -649,13 +649,13 @@ b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
-v81 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
-v90 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
@@ -783,13 +783,13 @@ b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
-v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
-v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
@@ -883,13 +883,13 @@ b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
-v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
-v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
@@ -1025,13 +1025,13 @@ b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True)
l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85)
l92 = sch.fuse(l90, l91, preserve_unit_iters=True)
-v93 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v93 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93)
b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True)
l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94)
l101 = sch.fuse(l99, l100, preserve_unit_iters=True)
-v102 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v102 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102)
b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a")
sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True)