You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/10/02 13:38:33 UTC
[incubator-tvm] branch master updated: [Ansor] Support multiple
output ops and fix Python API printing (#6584)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 72969b2 [Ansor] Support multiple output ops and fix Python API printing (#6584)
72969b2 is described below
commit 72969b27c6476d2269a3b0caf0c100cdae000c7a
Author: Cody Yu <co...@gmail.com>
AuthorDate: Fri Oct 2 06:38:22 2020 -0700
[Ansor] Support multiple output ops and fix Python API printing (#6584)
---
src/auto_scheduler/compute_dag.cc | 32 ++++----
.../search_policy/sketch_policy_rules.h | 8 ++
src/auto_scheduler/transform_step.cc | 87 +++++++++++++---------
src/auto_scheduler/utils.h | 8 +-
4 files changed, 81 insertions(+), 54 deletions(-)
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index 7c9ce4c..23b3817 100755
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -969,6 +969,9 @@ void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
}
} // end for placeholder
} // end for stage
+ p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors);
+ p_dag->ops = p_dag->access_analyzer->ops_topo_order;
+ p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops);
}
std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
@@ -989,16 +992,15 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
if (stage_to_axes == nullptr) {
stage_to_axes = &temp_stage_to_axes;
}
- Array<te::Operation> ops;
+ Array<te::Operation> out_ops;
for (const auto& op : operator->()->ops) {
- if (!op->IsInstance<te::PlaceholderOpNode>()) {
- ops.push_back(op);
+ if (operator->()->access_analyzer.IsOutput(op)) {
+ out_ops.push_back(op);
}
}
+
// Create the initial schedule
- // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler,
- // update this after testing with multiple outputs.
- te::Schedule schedule = te::create_schedule({ops.back()});
+ te::Schedule schedule = te::create_schedule(out_ops);
// init axes
for (const auto& x : operator->()->ops) {
@@ -1019,16 +1021,14 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const {
Array<te::Stage> stages;
StageToAxesMap stage_to_axes;
- Array<te::Operation> ops;
+ Array<te::Operation> out_ops;
for (const auto& op : operator->()->ops) {
- if (!op->IsInstance<te::PlaceholderOpNode>()) {
- ops.push_back(op);
+ if (operator->()->access_analyzer.IsOutput(op)) {
+ out_ops.push_back(op);
}
}
// Create the initial schedule
- // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler,
- // update this after testing with multiple outputs.
- te::Schedule schedule = te::create_schedule({ops.back()});
+ te::Schedule schedule = te::create_schedule(out_ops);
// init axes
for (const auto& x : operator->()->ops) {
@@ -1040,16 +1040,18 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
std::stringstream ss;
for (const auto& stage : stages) {
if (stage->op->IsInstance<te::ComputeOpNode>()) {
+ auto op_name = CleanName(stage->op->name);
+
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
- ss << stage->leaf_iter_vars[i]->var->name_hint;
+ ss << CleanName(stage->leaf_iter_vars[i]->var->name_hint, op_name);
if (i != stage->leaf_iter_vars.size() - 1) {
ss << ", ";
}
}
ss << " = "
- << "tuple(" << stage->op->name << ".op.axis)"
+ << "tuple(" << op_name << ".op.axis)"
<< " + "
- << "tuple(" << stage->op->name << ".op.reduce_axis)\n";
+ << "tuple(" << op_name << ".op.reduce_axis)\n";
}
}
// Call each step's PrintAsPythonAPI method
diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h
index 928efc5..035dc89 100644
--- a/src/auto_scheduler/search_policy/sketch_policy_rules.h
+++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h
@@ -28,6 +28,7 @@
#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/search_task.h>
+#include <string>
#include <utility>
#include <vector>
@@ -74,6 +75,12 @@ class SketchGenerationRule {
*/
virtual std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy,
const State& state, int stage_id) const = 0;
+
+ /*!
+ * \brief Get the name of this rule.
+ * \return A string of the rule name.
+ */
+ virtual std::string GetRuleName() const = 0;
};
#define DEFINE_SKETCH_GENERATION_RULE(rule_name) \
@@ -83,6 +90,7 @@ class SketchGenerationRule {
int stage_id) const final; \
std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \
int stage_id) const final; \
+ std::string GetRuleName() const final { return #rule_name; } \
};
/*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index 2a93497..73f6734 100755
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -356,8 +356,9 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
std::stringstream ss;
const auto& stage = (*stages)[stage_id];
const auto& iter = (*stage_to_axes)[stage][iter_id];
+ const auto& op_name = CleanName(stage->op->name);
- ss << "s[" << CleanName(stage->op->name) << "].";
+ ss << "s[" << op_name << "].";
switch (annotation) {
case IteratorAnnotation::kUnroll:
ss << "unroll(";
@@ -383,7 +384,7 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
LOG(FATAL) << "Invalid annotation " << static_cast<int>(annotation);
break;
}
- ss << CleanName(iter->var->name_hint);
+ ss << CleanName(iter->var->name_hint, op_name);
switch (annotation) {
case IteratorAnnotation::kVThread:
case IteratorAnnotation::kBlockX:
@@ -392,7 +393,7 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
case IteratorAnnotation::kThreadX:
case IteratorAnnotation::kThreadY:
case IteratorAnnotation::kThreadZ:
- ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
+ ss << ", te.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
<< "\")";
break;
default:
@@ -541,10 +542,11 @@ IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
const auto& stage = (*stages)[stage_id];
+ const auto& op_name = CleanName(stage->op->name);
std::stringstream to_fuse;
for (size_t i = 0; i < fused_ids.size(); ++i) {
- to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint);
+ to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint, op_name);
if (i != fused_ids.size() - 1) {
to_fuse << ", ";
}
@@ -553,7 +555,7 @@ String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
std::stringstream ss;
const auto& fused = ApplyToSchedule(stages, stage_to_axes);
- ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse("
+ ss << CleanName(fused->var->name_hint, op_name) << " = s[" << op_name << "].fuse("
<< to_fuse.str() << ")\n";
return ss.str();
@@ -640,6 +642,7 @@ String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
std::stringstream ss;
const auto& stage = (*stages)[stage_id];
+ const auto& op_name = CleanName(stage->op->name);
if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
size_t pos = 0;
@@ -650,16 +653,16 @@ String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
}
CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
int value = atoi(pragma_type.c_str() + pos + 1);
- ss << "s[" << CleanName(stage->op->name) << "].pragma("
- << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+ ss << "s[" << op_name << "].pragma("
+ << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name)
<< ", \"auto_unroll_max_step\", " << value << ")\n";
- ss << "s[" << CleanName(stage->op->name) << "].pragma("
- << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+ ss << "s[" << op_name << "].pragma("
+ << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name)
<< ", \"unroll_explicit\", True)\n";
} else {
- ss << "s[" << CleanName(stage->op->name) << "].pragma("
- << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type
- << "\")\n";
+ ss << "s[" << op_name << "].pragma("
+ << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \""
+ << pragma_type << "\")\n";
}
ApplyToSchedule(stages, stage_to_axes);
@@ -726,11 +729,12 @@ void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
const auto& stage = (*stages)[stage_id];
+ const auto& op_name = CleanName(stage->op->name);
std::stringstream ss;
- ss << "s[" << CleanName(stage->op->name) << "].reorder(";
+ ss << "s[" << op_name << "].reorder(";
for (size_t i = 0; i < after_ids.size(); ++i) {
- ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint);
+ ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint, op_name);
if (i != after_ids.size() - 1) {
ss << ", ";
}
@@ -881,16 +885,17 @@ String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_
int size = static_cast<int>(lengths.size());
if (inner_to_outer) {
for (int i = size - 1; i >= 0; i--) {
- ss << CleanName(outs[size - i]->var->name_hint) << ", "
- << CleanName(outs[size - i - 1]->var->name_hint) << " = s[" << func_name << "].split("
- << CleanName(to_split->var->name_hint) << ", factor=" << lengths[i] << ")\n";
+ ss << CleanName(outs[size - i]->var->name_hint, func_name) << ", "
+ << CleanName(outs[size - i - 1]->var->name_hint, func_name) << " = s[" << func_name
+ << "].split(" << CleanName(to_split->var->name_hint, func_name)
+ << ", factor=" << lengths[i] << ")\n";
to_split = outs[size - i];
}
} else {
for (int i = 0; i < size; i++) {
- ss << CleanName(outs[i]->var->name_hint) << ", " << CleanName(outs[i + 1]->var->name_hint)
- << " = s[" << func_name << "].split(" << CleanName(to_split->var->name_hint)
- << ", nparts=" << lengths[i] << ")\n";
+ ss << CleanName(outs[i]->var->name_hint, func_name) << ", "
+ << CleanName(outs[i + 1]->var->name_hint, func_name) << " = s[" << func_name << "].split("
+ << CleanName(to_split->var->name_hint, func_name) << ", nparts=" << lengths[i] << ")\n";
to_split = outs[i + 1];
}
}
@@ -1195,9 +1200,10 @@ String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
std::stringstream ss;
const auto& stage = (*stages)[stage_id];
- ss << "s[" << CleanName(stage->op->name) << "].storage_align("
- << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", "
- << offset << ")\n";
+ const auto& op_name = CleanName(stage->op->name);
+ ss << "s[" << op_name << "].storage_align("
+ << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", " << factor
+ << ", " << offset << ")\n";
ApplyToSchedule(stages, stage_to_axes);
return ss.str();
@@ -1269,8 +1275,11 @@ String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
std::stringstream ss;
const auto& stage = (*stages)[stage_id];
const auto& target_stage = (*stages)[target_stage_id];
- ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name)
- << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n";
+ const auto& op_name = CleanName(stage->op->name);
+ const auto& target_op_name = CleanName(target_stage->op->name);
+ ss << "s[" << op_name << "].compute_at(s[" << target_op_name << "], "
+ << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint, target_op_name)
+ << ")\n";
ApplyToSchedule(stages, stage_to_axes);
return ss.str();
}
@@ -1516,7 +1525,8 @@ String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxes
}
auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
- ss << CleanName(out->op->name) << " = "
+ const auto& op_name = CleanName(out->op->name);
+ ss << op_name << " = "
<< "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", ["
<< CleanName(reader_stages[0]->op->name);
for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
@@ -1527,13 +1537,13 @@ String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxes
// Print the iterators of the new added stage
const auto& iters = out->op->root_iter_vars();
for (size_t i = 0; i < iters.size(); ++i) {
- ss << CleanName(iters[i]->var->name_hint);
+ ss << CleanName(iters[i]->var->name_hint, op_name);
if (i != iters.size() - 1) {
ss << ", ";
}
}
ss << " = "
- << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+ << "tuple(" << op_name << ".op.axis)\n";
return ss.str();
}
@@ -1652,16 +1662,17 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxe
// Print the iterators of the new added stage
for (const auto& out : outs) {
const auto& iters = out->op->root_iter_vars();
+ const auto& op_name = CleanName(out->op->name);
for (size_t i = 0; i < iters.size(); ++i) {
- ss << CleanName(iters[i]->var->name_hint);
+ ss << CleanName(iters[i]->var->name_hint, op_name);
if (i != iters.size() - 1) {
ss << ", ";
}
}
ss << " = "
- << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+ << "tuple(" << op_name << ".op.axis)"
<< " + "
- << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+ << "tuple(" << op_name << ".op.reduce_axis)\n";
}
return ss.str();
@@ -1764,30 +1775,32 @@ String RfactorStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMa
for (const auto& out : outs) {
const auto& iters = out->op->root_iter_vars();
+ const auto& op_name = CleanName(out->op->name);
for (size_t i = 0; i < iters.size(); ++i) {
- ss << CleanName(iters[i]->var->name_hint);
+ ss << CleanName(iters[i]->var->name_hint, op_name);
if (i != iters.size() - 1) {
ss << ", ";
}
}
ss << " = "
- << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+ << "tuple(" << op_name << ".op.axis)"
<< " + "
- << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+ << "tuple(" << op_name << ".op.reduce_axis)\n";
}
const auto& output = (*stages)[stage_id + 1]->op.output(0);
const auto& iters = output->op->root_iter_vars();
+ const auto& op_name = CleanName(output->op->name);
for (size_t i = 0; i < iters.size(); ++i) {
- ss << CleanName(iters[i]->var->name_hint);
+ ss << CleanName(iters[i]->var->name_hint, op_name);
if (i != iters.size() - 1) {
ss << ", ";
}
}
ss << " = "
- << "tuple(s[" << CleanName(output->op->name) << "].op.axis)"
+ << "tuple(s[" << op_name << "].op.axis)"
<< " + "
- << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n";
+ << "tuple(s[" << op_name << "].op.reduce_axis)\n";
return ss.str();
}
diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h
index d036743..610fec9 100755
--- a/src/auto_scheduler/utils.h
+++ b/src/auto_scheduler/utils.h
@@ -209,16 +209,20 @@ inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) {
}
/*!
- * \brief Clean the name of an iterator to make it valid in python code.
+ * \brief Clean the name of an iterator or an op to make it valid in python code.
* \param str The original name.
+ * \param prefix The name prefix to differentiate the same name (e.g., the same iterator names).
* \return The cleaned name.
*/
-inline std::string CleanName(const std::string& str) {
+inline std::string CleanName(const std::string& str, const std::string& prefix = "") {
std::string ret = str;
StrReplace(&ret, ".", "_");
StrReplace(&ret, "@", "_");
StrReplace(&ret, "outer", "o");
StrReplace(&ret, "inner", "i");
+ if (prefix != "") {
+ return prefix + "_" + ret;
+ }
return ret;
}