You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/08/09 21:44:31 UTC

[tvm] branch main updated: [MetaSchedule] Extend tune_tir to support tuning of specific blocks. (#12342)

This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new aea82c6417 [MetaSchedule] Extend tune_tir to support tuning of specific blocks. (#12342)
aea82c6417 is described below

commit aea82c64178cbdb458dcc031bc7a88a10d9f742a
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue Aug 9 14:44:26 2022 -0700

    [MetaSchedule] Extend tune_tir to support tuning of specific blocks. (#12342)
    
    * Added optional target blocks.
    
    * Checkpoint for debugging.
    
    * Building with packedfunc filter.
    
    * Extended tune_tir API to support named blocks.
    
    * Remove accidental import.
    
    * Improve integration test.
    
    * Change names for more consistency.
    
    * Update integration test.
---
 include/tvm/meta_schedule/space_generator.h        |  2 +-
 .../space_generator/post_order_apply.py            | 12 +++-
 python/tvm/meta_schedule/tune.py                   | 27 +++++++
 .../space_generator/post_order_apply.cc            | 30 ++++++--
 .../test_meta_schedule_post_order_apply.py         | 82 ++++++++++++++++------
 .../python/unittest/test_meta_schedule_tune_tir.py | 57 +++++++++++++--
 6 files changed, 171 insertions(+), 39 deletions(-)

diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h
index f7d6cac31c..2df040e5d9 100644
--- a/include/tvm/meta_schedule/space_generator.h
+++ b/include/tvm/meta_schedule/space_generator.h
@@ -153,7 +153,7 @@ class SpaceGenerator : public runtime::ObjectRef {
    *  to blocks in post-DFS order.
    * \return The design space generator created.
    */
-  TVM_DLL static SpaceGenerator PostOrderApply();
+  TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr);
   TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
 };
 
diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py
index 80f372a448..6e2a2c52b1 100644
--- a/python/tvm/meta_schedule/space_generator/post_order_apply.py
+++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py
@@ -27,10 +27,18 @@ class PostOrderApply(SpaceGenerator):
     """
     PostOrderApply is the design space generator that generates design spaces by applying schedule
     rules to blocks in post-DFS order.
+
+    Parameters
+    ----------
+    f_block_filter : Optional[function]
+        An optional callback function that is used to filter which blocks have schedules generated
+        for them. The function should take in a block and return True if a schedule should
+        be generated or False if that block should be skipped. If no function is provided
+        all blocks will have schedules generated.
     """
 
-    def __init__(self):
+    def __init__(self, f_block_filter=None):
         """Constructor"""
         self.__init_handle_by_constructor__(
-            _ffi_api.SpaceGeneratorPostOrderApply,  # type: ignore # pylint: disable=no-member
+            _ffi_api.SpaceGeneratorPostOrderApply, f_block_filter  # type: ignore # pylint: disable=no-member
         )
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index fbbe24b32e..447fb56637 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -24,6 +24,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
 
 from tvm.ir import IRModule
 from tvm.ir.transform import PassContext
+from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
 from tvm.runtime import Module, NDArray, vm
 from tvm.target import Target
 from tvm.te import Tensor, create_prim_func
@@ -364,6 +365,7 @@ def tune_tir(
     cost_model: Optional[CostModel] = None,
     measure_callbacks: Optional[List[MeasureCallback]] = None,
     space: Optional[FnSpaceGenerator] = None,
+    blocks: Optional[List[str]] = None,
     sch_rules: Optional[FnScheduleRule] = None,
     postprocs: Optional[FnPostproc] = None,
     mutator_probs: Optional[FnMutatorProb] = None,
@@ -392,6 +394,22 @@ def tune_tir(
         The cost model to use.
     measure_callbacks : Optional[List[MeasureCallback]]
         The callbacks used during tuning.
+    space : Optional[FnSpaceGenerator]
+        The space generator to use.
+    blocks : Optional[List[str]]
+        A list of block names specifying blocks to be tuned. Note that if
+        the list is not None, blocks outside this list will not be tuned.
+        Only one of this argument and space may be provided.
+    sch_rules : Optional[FnScheduleRule]
+        The search rules to use.
+    postprocs : Optional[FnPostproc]
+        The postprocessors to use.
+    mutator_probs : Optional[FnMutatorProb]
+        The probability distribution to use different mutators.
+    task_name : str
+        The name of the function to extract schedules from.
+    num_threads : Optional[int]
+        The number of threads to use
 
     Returns
     -------
@@ -407,6 +425,15 @@ def tune_tir(
         params=[{"log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}"}],
     )
 
+    if blocks is not None:
+        assert space is None, "Can not specify blocks to tune when a search space is given."
+        # Create a filter function to identify named blocks.
+        def _f_block_filter(block, target_names) -> bool:
+            return block.name_hint in target_names
+
+        # Create a space generator that targets specific blocks.
+        space = PostOrderApply(f_block_filter=lambda block: _f_block_filter(block, blocks))
+
     # pylint: disable=protected-access
     mod = default_config.mod(mod)
     target = default_config.target(target)
diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc
index 50b49943f5..51dea2c2fe 100644
--- a/src/meta_schedule/space_generator/post_order_apply.cc
+++ b/src/meta_schedule/space_generator/post_order_apply.cc
@@ -24,8 +24,9 @@ namespace meta_schedule {
 /*! \brief Collecting all the blocks */
 class BlockCollector : public tir::StmtVisitor {
  public:
-  static Array<tir::BlockRV> Collect(const tir::Schedule& sch) {  //
-    return BlockCollector(sch).Run();
+  static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
+                                     const runtime::PackedFunc f_block_filter = nullptr) {  //
+    return BlockCollector(sch, f_block_filter).Run();
   }
 
  private:
@@ -48,7 +49,9 @@ class BlockCollector : public tir::StmtVisitor {
     return results;
   }
   /*! \brief Constructor */
-  explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {}
+  explicit BlockCollector(const tir::Schedule& sch,
+                          const runtime::PackedFunc f_block_filter = nullptr)
+      : sch_(sch), f_block_filter_(f_block_filter) {}
   /*! \brief Override the Stmt visiting behaviour */
   void VisitStmt_(const tir::BlockNode* block) override {
     tir::StmtVisitor::VisitStmt_(block);
@@ -56,11 +59,22 @@ class BlockCollector : public tir::StmtVisitor {
         << "Duplicated block name " << block->name_hint << " in function " << func_name_
         << " not supported!";
     block_names_.insert(block->name_hint);
-    blocks_to_collect_.push_back(block->name_hint);
+
+    // If filter function is provided, use it to selectively collect blocks.
+    // Otherwise collect all blocks.
+    Bool collect_block = Bool(true);
+    if (f_block_filter_ != nullptr) {
+      collect_block = f_block_filter_(GetRef<tir::Block>(block));
+    }
+    if (collect_block) {
+      blocks_to_collect_.push_back(block->name_hint);
+    }
   }
 
   /*! \brief The schedule to be collected */
   const tir::Schedule& sch_;
+  /*! \brief An optional packed func that allows only certain blocks to be collected. */
+  const runtime::PackedFunc f_block_filter_;
   /*! \brief The set of func name and block name pair */
   std::unordered_set<String> block_names_;
   /* \brief The list of blocks to collect in order */
@@ -81,6 +95,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
   Array<ScheduleRule> sch_rules_{nullptr};
   /*! \brief The logging function to use. */
   PackedFunc logging_func;
+  /*! \brief Optional block names to target. If not specified all blocks will have spaces generated.
+   */
+  runtime::PackedFunc f_block_filter_ = nullptr;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     // `rand_state_` is not visited
@@ -107,7 +124,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
     Array<tir::Schedule> result{sch};
     // Enumerate the schedule rules first because you can
     // always concat multiple schedule rules as one
-    Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch);
+    Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch, f_block_filter_);
     Array<Optional<ScheduleRule>> rules{NullOpt};
     rules.insert(rules.end(), sch_rules_.begin(), sch_rules_.end());
     for (Optional<ScheduleRule> sch_rule : rules) {
@@ -177,8 +194,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
   TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode);
 };
 
-SpaceGenerator SpaceGenerator::PostOrderApply() {
+SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter) {
   ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>();
+  n->f_block_filter_ = f_block_filter;
   return SpaceGenerator(n);
 }
 
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index 21d29ac74d..97a49602fb 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -195,6 +195,29 @@ class DoubleScheduleRule(PyScheduleRule):
         return result
 
 
+@derived_object
+class TrinityDoubleRule(PyScheduleRule):
+    def _initialize_with_tune_context(self, context: "TuneContext") -> None:
+        pass
+
+    def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
+        if _is_root(sch, block):
+            return [sch]
+        new_sch = sch.copy()
+        i, j = new_sch.get_loops(block=block)
+        i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
+        j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
+        new_sch.reorder(i_0, j_0, i_1, j_1)
+        result = [new_sch]
+        new_sch = sch.copy()
+        i, j = new_sch.get_loops(block=block)
+        i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
+        j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
+        new_sch.reorder(i_0, j_0, i_1, j_1)
+        result.append(new_sch)
+        return result
+
+
 @derived_object
 class ReorderScheduleRule(PyScheduleRule):
     def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -283,28 +306,6 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():
 
 
 def test_meta_schedule_post_order_apply_remove_block():
-    @derived_object
-    class TrinityDouble(PyScheduleRule):
-        def _initialize_with_tune_context(self, context: "TuneContext") -> None:
-            pass
-
-        def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
-            if _is_root(sch, block):
-                return [sch]
-            new_sch = sch.copy()
-            i, j = new_sch.get_loops(block=block)
-            i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
-            j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
-            new_sch.reorder(i_0, j_0, i_1, j_1)
-            result = [new_sch]
-            new_sch = sch.copy()
-            i, j = new_sch.get_loops(block=block)
-            i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
-            j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
-            new_sch.reorder(i_0, j_0, i_1, j_1)
-            result.append(new_sch)
-            return result
-
     @derived_object
     class RemoveBlock(PyScheduleRule):
         def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -342,7 +343,7 @@ def test_meta_schedule_post_order_apply_remove_block():
         target=Target("llvm"),
         task_name="Remove Block Task",
         space_generator=PostOrderApply(),
-        sch_rules=[RemoveBlock(), TrinityDouble()],
+        sch_rules=[RemoveBlock(), TrinityDoubleRule()],
     )
     post_order_apply = context.space_generator
     schs = post_order_apply.generate_design_space(mod)
@@ -385,5 +386,40 @@ def test_meta_schedule_custom_search_space():
     assert called
 
 
+def test_target_blocks_search_space():
+    # Test that specific blocks of trinity matmul can be targeted.
+    def filter_fn(block, target_names) -> bool:
+        return block.name_hint in target_names
+
+    def _get_sch(filter_fn):
+        mod = TrinityMatmul
+        context = TuneContext(
+            mod=mod,
+            target=Target("llvm"),
+            task_name="Custom Search Space Task",
+            space_generator=PostOrderApply(f_block_filter=filter_fn),
+            sch_rules=[TrinityDoubleRule()],
+        )
+        post_order_apply = context.space_generator
+        schs = post_order_apply.generate_design_space(mod)
+        return schs
+
+    # Start by checking that by default each block has a space generated.
+    schs = _get_sch(None)
+    assert len(schs) == 8
+
+    # Next check that we can target a specific block and only get its' revelant schedules.
+    schs = _get_sch(lambda block: filter_fn(block, ["B"]))
+    assert len(schs) == 2
+
+    ## Check that extracting two blocks works.
+    schs = _get_sch(lambda block: filter_fn(block, ["A", "C"]))
+    assert len(schs) == 4
+
+    ## Finally check that all blocks can be extracted by name.
+    schs = _get_sch(lambda block: filter_fn(block, ["A", "B", "C"]))
+    assert len(schs) == 8
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py
index 0e8c205230..6ab5f9b8c5 100644
--- a/tests/python/unittest/test_meta_schedule_tune_tir.py
+++ b/tests/python/unittest/test_meta_schedule_tune_tir.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=missing-docstring
+# pylint: disable=missing-docstring,no-member,invalid-name,unused-variable
 import logging
 import tempfile
 import numpy as np
@@ -23,20 +23,19 @@ import pytest
 import tvm
 
 from tvm import meta_schedule as ms
-from tvm.meta_schedule import TuneConfig, tune_tir
+from tvm.meta_schedule import TuneContext, TuneConfig, tune_tir
 from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
 from tvm.meta_schedule.testing.local_rpc import LocalRPC
+from tvm.meta_schedule.schedule_rule import PyScheduleRule
+from tvm.meta_schedule.utils import derived_object
 from tvm.script import tir as T
 from tvm.target import Target
-from tvm.tir import Schedule
+from tvm.tir.schedule import BlockRV, Schedule
 
 logging.basicConfig()
 logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
 
 
-# pylint: disable=no-member,invalid-name,unused-variable
-
-
 @T.prim_func
 def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
     A = T.match_buffer(a, [128, 128])
@@ -50,7 +49,19 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
             C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
 
 
-# pylint: enable=no-member,invalid-name,unused-variable
+@T.prim_func
+def two_step(a: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (1024, 1024), "float32")
+    B = T.alloc_buffer((1024, 1024), "float32")
+    C = T.match_buffer(c, (1024, 1024), "float32")
+    for i, j in T.grid(1024, 1024):
+        with T.block("A"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            B[vi, vj] = A[vi, vj] * 2.0
+    for i, j in T.grid(1024, 1024):
+        with T.block("B"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            C[vi, vj] = B[vi, vj] + 3.0
 
 
 @pytest.mark.skip("Integration test")
@@ -74,6 +85,37 @@ def test_tune_matmul_cpu():
             print(sch.trace)
 
 
+@pytest.mark.skip("Integration test")
+def test_tune_block_cpu():
+    @derived_object
+    class RemoveBlock(PyScheduleRule):
+        def _initialize_with_tune_context(self, context: TuneContext) -> None:
+            pass
+
+        def apply(self, sch: Schedule, block: BlockRV):
+            if sch.get(block).name_hint == "root":
+                return [sch]
+            sch = sch.copy()
+            sch.compute_inline(block)
+            return [sch]
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        sch: Schedule = tune_tir(
+            mod=two_step,
+            target=Target("llvm --num-cores=16"),
+            config=TuneConfig(
+                strategy="replay_trace",
+                num_trials_per_iter=32,
+                max_trials_per_task=32,
+                max_trials_global=32,
+            ),
+            work_dir=work_dir,
+            blocks=["A"],
+            sch_rules=lambda *args: [RemoveBlock()],
+        )
+        assert sch is not None
+
+
 @pytest.mark.skip("Integration test")
 def test_tune_matmul_cuda():
     with tempfile.TemporaryDirectory() as work_dir:
@@ -141,3 +183,4 @@ if __name__ == """__main__""":
     test_tune_matmul_cpu()
     test_tune_matmul_cuda()
     test_tune_run_module_via_rpc()
+    test_tune_block_cpu()