You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by xi...@apache.org on 2022/08/09 21:44:31 UTC
[tvm] branch main updated: [MetaSchedule] Extend tune_tir to support tuning of specific blocks. (#12342)
This is an automated email from the ASF dual-hosted git repository.
xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new aea82c6417 [MetaSchedule] Extend tune_tir to support tuning of specific blocks. (#12342)
aea82c6417 is described below
commit aea82c64178cbdb458dcc031bc7a88a10d9f742a
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue Aug 9 14:44:26 2022 -0700
[MetaSchedule] Extend tune_tir to support tuning of specific blocks. (#12342)
* Added optional target blocks.
* Checkpoint for debugging.
* Building with packedfunc filter.
* Extended tune_tir API to support named blocks.
* Remove accidental import.
* Improve integration test.
* Change names for more consistency.
* Update integration test.
---
include/tvm/meta_schedule/space_generator.h | 2 +-
.../space_generator/post_order_apply.py | 12 +++-
python/tvm/meta_schedule/tune.py | 27 +++++++
.../space_generator/post_order_apply.cc | 30 ++++++--
.../test_meta_schedule_post_order_apply.py | 82 ++++++++++++++++------
.../python/unittest/test_meta_schedule_tune_tir.py | 57 +++++++++++++--
6 files changed, 171 insertions(+), 39 deletions(-)
diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h
index f7d6cac31c..2df040e5d9 100644
--- a/include/tvm/meta_schedule/space_generator.h
+++ b/include/tvm/meta_schedule/space_generator.h
@@ -153,7 +153,7 @@ class SpaceGenerator : public runtime::ObjectRef {
* to blocks in post-DFS order.
* \return The design space generator created.
*/
- TVM_DLL static SpaceGenerator PostOrderApply();
+ TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
};
diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py
index 80f372a448..6e2a2c52b1 100644
--- a/python/tvm/meta_schedule/space_generator/post_order_apply.py
+++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py
@@ -27,10 +27,18 @@ class PostOrderApply(SpaceGenerator):
"""
PostOrderApply is the design space generator that generates design spaces by applying schedule
rules to blocks in post-DFS order.
+
+ Parameters
+ ----------
+ f_block_filter : Optional[function]
+ An optional callback function that is used to filter which blocks have schedules generated
+ for them. The function should take in a block and return True if a schedule should
+ be generated or False if that block should be skipped. If no function is provided
+ all blocks will have schedules generated.
"""
- def __init__(self):
+ def __init__(self, f_block_filter=None):
"""Constructor"""
self.__init_handle_by_constructor__(
- _ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member
+ _ffi_api.SpaceGeneratorPostOrderApply, f_block_filter # type: ignore # pylint: disable=no-member
)
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index fbbe24b32e..447fb56637 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -24,6 +24,7 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
from tvm.ir import IRModule
from tvm.ir.transform import PassContext
+from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.runtime import Module, NDArray, vm
from tvm.target import Target
from tvm.te import Tensor, create_prim_func
@@ -364,6 +365,7 @@ def tune_tir(
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
space: Optional[FnSpaceGenerator] = None,
+ blocks: Optional[List[str]] = None,
sch_rules: Optional[FnScheduleRule] = None,
postprocs: Optional[FnPostproc] = None,
mutator_probs: Optional[FnMutatorProb] = None,
@@ -392,6 +394,22 @@ def tune_tir(
The cost model to use.
measure_callbacks : Optional[List[MeasureCallback]]
The callbacks used during tuning.
+ space : Optional[FnSpaceGenerator]
+ The space generator to use.
+ blocks : Optional[List[str]]
+ A list of block names specifying blocks to be tuned. Note that if
+ the list is not None, blocks outside this list will not be tuned.
+ Only one of this argument and space may be provided.
+ sch_rules : Optional[FnScheduleRule]
+ The search rules to use.
+ postprocs : Optional[FnPostproc]
+ The postprocessors to use.
+ mutator_probs : Optional[FnMutatorProb]
+ The probability distribution to use different mutators.
+ task_name : str
+ The name of the function to extract schedules from.
+ num_threads : Optional[int]
+ The number of threads to use
Returns
-------
@@ -407,6 +425,15 @@ def tune_tir(
params=[{"log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}"}],
)
+ if blocks is not None:
+ assert space is None, "Can not specify blocks to tune when a search space is given."
+ # Create a filter function to identify named blocks.
+ def _f_block_filter(block, target_names) -> bool:
+ return block.name_hint in target_names
+
+ # Create a space generator that targets specific blocks.
+ space = PostOrderApply(f_block_filter=lambda block: _f_block_filter(block, blocks))
+
# pylint: disable=protected-access
mod = default_config.mod(mod)
target = default_config.target(target)
diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc
index 50b49943f5..51dea2c2fe 100644
--- a/src/meta_schedule/space_generator/post_order_apply.cc
+++ b/src/meta_schedule/space_generator/post_order_apply.cc
@@ -24,8 +24,9 @@ namespace meta_schedule {
/*! \brief Collecting all the blocks */
class BlockCollector : public tir::StmtVisitor {
public:
- static Array<tir::BlockRV> Collect(const tir::Schedule& sch) { //
- return BlockCollector(sch).Run();
+ static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
+ const runtime::PackedFunc f_block_filter = nullptr) { //
+ return BlockCollector(sch, f_block_filter).Run();
}
private:
@@ -48,7 +49,9 @@ class BlockCollector : public tir::StmtVisitor {
return results;
}
/*! \brief Constructor */
- explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {}
+ explicit BlockCollector(const tir::Schedule& sch,
+ const runtime::PackedFunc f_block_filter = nullptr)
+ : sch_(sch), f_block_filter_(f_block_filter) {}
/*! \brief Override the Stmt visiting behaviour */
void VisitStmt_(const tir::BlockNode* block) override {
tir::StmtVisitor::VisitStmt_(block);
@@ -56,11 +59,22 @@ class BlockCollector : public tir::StmtVisitor {
<< "Duplicated block name " << block->name_hint << " in function " << func_name_
<< " not supported!";
block_names_.insert(block->name_hint);
- blocks_to_collect_.push_back(block->name_hint);
+
+ // If filter function is provided, use it to selectively collect blocks.
+ // Otherwise collect all blocks.
+ Bool collect_block = Bool(true);
+ if (f_block_filter_ != nullptr) {
+ collect_block = f_block_filter_(GetRef<tir::Block>(block));
+ }
+ if (collect_block) {
+ blocks_to_collect_.push_back(block->name_hint);
+ }
}
/*! \brief The schedule to be collected */
const tir::Schedule& sch_;
+ /*! \brief An optional packed func that allows only certain blocks to be collected. */
+ const runtime::PackedFunc f_block_filter_;
/*! \brief The set of func name and block name pair */
std::unordered_set<String> block_names_;
/* \brief The list of blocks to collect in order */
@@ -81,6 +95,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
Array<ScheduleRule> sch_rules_{nullptr};
/*! \brief The logging function to use. */
PackedFunc logging_func;
+ /*! \brief Optional block names to target. If not specified all blocks will have spaces generated.
+ */
+ runtime::PackedFunc f_block_filter_ = nullptr;
void VisitAttrs(tvm::AttrVisitor* v) {
// `rand_state_` is not visited
@@ -107,7 +124,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
Array<tir::Schedule> result{sch};
// Enumerate the schedule rules first because you can
// always concat multiple schedule rules as one
- Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch);
+ Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch, f_block_filter_);
Array<Optional<ScheduleRule>> rules{NullOpt};
rules.insert(rules.end(), sch_rules_.begin(), sch_rules_.end());
for (Optional<ScheduleRule> sch_rule : rules) {
@@ -177,8 +194,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode);
};
-SpaceGenerator SpaceGenerator::PostOrderApply() {
+SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter) {
ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>();
+ n->f_block_filter_ = f_block_filter;
return SpaceGenerator(n);
}
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index 21d29ac74d..97a49602fb 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -195,6 +195,29 @@ class DoubleScheduleRule(PyScheduleRule):
return result
+@derived_object
+class TrinityDoubleRule(PyScheduleRule):
+ def _initialize_with_tune_context(self, context: "TuneContext") -> None:
+ pass
+
+ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
+ if _is_root(sch, block):
+ return [sch]
+ new_sch = sch.copy()
+ i, j = new_sch.get_loops(block=block)
+ i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
+ j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
+ new_sch.reorder(i_0, j_0, i_1, j_1)
+ result = [new_sch]
+ new_sch = sch.copy()
+ i, j = new_sch.get_loops(block=block)
+ i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
+ j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
+ new_sch.reorder(i_0, j_0, i_1, j_1)
+ result.append(new_sch)
+ return result
+
+
@derived_object
class ReorderScheduleRule(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -283,28 +306,6 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():
def test_meta_schedule_post_order_apply_remove_block():
- @derived_object
- class TrinityDouble(PyScheduleRule):
- def _initialize_with_tune_context(self, context: "TuneContext") -> None:
- pass
-
- def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
- if _is_root(sch, block):
- return [sch]
- new_sch = sch.copy()
- i, j = new_sch.get_loops(block=block)
- i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
- j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
- new_sch.reorder(i_0, j_0, i_1, j_1)
- result = [new_sch]
- new_sch = sch.copy()
- i, j = new_sch.get_loops(block=block)
- i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
- j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
- new_sch.reorder(i_0, j_0, i_1, j_1)
- result.append(new_sch)
- return result
-
@derived_object
class RemoveBlock(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
@@ -342,7 +343,7 @@ def test_meta_schedule_post_order_apply_remove_block():
target=Target("llvm"),
task_name="Remove Block Task",
space_generator=PostOrderApply(),
- sch_rules=[RemoveBlock(), TrinityDouble()],
+ sch_rules=[RemoveBlock(), TrinityDoubleRule()],
)
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
@@ -385,5 +386,40 @@ def test_meta_schedule_custom_search_space():
assert called
+def test_target_blocks_search_space():
+ # Test that specific blocks of trinity matmul can be targeted.
+ def filter_fn(block, target_names) -> bool:
+ return block.name_hint in target_names
+
+ def _get_sch(filter_fn):
+ mod = TrinityMatmul
+ context = TuneContext(
+ mod=mod,
+ target=Target("llvm"),
+ task_name="Custom Search Space Task",
+ space_generator=PostOrderApply(f_block_filter=filter_fn),
+ sch_rules=[TrinityDoubleRule()],
+ )
+ post_order_apply = context.space_generator
+ schs = post_order_apply.generate_design_space(mod)
+ return schs
+
+ # Start by checking that by default each block has a space generated.
+ schs = _get_sch(None)
+ assert len(schs) == 8
+
+ # Next check that we can target a specific block and only get its' revelant schedules.
+ schs = _get_sch(lambda block: filter_fn(block, ["B"]))
+ assert len(schs) == 2
+
+ ## Check that extracting two blocks works.
+ schs = _get_sch(lambda block: filter_fn(block, ["A", "C"]))
+ assert len(schs) == 4
+
+ ## Finally check that all blocks can be extracted by name.
+ schs = _get_sch(lambda block: filter_fn(block, ["A", "B", "C"]))
+ assert len(schs) == 8
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py
index 0e8c205230..6ab5f9b8c5 100644
--- a/tests/python/unittest/test_meta_schedule_tune_tir.py
+++ b/tests/python/unittest/test_meta_schedule_tune_tir.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=missing-docstring
+# pylint: disable=missing-docstring,no-member,invalid-name,unused-variable
import logging
import tempfile
import numpy as np
@@ -23,20 +23,19 @@ import pytest
import tvm
from tvm import meta_schedule as ms
-from tvm.meta_schedule import TuneConfig, tune_tir
+from tvm.meta_schedule import TuneContext, TuneConfig, tune_tir
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.local_rpc import LocalRPC
+from tvm.meta_schedule.schedule_rule import PyScheduleRule
+from tvm.meta_schedule.utils import derived_object
from tvm.script import tir as T
from tvm.target import Target
-from tvm.tir import Schedule
+from tvm.tir.schedule import BlockRV, Schedule
logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
-# pylint: disable=no-member,invalid-name,unused-variable
-
-
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
@@ -50,7 +49,19 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
-# pylint: enable=no-member,invalid-name,unused-variable
+@T.prim_func
+def two_step(a: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (1024, 1024), "float32")
+ B = T.alloc_buffer((1024, 1024), "float32")
+ C = T.match_buffer(c, (1024, 1024), "float32")
+ for i, j in T.grid(1024, 1024):
+ with T.block("A"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+ for i, j in T.grid(1024, 1024):
+ with T.block("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi, vj] + 3.0
@pytest.mark.skip("Integration test")
@@ -74,6 +85,37 @@ def test_tune_matmul_cpu():
print(sch.trace)
+@pytest.mark.skip("Integration test")
+def test_tune_block_cpu():
+ @derived_object
+ class RemoveBlock(PyScheduleRule):
+ def _initialize_with_tune_context(self, context: TuneContext) -> None:
+ pass
+
+ def apply(self, sch: Schedule, block: BlockRV):
+ if sch.get(block).name_hint == "root":
+ return [sch]
+ sch = sch.copy()
+ sch.compute_inline(block)
+ return [sch]
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ sch: Schedule = tune_tir(
+ mod=two_step,
+ target=Target("llvm --num-cores=16"),
+ config=TuneConfig(
+ strategy="replay_trace",
+ num_trials_per_iter=32,
+ max_trials_per_task=32,
+ max_trials_global=32,
+ ),
+ work_dir=work_dir,
+ blocks=["A"],
+ sch_rules=lambda *args: [RemoveBlock()],
+ )
+ assert sch is not None
+
+
@pytest.mark.skip("Integration test")
def test_tune_matmul_cuda():
with tempfile.TemporaryDirectory() as work_dir:
@@ -141,3 +183,4 @@ if __name__ == """__main__""":
test_tune_matmul_cpu()
test_tune_matmul_cuda()
test_tune_run_module_via_rpc()
+ test_tune_block_cpu()