You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2022/08/16 01:32:13 UTC
[tvm] branch main updated: [MetaSchedule] Add logging of usage of tensor intrinsics (#12445)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ecbe4ca0ed [MetaSchedule] Add logging of usage of tensor intrinsics (#12445)
ecbe4ca0ed is described below
commit ecbe4ca0edadeca8fee4d0c2c9f7a9093043b5ee
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Mon Aug 15 18:32:04 2022 -0700
[MetaSchedule] Add logging of usage of tensor intrinsics (#12445)
* [MetaSchedule] Add logging of usage of tensor intrinsics
* fix
---
src/meta_schedule/schedule_rule/multi_level_tiling.cc | 1 +
src/meta_schedule/schedule_rule/multi_level_tiling.h | 2 ++
src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc | 7 +++++--
src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc | 3 +++
4 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index 76c3f5fa8b..eefc2eea41 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -87,6 +87,7 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context)
TVM_PY_LOG(INFO, context->logging_func) << "'thread_warp_size' is not defined in the target";
}
}
+ logging_func = context->logging_func;
}
// Entry of the mega rule; Inherited from ScheduleRuleNode
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h
index 3398255674..9161a972c1 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.h
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h
@@ -186,6 +186,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
int thread_warp_size_;
/*! \brief The maximum number of threads to be used size of a thread warp */
int max_threads_per_block_;
+ /*! \brief The logging function */
+ PackedFunc logging_func;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("structure", &structure);
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
index 6d34f7b64e..7a3ec513db 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
@@ -198,9 +198,12 @@ Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
}
Array<Schedule> results;
for (auto&& state : ApplySubRules(initial_states)) {
+ TVM_PY_LOG(INFO, logging_func) << "Sketch " << results.size() << ": tensorizing with "
+ << state.as<TensorCoreStateNode>()->intrin_group.compute_intrin;
results.push_back(std::move(state->sch));
}
if (results.empty()) {
+ TVM_PY_LOG(INFO, logging_func) << "The workload cannot be tensorized.";
return {original_sch};
}
return results;
@@ -276,8 +279,8 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
} else if (dtype.is_int() && dtype.bits() == 8) {
sch->StorageAlign(cache_read, 0, -2, 32, 16);
} else {
- LOG(WARNING) << "StorageAlign is not applied for data type " << dtype
- << ", shared memory accesses might be inefficient.";
+ TVM_PY_LOG(WARNING, logging_func) << "StorageAlign is not applied for data type " << dtype
+ << ", shared memory accesses might be inefficient.";
}
}
return {state};
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
index 009e5a0d92..3a299ed041 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc
@@ -49,14 +49,17 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
auto desc_func = tir::TensorIntrin::Get(intrin_name)->desc;
if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) {
+ TVM_PY_LOG(INFO, logging_func) << "The workload cannot be tensorized.";
return {sch};
}
auto res = MultiLevelTilingNode::Apply(sch->Copy(), block_rv);
if (res.empty()) {
+ TVM_PY_LOG(INFO, logging_func) << "The workload cannot be tensorized.";
return {sch};
}
+ TVM_PY_LOG(INFO, logging_func) << "Tensorizing with " << intrin_name;
return res;
}