You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/08/13 08:06:21 UTC

[tvm] branch main updated: [MetaSchedule] Filter vector_load_lens based on buffer dtype (#12408)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 96cac7501d [MetaSchedule] Filter vector_load_lens based on buffer dtype (#12408)
96cac7501d is described below

commit 96cac7501dbce42af6c5bb6f75b51a4b70c4077a
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Sat Aug 13 01:06:16 2022 -0700

    [MetaSchedule] Filter vector_load_lens based on buffer dtype (#12408)
    
    This makes the same config generic to work across workloads with different types.
---
 python/tvm/meta_schedule/default_config.py         |  4 +--
 python/tvm/meta_schedule/testing/schedule_rule.py  |  4 +--
 .../schedule_rule/multi_level_tiling.cc            | 29 +++++++++++++++++++---
 ...ta_schedule_schedule_rule_multi_level_tiling.py | 16 ++++++------
 4 files changed, 38 insertions(+), 15 deletions(-)

diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py
index e27b6ad4b4..58f82a248b 100644
--- a/python/tvm/meta_schedule/default_config.py
+++ b/python/tvm/meta_schedule/default_config.py
@@ -293,7 +293,7 @@ class _DefaultCUDA:
                 structure="SSSRRSRS",
                 tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
                 max_innermost_factor=64,
-                vector_load_lens=[1, 2, 3, 4],
+                vector_load_lens=[1, 2, 3, 4, 8, 16],
                 reuse_read=M.ReuseType(
                     req="must",
                     levels=[4],
@@ -374,7 +374,7 @@ class _DefaultCUDATensorCore:
                 structure="SSSRRSRS",
                 tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
                 max_innermost_factor=4,
-                vector_load_lens=[1, 2, 3, 4],
+                vector_load_lens=[1, 2, 3, 4, 8, 16],
                 reuse_read=M.ReuseType(req="must", levels=[4], scope="shared"),
                 reuse_write=M.ReuseType(
                     req="must",
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index ea748ddc05..f5a936f491 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -98,7 +98,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
             structure="SSSRRSRS",
             tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
             max_innermost_factor=64,
-            vector_load_lens=[1, 2, 3, 4],
+            vector_load_lens=[1, 2, 3, 4, 8, 16],
             reuse_read=ReuseType(
                 req="must",
                 levels=[4],
@@ -141,7 +141,7 @@ def multi_level_tiling_tensor_core(
             structure="SSSRRSRS",
             tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
             max_innermost_factor=4,  # 64 // tensor intrin size
-            vector_load_lens=[1, 2, 3, 4],
+            vector_load_lens=[1, 2, 3, 4, 8, 16],
             reuse_read=ReuseType(
                 req="must",
                 levels=[4],
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 5f048dec00..76c3f5fa8b 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -261,11 +261,34 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
 
 void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
                                                        const tir::BlockRV& block) const {
-  if (!vector_load_lens.empty()) {
-    int n = vector_load_lens.size();
+  // Filter out invalid vector lanes according to the data type.
+  const tir::BlockNode* block_node = (*sch)->GetSRef(block)->StmtAs<tir::BlockNode>();
+  ICHECK_EQ(block_node->writes.size(), 1);
+  const runtime::DataType dtype = block_node->writes[0]->buffer->dtype;
+  std::function<bool(int)> f_filter = nullptr;
+  if (dtype == runtime::DataType::Float(32)) {
+    f_filter = [&](int vector_len) { return vector_len <= 4; };
+  } else if (dtype == runtime::DataType::Float(16)) {
+    f_filter = [&](int vector_len) {
+      return (vector_len == 1 || vector_len % 2 == 0) && vector_len <= 8;
+    };
+  } else if (dtype == runtime::DataType::Int(8)) {
+    f_filter = [&](int vector_len) { return vector_len <= 16; };
+  }
+  std::vector<int> valid_vector_lens;
+  valid_vector_lens.reserve(vector_load_lens.size());
+  if (f_filter != nullptr) {
+    std::copy_if(vector_load_lens.begin(), vector_load_lens.end(),
+                 std::back_inserter(valid_vector_lens), f_filter);
+  } else {
+    valid_vector_lens = vector_load_lens;
+  }
+
+  if (!valid_vector_lens.empty()) {
+    int n = valid_vector_lens.size();
     double prob = 1.0 / n;
     tir::ExprRV vector_load_len =
-        (*sch)->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
+        (*sch)->SampleCategorical(support::AsArray<int, Integer>(valid_vector_lens),
                                   Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
     (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len);
   }
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index 35da2e96b3..d415ae9ce6 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -649,13 +649,13 @@ b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
 sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True)
 l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
 l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
-v81 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
 b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
 sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True)
 l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
 l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
-v90 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
 b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
 sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True)
@@ -783,13 +783,13 @@ b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
 sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
 l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
 l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
-v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
 b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
 sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
 l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
 l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
-v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
 b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
 sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
@@ -883,13 +883,13 @@ b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared")
 sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True)
 l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70)
 l77 = sch.fuse(l75, l76, preserve_unit_iters=True)
-v78 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78)
 b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared")
 sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True)
 l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79)
 l86 = sch.fuse(l84, l85, preserve_unit_iters=True)
-v87 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87)
 b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a")
 sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True)
@@ -1025,13 +1025,13 @@ b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared")
 sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True)
 l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85)
 l92 = sch.fuse(l90, l91, preserve_unit_iters=True)
-v93 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v93 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93)
 b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared")
 sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True)
 l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94)
 l101 = sch.fuse(l99, l100, preserve_unit_iters=True)
-v102 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
+v102 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
 sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102)
 b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a")
 sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True)