You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2021/01/07 01:50:39 UTC
[tvm] branch main updated: [AutoScheduler][Relay] Control compile
engine cache via PassContext (#7220)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 93d79ba [AutoScheduler][Relay] Control compile engine cache via PassContext (#7220)
93d79ba is described below
commit 93d79bafcf854a928d248aab92782da36eec3b4a
Author: Cody Yu <co...@gmail.com>
AuthorDate: Wed Jan 6 17:50:21 2021 -0800
[AutoScheduler][Relay] Control compile engine cache via PassContext (#7220)
* [AutoScheduler][Relay] Control compile engine cache via PassContext
* lint
* lint
---
python/tvm/auto_scheduler/relay_integration.py | 35 +++++++++-----------------
src/relay/backend/compile_engine.cc | 5 +++-
src/relay/backend/utils.h | 9 +++++++
3 files changed, 25 insertions(+), 24 deletions(-)
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index eecf88b..ea1a8cc 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -56,7 +56,10 @@ def call_all_topi_funcs(mod, params, target):
with transform.PassContext(
opt_level=3,
- config={"relay.backend.use_auto_scheduler": True},
+ config={
+ "relay.backend.use_auto_scheduler": True,
+ "relay.backend.disable_compile_engine_cache": True,
+ },
disabled_pass={"AutoSchedulerLayoutRewrite"},
):
try:
@@ -105,7 +108,6 @@ def extract_tasks(
The weight (i.e. the number of appearance) of extracted tasks
"""
# pylint: disable=import-outside-toplevel
- from tvm import relay
if isinstance(target, str):
target = tvm.target.Target(target)
@@ -123,17 +125,10 @@ def extract_tasks(
build_thread.start()
build_thread.join()
- # query the compile engine to get the number of occurrence of all tasks
- engine = relay.backend.compile_engine.get()
- use_count_dict = {}
- for k, v in engine.items():
- use_count_dict[k] = v.use_count
-
# create search tasks
tasks = []
weights = []
- for wkl_key, ccache_key in env.wkl_key_to_ccache_key.items():
- dag = ComputeDAG(wkl_key)
+ for wkl_key, weight in env.wkl_key_to_weight.items():
tasks.append(
SearchTask(
workload_key=wkl_key,
@@ -145,10 +140,7 @@ def extract_tasks(
layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
)
)
- weights.append(use_count_dict[ccache_key] + 1)
-
- # clean the cached lowering results
- engine.clear()
+ weights.append(weight)
return tasks, weights
@@ -169,7 +161,7 @@ class TracingEnvironment:
def __init__(self, tracing_mode):
self.tracing_mode = tracing_mode
self.relay_disable_build_cache = "false"
- self.wkl_key_to_ccache_key = {}
+ self.wkl_key_to_weight = {}
def __enter__(self):
TracingEnvironment.current = self
@@ -178,17 +170,17 @@ class TracingEnvironment:
def __exit__(self, exc_type, exc_val, exc_tb):
TracingEnvironment.current = None
- def add_workload_key(self, workload_key, ccache_key):
+ def add_workload_key(self, workload_key):
"""Add the workload key of a search task
Parameters
----------
workload_key: str
The workload key of a task
- ccache_key: CCacheKey
- The corresponding ccache_key of the task
"""
- self.wkl_key_to_ccache_key[workload_key] = ccache_key
+ if workload_key not in self.wkl_key_to_weight:
+ self.wkl_key_to_weight[workload_key] = 0
+ self.wkl_key_to_weight[workload_key] += 1
@tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
@@ -278,7 +270,6 @@ def auto_schedule_topi(outs):
An initial schdule in the tracing mode.
"""
# pylint: disable=import-outside-toplevel
- from tvm import relay
io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs)
if not io_tensors: # The compute includes dynamic shapes which are not supported yet.
@@ -305,9 +296,7 @@ def auto_schedule_topi(outs):
elif env.tracing_mode in [TracingMode.EXTRACT_TASK, TracingMode.EXTRACT_COMPLEX_TASK_ONLY]:
# in the task extraction mode
if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK:
- engine = relay.backend.compile_engine.get()
- ccache_key = engine.get_current_ccache_key()
- env.add_workload_key(key, ccache_key)
+ env.add_workload_key(key)
schedule = te.create_schedule([x.op for x in outs])
elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
# in prepare_layout_rewrite mode
diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc
index 789f39d..c969c3b 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -701,7 +701,9 @@ class CompileEngineImpl : public CompileEngineNode {
} else {
value = CCacheValue(make_object<CCacheValueNode>());
value->use_count = 0;
- cache_[key] = value;
+ if (!backend::IsCompileEngineCacheDisabled()) {
+ cache_[key] = value;
+ }
}
cur_ccache_key_ = key;
@@ -832,6 +834,7 @@ CompileEngine& CompileEngine::Global() {
}
TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool);
TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
.set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index e167720..6908ca8 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -303,6 +303,15 @@ inline bool IsAutoSchedulerEnabled() {
.value();
}
+/*!
+ * \brief Return whether the compile engine cache is disabled in the pass context.
+ */
+inline bool IsCompileEngineCacheDisabled() {
+ return transform::PassContext::Current()
+ ->GetConfig<Bool>("relay.backend.disable_compile_engine_cache", Bool(false))
+ .value();
+}
+
} // namespace backend
} // namespace relay
} // namespace tvm