You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/07/21 19:58:59 UTC
[incubator-tvm] branch master updated: [Ansor][AutoTVM v2.0] Phase
1: Add annotation/compute_at/compute_root/compute_inline steps (#6073)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 71533a5 [Ansor][AutoTVM v2.0] Phase 1: Add annotation/compute_at/compute_root/compute_inline steps (#6073)
71533a5 is described below
commit 71533a5c252d9c501a20b906c9cc5e6471d3f686
Author: Chenfan <ch...@alibaba-inc.com>
AuthorDate: Wed Jul 22 03:58:48 2020 +0800
[Ansor][AutoTVM v2.0] Phase 1: Add annotation/compute_at/compute_root/compute_inline steps (#6073)
* Add annotation step
* Add compute_at/compute_root/compute_inline
* Doc update
* Update
* Update
* Update measure record UT
* Update
* Update
* Update
* Move state implementation to step
* Move measure_record implementation to step
* Order update & API update
* Update the order of state api
* Update
---
python/tvm/auto_scheduler/loop_state.py | 228 ++++++-
src/auto_scheduler/compute_dag.cc | 26 +-
src/auto_scheduler/loop_state.cc | 359 +++++-----
src/auto_scheduler/loop_state.h | 274 ++++----
src/auto_scheduler/measure_record.cc | 122 +---
src/auto_scheduler/transform_step.cc | 754 ++++++++++++++++++++-
src/auto_scheduler/transform_step.h | 503 +++++++++++++-
src/auto_scheduler/utils.h | 32 +-
.../python/unittest/test_auto_scheduler_common.py | 2 +-
.../unittest/test_auto_scheduler_loop_state.py | 89 ++-
.../python/unittest/test_auto_scheduler_measure.py | 42 +-
11 files changed, 1914 insertions(+), 517 deletions(-)
diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py
index 693a668..ab041cf 100644
--- a/python/tvm/auto_scheduler/loop_state.py
+++ b/python/tvm/auto_scheduler/loop_state.py
@@ -83,6 +83,24 @@ class State:
-----
This is a wrapper class of StateObject to deal with copy-on-write property
"""
+
+ # Static trans table for thread bind
+ # This is used to transform the annotation name to C++ enum
+ ANNOTATION_TRANS_TABLE = {
+ "none": 0,
+ "unroll": 1,
+ "vectorize": 2,
+ "parallel": 3,
+ "vthread": 4,
+ "blockIdx.x": 5,
+ "threadIdx.x": 6,
+ "blockIdx.y": 7,
+ "threadIdx.y": 8,
+ "blockIdx.z": 9,
+ "threadIdx.z": 10,
+ "tensorize": 11
+ }
+
def __init__(self, state_object, dag):
self.state_object = state_object
self.compute_dag = dag
@@ -108,20 +126,140 @@ class State:
"""
return [stage.op for stage in self.stages]
+ def bind(self, stage, iterator, thread_name):
+ """ Schedule primitive corresponds to te.bind.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be binded, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to be binded.
+ thread_name : str
+ The thread type to be binded. Candidates:
+ - vthread
+ - blockIdx.x
+ - threadIdx.x
+ - blockIdx.y
+ - threadIdx.y
+ - blockIdx.z
+ - threadIdx.z
+
+ Returns
+ -------
+ res_it : Iterator
+ The binded Iterator.
+ """
+ if not thread_name in State.ANNOTATION_TRANS_TABLE.keys():
+ raise ValueError("Invalid thread_name: ", thread_name)
+
+ self.state_object, res = _ffi_api.StateBind(self.state_object,
+ self._resolve_stage_id(stage), iterator,
+ State.ANNOTATION_TRANS_TABLE[thread_name])
+ return res
+
+ def parallel(self, stage, iterator):
+ """ Schedule primitive corresponds to te.parallel.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be paralleled, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to be paralleled.
+
+ Returns
+ -------
+ res_it : Iterator
+ The paralleled Iterator.
+ """
+ self.state_object, res = _ffi_api.StateParallel(self.state_object,
+ self._resolve_stage_id(stage), iterator)
+ return res
+
+ def unroll(self, stage, iterator, max_unroll=None):
+ """ Schedule primitive corresponds to te.unroll.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be unrolled, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to be unrolled.
+ max_unroll : Optional[int]
+ The max unroll limit. Iterator with extent larger than this limit will be skipped.
+
+ Returns
+ -------
+ res_it : Iterator
+ The unrolled Iterator.
+ """
+ self.state_object, res = _ffi_api.StateUnroll(self.state_object,
+ self._resolve_stage_id(stage), iterator,
+ max_unroll if max_unroll else -1)
+ return res
+
+ def vectorize(self, stage, iterator):
+ """ Schedule primitive corresponds to te.vectorize.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be vectorized, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iterator : Iterator
+ The iterator to be vectorized.
+
+ Returns
+ -------
+ res_it : Iterator
+ The vectorized Iterator.
+ """
+ self.state_object, res = _ffi_api.StateVectorize(self.state_object,
+ self._resolve_stage_id(stage), iterator)
+ return res
+
+ def fuse(self, stage, iters):
+ """ Schedule primitive corresponds to te.fuse.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be fused, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ iters : List[Iterator]
+ The iterators to be fused.
+
+ Returns
+ -------
+ res_it : Iterator
+ The fused Iterator.
+
+ Notes
+ -----
+ If the iterators to be fused have stages attached at them(by compute_at), the fused
+ result will become the new attach point.
+ """
+ self.state_object, res = _ffi_api.StateFuse(self.state_object,
+ self._resolve_stage_id(stage), iters)
+ return res
+
def reorder(self, stage, order):
""" Schedule primitive corresponds to te.reorder.
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be reordered, can be a Stage order index, Stage operation or stage
- output tensor.
+ The Stage to be reordered, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
order : List[Iterator]
Iterators in the expected order.
"""
- stage_id = self._resolve_stage_id(stage)
-
- self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order)
+ self.state_object = _ffi_api.StateReorder(self.state_object, self._resolve_stage_id(stage),
+ order)
def split(self, stage, iterator, lengths, inner_to_outer=True):
""" Schedule primitive corresponds to te.split.
@@ -132,8 +270,8 @@ class State:
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be split, can be a Stage order index, Stage operation or stage
- output tensor.
+ The Stage to be split, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
iterator : Iterator
The iterator to be split.
lengths: List[int]
@@ -144,34 +282,74 @@ class State:
Returns
-------
res_its : List[Iterator]
- The splitted new Iterators
- """
- stage_id = self._resolve_stage_id(stage)
+ The splitted new Iterators.
- self.state_object, res = _ffi_api.StateSplit(self.state_object, stage_id, iterator, lengths,
- inner_to_outer)
+ Notes
+ -----
+ If we do split on an iterator which has stages attached at it(by compute_at), the inner
+ most iterator of split results will become the new attach point.
+ """
+ self.state_object, res = _ffi_api.StateSplit(self.state_object,
+ self._resolve_stage_id(stage),
+ iterator, lengths, inner_to_outer)
return res
- def fuse(self, stage, iters):
- """ Schedule primitive corresponds to te.fuse.
+ def compute_at(self, stage, target_stage, target_iter):
+ """ Schedule primitive corresponds to te.compute_at.
Parameters
----------
stage : Union[int, Operation, Tensor]
- The Stage to be fused, can be a Stage order index, Stage operation or stage
- output tensor.
- iters : List[Iterator]
- The iterators to be fused
+ The Stage to be compute at, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ target_stage : Union[int, Operation, Tensor]
+ The target stage of compute_at, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+ target_iter : Iterator
+ The target Iterator of compute_at.
+
+ Notes
+ -----
+ After compute_at, we need careful dependency analysis to compute the accurate bound
+ information. However, it is relatively expensive and complicated, so we just fill "None"
+ as bound for the newly created iterators.
+ Call ComputeDAG::InferBound on the returned state to get the complete bound information.
+ """
+ self.state_object = _ffi_api.StateComputeAt(self.state_object,
+ self._resolve_stage_id(stage),
+ self._resolve_stage_id(target_stage),
+ target_iter)
- Returns
- -------
- res_it : Iterator
- The fused Iterator
+ def compute_inline(self, stage):
+ """ Schedule primitive corresponds to te.compute_inline.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be compute inlined, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
"""
- stage_id = self._resolve_stage_id(stage)
+ self.state_object = _ffi_api.StateComputeInline(self.state_object,
+ self._resolve_stage_id(stage))
- self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters)
- return res
+ def compute_root(self, stage):
+ """ Schedule primitive corresponds to te.compute_root.
+
+ Parameters
+ ----------
+ stage : Union[int, Operation, Tensor]
+ The Stage to be compute root, which can be specified by the integer index, Operation,
+ or output tensor of the stage.
+
+ Notes
+ -----
+ After compute_root, we need careful dependency analysis to compute the accurate bound
+ information. However, it is relatively expensive and complicated, so we just fill "None"
+ as bound for the newly created iterators.
+ Call ComputeDAG::InferBound on the returned state to get the complete bound information.
+ """
+ self.state_object = _ffi_api.StateComputeRoot(self.state_object,
+ self._resolve_stage_id(stage))
def copy(self):
""" Do deep copy of this State. """
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
index a7abcb8..d81dff6 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -270,19 +270,9 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
}
// Apply the history steps to TVM schedule
+ // Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
- // Call each step's ApplyToSchedule method
- // Note: some steps have extra parameters that must be passed and they may need different
- // return value, so the ApplyToSchedule is not able to be merged to single interface
- if (auto ps = step.as<ReorderStepNode>()) {
- ps->ApplyToSchedule(stages, stage_to_axes);
- } else if (auto ps = step.as<SplitStepNode>()) {
- ps->ApplyToSchedule(stages, stage_to_axes);
- } else if (auto ps = step.as<FuseStepNode>()) {
- ps->ApplyToSchedule(stages, stage_to_axes);
- } else {
- LOG(FATAL) << "Invalid Step";
- }
+ StepApplyToSchedule(step, stages, stage_to_axes);
}
return std::make_pair(schedule, operator->()->tensors);
@@ -326,15 +316,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
- if (auto ps = step.as<ReorderStepNode>()) {
- ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
- } else if (auto ps = step.as<SplitStepNode>()) {
- ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
- } else if (auto ps = step.as<FuseStepNode>()) {
- ss << ps->PrintAsPythonAPI(&stages, &stage_to_axes);
- } else {
- LOG(FATAL) << "Invalid Step";
- }
+ ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes);
}
return ss.str();
@@ -352,7 +334,7 @@ State ComputeDAG::InferBound(const State& state) const {
ret_state = operator->()->init_state;
pstate = ret_state.CopyOnWrite();
pstate->transform_steps = state->transform_steps;
- ret_state.DoSteps(*this);
+ ret_state.ApplySteps(*this);
} else {
ret_state = state;
pstate = ret_state.CopyOnWrite();
diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc
index 1bfcb9e..bfe5478 100644
--- a/src/auto_scheduler/loop_state.cc
+++ b/src/auto_scheduler/loop_state.cc
@@ -90,36 +90,122 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
data_ = std::move(node);
}
+/********** AttachMap **********/
+void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
+ AttachMapNode* pnode = CopyOnWrite();
+
+ // Delete the current entry of this stage
+ DeleteStageEntry(pnode, stage_id);
+
+ // Store the new stage/iterator relations to map
+ IterKey iter_key(target_stage_id, target_iter_id);
+ pnode->stage_to_attach_iter[stage_id] = iter_key;
+ pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
+}
+
+void AttachMap::DeleteStage(int stage_id) {
+ AttachMapNode* pnode = CopyOnWrite();
+ // Delete the original stage entry
+ DeleteStageEntry(pnode, stage_id);
+}
+
+void AttachMap::UpdateIters(const std::vector<IterKey>& original_iters,
+ const std::vector<IterKey>& new_iters) {
+ CHECK_EQ(original_iters.size(), new_iters.size());
+ AttachMapNode* pnode = CopyOnWrite();
+ for (size_t i = 0; i < original_iters.size(); ++i) {
+ auto entry = pnode->iter_to_attached_stages.find(original_iters[i]);
+ // We get <IterKey, std::vector<StageKey>> from this map
+ if (entry == pnode->iter_to_attached_stages.end()) {
+ // Skip if this iterator does not have any attach relations
+ continue;
+ }
+
+ // Update the attaching target of an stage to the new iter in `stage_to_attach_iter`
+ for (const auto& s : entry->second) {
+ pnode->stage_to_attach_iter[s] = new_iters[i];
+ }
+
+ // Remove the original iterator relation from `iter_to_attached_stages` and add the new
+ // iterator to it
+ std::vector<int> attached_stages = std::move(entry->second);
+ pnode->iter_to_attached_stages.erase(entry);
+ pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages);
+ }
+}
+
+void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
+ auto old_entry = pnode->stage_to_attach_iter.find(stage_id);
+ // We get <StageKey, IterKey> from this map
+ if (old_entry != pnode->stage_to_attach_iter.end()) {
+ // Delete the stage in `iter_to_attached_stages`, if the corresponding iterator does not have
+ // any attatched stage, delete this iterm too
+ auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second);
+ // We get <IterKey, std::vector<StageKey>> from this map
+ FindAndDeleteItem(&entry2->second, stage_id);
+ if (entry2->second.size() == 0) {
+ pnode->iter_to_attached_stages.erase(entry2);
+ }
+ // Delete the stage in `stage_to_attach_iter`
+ pnode->stage_to_attach_iter.erase(old_entry);
+ }
+}
+
/********** State **********/
State::State(const Array<te::Operation>& ops) {
auto node = make_object<StateNode>();
for (const auto& op : ops) {
node->stages.push_back(Stage(op));
}
+ node->attach_map = AttachMap(make_object<AttachMapNode>());
node->concrete = true;
data_ = std::move(node);
}
/********** Schedule primitives apis for state **********/
-void State::reorder(int stage_id, const Array<Iterator>& order) {
+Iterator State::bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type) {
const Stage& stage = operator->()->stages[stage_id];
- CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
- << "should be specified";
- Array<Integer> after_ids;
- GetIndices(stage->iters, order, &after_ids);
- ReorderStep step = ReorderStep(stage_id, after_ids);
+ if (thread_type < IteratorAnnotation::kVThread || thread_type > IteratorAnnotation::kThreadZ) {
+ LOG(FATAL) << "thread_type error, valid: kVThread, kBlockX, kBlockY, "
+ << "kThreadX, kThreadY, kBlockZ, kThreadZ";
+ }
+ AnnotationStep step = AnnotationStep(stage_id, GetIndex(stage->iters, it), thread_type);
CopyOnWrite()->transform_steps.push_back(step);
- DoReorderStep(step);
+ return step->ApplyToState(this);
}
-Array<Iterator> State::split(int stage_id, const Iterator& it,
- const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+Iterator State::parallel(int stage_id, const Iterator& it) {
const Stage& stage = operator->()->stages[stage_id];
- SplitStep step =
- SplitStep(stage_id, GetIndex(stage->iters, it),
- it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
+ AnnotationStep step =
+ AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kParallel);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
+}
+
+Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) {
+ const Stage& stage = operator->()->stages[stage_id];
+
+ // Don't unroll if the extent is larger than max_unroll
+ if (max_unroll != -1 && it->range.defined()) {
+ if (auto imm = it->range->extent.as<IntImmNode>()) {
+ if (imm->value > max_unroll) {
+ return it;
+ }
+ }
+ }
+
+ AnnotationStep step =
+ AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kUnroll);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
+}
+
+Iterator State::vectorize(int stage_id, const Iterator& it) {
+ const Stage& stage = operator->()->stages[stage_id];
+ AnnotationStep step =
+ AnnotationStep(stage_id, GetIndex(stage->iters, it), IteratorAnnotation::kVectorize);
CopyOnWrite()->transform_steps.push_back(step);
- return DoSplitStep(step);
+ return step->ApplyToState(this);
}
Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
@@ -128,174 +214,59 @@ Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
GetIndices(stage->iters, iters, &indices);
FuseStep step = FuseStep(stage_id, indices);
CopyOnWrite()->transform_steps.push_back(step);
- return DoFuseStep(step);
+ return step->ApplyToState(this);
}
-/********** Step implementations for state **********/
-void State::DoReorderStep(const ReorderStep& step) {
- const Stage& stage = operator->()->stages[step->stage_id];
- Array<Iterator> iters;
- for (auto x : step->after_ids) {
- iters.push_back(stage->iters[x]);
- }
- StateNode* pstate = CopyOnWrite();
- pstate->stages.Set(step->stage_id,
- Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
+void State::reorder(int stage_id, const Array<Iterator>& order) {
+ const Stage& stage = operator->()->stages[stage_id];
+ CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
+ << "should be specified";
+ Array<Integer> after_ids;
+ GetIndices(stage->iters, order, &after_ids);
+ ReorderStep step = ReorderStep(stage_id, after_ids);
+ CopyOnWrite()->transform_steps.push_back(step);
+ step->ApplyToState(this);
}
-// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep
-Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id,
- const Array<Optional<Integer>>& lengths,
- bool inner_to_outer) {
+Array<Iterator> State::split(int stage_id, const Iterator& it,
+ const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
const Stage& stage = operator->()->stages[stage_id];
- const Iterator& it = stage->iters[iter_id];
- bool concrete = true;
-
- Optional<PrimExpr> tosplit_min, tosplit_extent;
- if (it->range.defined()) {
- tosplit_min = it->range->min;
- tosplit_extent = it->range->extent;
- } else {
- tosplit_min = NullOpt;
- tosplit_extent = NullOpt;
- }
-
- Array<Iterator> outs;
- for (size_t i = 0; i < lengths.size(); ++i) {
- Optional<Integer> l;
- String name;
- if (inner_to_outer) {
- l = lengths[lengths.size() - i - 1];
- name = it->name + "." + std::to_string(lengths.size() - i);
- } else {
- l = lengths[i];
- name = it->name + "." + std::to_string(i);
- }
- Iterator res;
- if (l && tosplit_min && tosplit_extent) {
- res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind,
- IteratorAnnotation::kNone);
- tosplit_min = Integer(0);
- tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value());
- } else {
- res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
- tosplit_min = NullOpt;
- tosplit_extent = NullOpt;
- concrete = false;
- }
- outs.push_back(std::move(res));
- }
-
- Range range;
- if (tosplit_min && tosplit_extent) {
- range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value());
- }
- if (inner_to_outer) {
- outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone));
- // Reverse the Iterator array
- Array<Iterator> temp(outs.rbegin(), outs.rend());
- outs = std::move(temp);
- } else {
- outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind,
- IteratorAnnotation::kNone));
- }
-
- Array<Iterator> new_iters;
- new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id);
- new_iters.insert(new_iters.end(), outs.begin(), outs.end());
- new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end());
-
- StateNode* pstate = CopyOnWrite();
- pstate->stages.Set(stage_id,
- Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
- pstate->concrete &= concrete;
-
- return outs;
+ SplitStep step =
+ SplitStep(stage_id, GetIndex(stage->iters, it),
+ it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
+ CopyOnWrite()->transform_steps.push_back(step);
+ return step->ApplyToState(this);
}
-Array<Iterator> State::DoSplitStep(const SplitStep& step) {
- return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer);
+void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
+ const Stage& target_stage = operator->()->stages[target_stage_id];
+ ComputeAtStep step =
+ ComputeAtStep(stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter));
+ CopyOnWrite()->transform_steps.push_back(step);
+ step->ApplyToState(this);
}
-Iterator State::DoFuseStep(const FuseStep& step) {
- int stage_id = step->stage_id;
- const Stage& stage = operator->()->stages[stage_id];
-
- String new_name;
- PrimExpr new_extent = 1;
- IteratorKind new_iter_kind = IteratorKind::kSpecial;
-
- for (size_t i = 0; i < step->fused_ids.size(); ++i) {
- if (i > 0) {
- CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1);
- }
-
- const Iterator& it = stage->iters[step->fused_ids[i]];
- new_name = new_name + it->name + "@";
-
- if (it->range.defined() && new_extent.defined()) {
- new_extent = new_extent * it->range->extent;
- } else {
- new_extent = PrimExpr();
- }
-
- if (i == 0) {
- new_iter_kind = it->iter_kind;
- } else {
- if (new_iter_kind != it->iter_kind) {
- new_iter_kind = IteratorKind::kMixed;
- }
- }
- }
+void State::compute_inline(int stage_id) {
+ ComputeInlineStep step = ComputeInlineStep(stage_id);
+ CopyOnWrite()->transform_steps.push_back(step);
+ step->ApplyToState(this);
+}
- Range range;
- if (new_extent.defined()) {
- range = Range::FromMinExtent(0, new_extent);
- }
- Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone);
- Array<Iterator> new_iters;
- new_iters.insert(new_iters.end(), stage->iters.begin(),
- stage->iters.begin() + step->fused_ids.front());
- new_iters.push_back(new_it);
- new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1,
- stage->iters.end());
-
- StateNode* pstate = CopyOnWrite();
- pstate->stages.Set(stage_id,
- Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
-
- return new_it;
+void State::compute_root(int stage_id) {
+ ComputeRootStep step = ComputeRootStep(stage_id);
+ CopyOnWrite()->transform_steps.push_back(step);
+ step->ApplyToState(this);
}
-void State::DoSteps(const ComputeDAG& dag) {
+void State::ApplySteps(const ComputeDAG& dag) {
CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
+ // Call each step's ApplyToState method
for (const auto& step : operator->()->transform_steps) {
- if (auto ps = step.as<ReorderStepNode>()) {
- DoReorderStep(GetRef<ReorderStep>(ps));
- } else if (auto ps = step.as<SplitStepNode>()) {
- DoSplitStep(GetRef<SplitStep>(ps));
- } else if (auto ps = step.as<FuseStepNode>()) {
- DoFuseStep(GetRef<FuseStep>(ps));
- } else {
- LOG(FATAL) << "Invalid step: " << step;
- }
+ StepApplyToState(step, this, dag);
}
}
-static const char* IteratorAnnotationString[] = {
- "for", // kNone = 0
- "unroll", // kUnroll = 1
- "vectorize", // kVectorize = 2
- "parallel", // kParallel = 3
- "vthread", // kVThread = 4
- "gpu.blockIdx.x", // kBlockX = 5
- "gpu.threadIdx.x", // kThreadX = 6
- "gpu.blockIdx.y", // kBlockY = 7
- "gpu.threadIdx.y", // kThreadY = 8
- "tensorize" // kTensorized = 9
-};
-
// Print stage to ostream
void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_indent,
bool delete_trivial_loop) {
@@ -332,6 +303,17 @@ void PrintStage(std::ostream* os, int stage_id, const State& state, size_t base_
indent += 2;
}
+
+ if (state.defined()) {
+ IterKey iter_key(stage_id, i);
+ auto pair = state->attach_map->iter_to_attached_stages.find(iter_key);
+ if (pair != state->attach_map->iter_to_attached_stages.end()) {
+ // Print the attached stage
+ for (const auto& attach_stage_id : pair->second) {
+ PrintStage(os, attach_stage_id, state, base_indent + indent, delete_trivial_loop);
+ }
+ }
+ }
}
for (size_t j = 0; j < base_indent + indent; ++j) {
@@ -386,6 +368,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
/********** State interface API for ffi **********/
+TVM_REGISTER_GLOBAL("auto_scheduler.StateBind")
+ .set_body_typed([](State state, int stage_id, const Iterator& it, int thread_type) {
+ const auto& res = state.bind(stage_id, it, IteratorAnnotation(thread_type));
+ return Array<ObjectRef>{state, res};
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateParallel")
+ .set_body_typed([](State state, int stage_id, const Iterator& it) {
+ const auto& res = state.parallel(stage_id, it);
+ return Array<ObjectRef>{state, res};
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateUnroll")
+ .set_body_typed([](State state, int stage_id, const Iterator& it, int max_unroll) {
+ const auto& res = state.unroll(stage_id, it, max_unroll);
+ return Array<ObjectRef>{state, res};
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateVectorize")
+ .set_body_typed([](State state, int stage_id, const Iterator& it) {
+ const auto& res = state.vectorize(stage_id, it);
+ return Array<ObjectRef>{state, res};
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
+ .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
+ const auto& res = state.fuse(stage_id, iters);
+ return Array<ObjectRef>{state, res};
+ });
+
TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
.set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
state.reorder(stage_id, order);
@@ -399,10 +411,23 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
return Array<ObjectRef>{state, res};
});
-TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
- .set_body_typed([](State state, int stage_id, const Array<Iterator>& iters) {
- const auto& res = state.fuse(stage_id, iters);
- return Array<ObjectRef>{state, res};
+TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
+ .set_body_typed([](State state, int stage_id, int target_stage_id,
+ const Iterator& target_iter) {
+ state.compute_at(stage_id, target_stage_id, target_iter);
+ return state;
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeInline")
+ .set_body_typed([](State state, int stage_id) {
+ state.compute_inline(stage_id);
+ return state;
+ });
+
+TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
+ .set_body_typed([](State state, int stage_id) {
+ state.compute_root(stage_id);
+ return state;
});
TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
diff --git a/src/auto_scheduler/loop_state.h b/src/auto_scheduler/loop_state.h
index 04e5304..4d6477b 100644
--- a/src/auto_scheduler/loop_state.h
+++ b/src/auto_scheduler/loop_state.h
@@ -51,6 +51,9 @@
#include <tvm/runtime/container.h>
#include <functional>
+#include <unordered_map>
+#include <utility>
+#include <vector>
#include "transform_step.h"
@@ -79,84 +82,6 @@ enum class ComputeAtKind : int {
kIter = 2,
};
-/*! \brief The type of an iterator. */
-enum class IteratorKind : int {
- /*! \brief Spatial iterator. */
- kSpatial = 0,
- /*! \brief Reduction iterator. */
- kReduction = 1,
- /*! \brief Fused spatial and reduction iterator. */
- kMixed = 2,
- /*! \brief Special iterator. (e.g. virtual root iterator) */
- kSpecial = 3
-};
-
-/*! \brief The type of an iterator's annotation. */
-enum class IteratorAnnotation : int {
- /*! \brief This iterator has no annotation. */
- kNone = 0,
- /*! \brief This iterator has been unrolled. */
- kUnroll = 1,
- /*! \brief This iterator has been vectorized. */
- kVectorize = 2,
- /*! \brief This iterator has been paralleld. */
- kParallel = 3,
- /*! \brief This iterator has been bind to vthread. */
- kVThread = 4,
- /*! \brief This iterator has been bind to blockIdx.x. */
- kBlockX = 5,
- /*! \brief This iterator has been bind to threadIdx.x. */
- kThreadX = 6,
- /*! \brief This iterator has been bind to blockIdx.y. */
- kBlockY = 7,
- /*! \brief This iterator has been bind to threadIdx.y. */
- kThreadY = 8,
- /*! \brief This iterator has been mapped with a tensorize intrinsic. */
- kTensorized = 9
-};
-
-/*!
- * \brief A for loop iterator
- * Similar to tvm::IterVar in `include/tvm/tir/expr.h`
- */
-class IteratorNode : public Object {
- public:
- /*! \brief The name of this iterator. */
- String name;
- /*! \brief The range of this iterator. */
- Range range;
- /*! \brief The iterator type of this iterator. */
- IteratorKind iter_kind;
- /*! \brief The annotation type of this iterator. */
- IteratorAnnotation annotation;
-
- void VisitAttrs(tvm::AttrVisitor* v) {
- v->Visit("name", &name);
- v->Visit("range", &range);
- }
-
- static constexpr const char* _type_key = "auto_scheduler.Iterator";
- TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object);
-};
-
-/*!
- * \brief Managed reference to IteratorNode.
- * \sa IteratorNode
- */
-class Iterator : public ObjectRef {
- public:
- /*!
- * \brief The constructor.
- * \param name The name of this iterator.
- * \param range The range of this iterator.
- * \param iter_kind The iterator type of this iterator.
- * \param annotation The annotation type of this iterator.
- */
- Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation);
-
- TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode);
-};
-
/*! \brief Stage-level attributes. */
struct StageAttributes {
/*! \brief The maximum steps for the pragma `auto_unroll_max_step`. */
@@ -167,16 +92,16 @@ struct StageAttributes {
/*!
* \brief A op stage in the compute declaration.
- * Similar to te::Stage in `include/schedule.h`.
+ * Similar to te::Stage in `include/tvm/te/schedule.h`.
*/
class StageNode : public Object {
public:
/*! \brief The operator of this stage */
te::Operation op;
- /*! \brief The type of this stage. */
- StageKind op_type;
/*! \brief The iterators in this stage. */
Array<Iterator> iters;
+ /*! \brief The type of this stage. */
+ StageKind op_type;
/*! \brief The compute location of this stage. */
ComputeAtKind compute_at;
/*! \brief Other stage-level attributes. */
@@ -185,6 +110,8 @@ class StageNode : public Object {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("iters", &iters);
+ v->Visit("op_type", &op_type);
+ v->Visit("compute_at", &compute_at);
}
static constexpr const char* _type_key = "auto_scheduler.Stage";
@@ -217,6 +144,70 @@ class Stage : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode);
};
+/*! \brief Use stage_id to represent a stage. */
+using StageKey = int;
+/*! \brief Use stage_id and iter_id to represent a iterator. */
+using IterKey = std::pair<int, int>;
+
+/*!
+ * \brief stores the compute_at relation between stages
+ * This stores a bi-directional mapping from stages and iter:
+ * 1. Stage to its attached iterator
+ * 2. Iterator to the stage attached to it
+ * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages
+ * to query the relations
+ */
+class AttachMapNode : public Object {
+ public:
+ /*! \brief A Map to store the mapping of stage to its attached iterator. */
+ std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
+ /*! \brief A Map to store the mapping of iterator to the stage attached to it. */
+ std::unordered_map<IterKey, std::vector<StageKey>> iter_to_attached_stages;
+
+ static constexpr const char* _type_key = "auto_scheduler.AttachMap";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AttachMapNode.
+ * \sa AttachMapNode
+ */
+class AttachMap : public ObjectRef {
+ public:
+ /*!
+ * \brief Process the stage/iterator mapping after compute at.
+ * \param stage_id The index of the stage to be compute at.
+ * \param target_stage_id The index of stage that this step will compute at to.
+ * \param target_iter_id The index of iterator in target stage that this step will compute at to.
+ */
+ void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
+ /*!
+ * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
+ * \param stage_id The index of the stage to be compute at.
+ */
+ void DeleteStage(int stage_id);
+ /*!
+ * \brief Find the relations of original iterators in AttachMap, and update them with the new
+ * iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
+ * \param original_iters The original IterKey.
+ * \param new_iters The new IterKey to update.
+ */
+ void UpdateIters(const std::vector<IterKey>& original_iters,
+ const std::vector<IterKey>& new_iters);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);
+
+ private:
+ /*!
+ * \brief To delete the entry of a specific stage. This will remove the items related to this
+ * stage in both `stage_to_attach_iter` and `iter_to_attached_stages` map.
+ * \param pnode A mutable pointer to AttachMapNode.
+ * \param stage_id The index of stage that will be removed from the map.
+ */
+ static void DeleteStageEntry(AttachMapNode* pnode, int stage_id);
+};
+
/*!
* \brief A state in the search process.
* It consists of the current loop structure and a list of transformation steps used to construct
@@ -230,6 +221,11 @@ class StateNode : public Object {
/*! \brief History transformation steps. */
Array<Step> transform_steps;
/*!
+ * \brief The attach relations of stages and iterators. This is used to track the compute at
+ * operation.
+ */
+ AttachMap attach_map;
+ /*!
* \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
* tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
*/
@@ -275,17 +271,60 @@ class State : public ObjectRef {
String ToStr(bool delete_trivial_loop = true) const;
/*!
- * \brief General do step functions with a runtime dynamic dispatcher. This will re-apply all the
- * transform steps with the initial state.
+ * \brief General call step functions with a runtime dynamic dispatcher. This will re-apply all
+ * the transform steps from the initial state.
* \param dag The original ComputeDAG of this state.
- * \note This is different from the class member `current_compute_dag`, for some transform step
- * may change the op stage structure of the ComputeDAG.
+ * \note The input `dag` is different from the class member `current_compute_dag`.
+ * This function takes the initial ComputeDAG as input to replay all the history. While the
+ * `current_compute_dag` is used to track the current stage status, for some transform step may
+ * change the op stage structure.
*/
- void DoSteps(const ComputeDAG& dag);
+ void ApplySteps(const ComputeDAG& dag);
- /* Step APIs for State. */
+ /********** Step APIs working on single stage **********/
/*!
+ * \brief Schedule primitive corresponds to te.bind.
+ * \param stage_id The index of the stage to be binded.
+ * \param it The iterator to be binded.
+ * \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as
+ * this input.
+ * \return The iterator result after binded.
+ */
+ Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
+ /*!
+ * \brief Schedule primitive corresponds to te.parallel.
+ * \param stage_id The index of the stage to be paralleled.
+ * \param it The iterator to be paralleled.
+ * \return The iterator result after parallel.
+ */
+ Iterator parallel(int stage_id, const Iterator& it);
+ /*!
+ * \brief Schedule primitive corresponds to te.unroll.
+ * \param stage_id The index of the stage to be unrolled.
+ * \param it The iterator to be unrolled.
+ * \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
+ * skipped.
+ * \return The iterator result after unrolled.
+ */
+ Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
+ /*!
+ * \brief Schedule primitive corresponds to te.vectorize.
+ * \param stage_id The index of the stage to be vectorized.
+ * \param it The iterator to be vectorized.
+ * \return The iterator result after vectorize.
+ */
+ Iterator vectorize(int stage_id, const Iterator& it);
+ /*!
+ * \brief Schedule primitive corresponds to te.fuse.
+ * \param stage_id The index of the stage to be fused.
+ * \param iters The iterators to be fused.
+ * \return The iterator result after fuse.
+ * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
+ * result will become the new attach point.
+ */
+ Iterator fuse(int stage_id, const Array<Iterator>& iters);
+ /*!
* \brief Schedule primitive corresponds to te.reorder.
* \param stage_id The index of the stage to be reordered.
* \param order The expected iterator order.
@@ -294,57 +333,46 @@ class State : public ObjectRef {
/*!
* \brief Schedule primitive corresponds to te.split.
* \param stage_id The index of the stage to be split.
- * \param it The iterator the be split.
+ * \param it The iterator to be split.
* \param lengths The multiple split factors. Can be None to be filled by search policy.
* \param inner_to_outer Whether the factor go from inner to outer, or from outer to inner.
* \return The iterator results after split.
+ * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
+ * most iterator of split results will become the new attach point.
*/
Array<Iterator> split(int stage_id, const Iterator& it, const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);
- /*!
- * \brief Schedule primitive corresponds to te.fuse.
- * \param stage_id The index of the stage to be fused.
- * \param iters The iterators to be fused.
- * \return The iterator result after fuse.
- */
- Iterator fuse(int stage_id, const Array<Iterator>& iters);
- TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
- TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
-
- private:
- /* Do transform steps
- * Note: The following functions only change loop state but do not change transform_history.
- * We separate these functions out, so you can call them for replay easily given history steps */
+ /********** Step APIs working on multiple stages **********/
/*!
- * \brief Apply reorder step to current state.
- * \param step A ReorderStep.
+ * \brief Schedule primitive corresponds to te.compute_at.
+ * \param stage_id The index of the stage to be reordered.
+ * \param target_stage_id The index of stage that this step will compute at to.
+ * \param target_iter The iterator in target stage that this step will compute at to.
+ * \note After compute_at, we need careful dependency analysis to compute the accurate bound
+ * information. However, it is relatively expensive and complicated, so we just fill "None" as
+ * bound for the newly created iterators.
+ * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
*/
- void DoReorderStep(const ReorderStep& step);
+ void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
/*!
- * \brief Apply split step to current state.
- * \param step A SplitStep.
- * \return The iterator results after split.
+ * \brief Schedule primitive corresponds to te.compute_inline.
+ * \param stage_id The index of the stage to be reordered.
*/
- Array<Iterator> DoSplitStep(const SplitStep& step);
+ void compute_inline(int stage_id);
/*!
- * \brief Apply fuse step to current state.
- * \param step A FuseStep.
- * \return The iterator result after fuse.
+ * \brief Schedule primitive corresponds to te.compute_root.
+ * \param stage_id The index of the stage to be reordered.
+ * \note After compute_root, we need careful dependency analysis to compute the accurate bound
+ * information. However, it is relatively expensive and complicated, so we just fill "None" as
+ * bound for the newly created iterators.
+ * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
*/
- Iterator DoFuseStep(const FuseStep& step);
+ void compute_root(int stage_id);
- /*!
- * \brief Common function for DoSplitStep and DoFollowSplitStep(Will be added later).
- * \param stage_id The index of the stage to be split.
- * \param iter_id The index of the iterator to be split.
- * \param lengths The multiple split factors.
- * \param inner_to_outer The split direction.
- * \return The iterator results after split.
- */
- Array<Iterator> DoSplitStepCommon(int stage_id, int iter_id,
- const Array<Optional<Integer>>& lengths, bool inner_to_outer);
+ TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
};
} // namespace auto_scheduler
diff --git a/src/auto_scheduler/measure_record.cc b/src/auto_scheduler/measure_record.cc
index f6f882e..39f9ad8 100644
--- a/src/auto_scheduler/measure_record.cc
+++ b/src/auto_scheduler/measure_record.cc
@@ -42,25 +42,6 @@
namespace dmlc {
namespace json {
-inline std::vector<int> IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) {
- std::vector<int> out;
- for (const auto& x : data) {
- CHECK(x.defined());
- out.push_back(x);
- }
- return out;
-}
-
-inline std::vector<int> IntArrayToVector(
- const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) {
- std::vector<int> out;
- for (const auto& x : data) {
- CHECK(x);
- out.push_back(x.value());
- }
- return out;
-}
-
template <>
struct Handler<::tvm::Array<::tvm::auto_scheduler::Stage>> {
inline static void Write(dmlc::JSONWriter* writer,
@@ -82,28 +63,10 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
inline static void Write(dmlc::JSONWriter* writer,
const ::tvm::Array<::tvm::auto_scheduler::Step>& data) {
writer->BeginArray(false);
- for (size_t i = 0; i < data.size(); ++i) {
+ for (const auto& step : data) {
writer->WriteArraySeperator();
writer->BeginArray(false);
- if (auto ps = data[i].as<::tvm::auto_scheduler::ReorderStepNode>()) {
- writer->WriteArrayItem(std::string("RE"));
- writer->WriteArrayItem(ps->stage_id);
- writer->WriteArrayItem(IntArrayToVector(ps->after_ids));
- } else if (auto ps = data[i].as<::tvm::auto_scheduler::SplitStepNode>()) {
- writer->WriteArrayItem(std::string("SP"));
- writer->WriteArrayItem(ps->stage_id);
- writer->WriteArrayItem(ps->iter_id);
- writer->WriteArrayItem(ps->extent ? ::tvm::auto_scheduler::GetIntImm(ps->extent.value())
- : 0);
- writer->WriteArrayItem(IntArrayToVector(ps->lengths));
- writer->WriteArrayItem(static_cast<int>(ps->inner_to_outer));
- } else if (auto ps = data[i].as<::tvm::auto_scheduler::FuseStepNode>()) {
- writer->WriteArrayItem(std::string("FU"));
- writer->WriteArrayItem(ps->stage_id);
- writer->WriteArrayItem(IntArrayToVector(ps->fused_ids));
- } else {
- LOG(FATAL) << "Invalid step: " << data[i];
- }
+ step->WriteToRecord(writer);
writer->EndArray();
}
writer->EndArray();
@@ -111,67 +74,12 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
inline static void Read(dmlc::JSONReader* reader,
::tvm::Array<::tvm::auto_scheduler::Step>* data) {
- std::vector<int> int_list;
- bool s, inner_to_outer;
- std::string name, scope_name, pragma_type, ti_func_name;
- int stage_id, iter_id, extent;
-
+ bool s;
reader->BeginArray();
data->clear();
while (reader->NextArrayItem()) {
reader->BeginArray();
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&name);
- if (name == "RE") {
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&stage_id);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&int_list);
- ::tvm::Array<::tvm::Integer> after_ids;
- for (const auto& i : int_list) {
- after_ids.push_back(i);
- }
- data->push_back(::tvm::auto_scheduler::ReorderStep(stage_id, after_ids));
- } else if (name == "SP") {
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&stage_id);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&iter_id);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&extent);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&int_list);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&inner_to_outer);
- ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths;
- for (const auto& i : int_list) {
- lengths.push_back(::tvm::Integer(i));
- }
- data->push_back(::tvm::auto_scheduler::SplitStep(
- stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer));
- } else if (name == "FU") {
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&stage_id);
- s = reader->NextArrayItem();
- CHECK(s);
- reader->Read(&int_list);
- ::tvm::Array<::tvm::Integer> fused_ids;
- for (const auto& i : int_list) {
- fused_ids.push_back(i);
- }
- data->push_back(::tvm::auto_scheduler::FuseStep(stage_id, fused_ids));
- } else {
- LOG(FATAL) << "Invalid step format";
- }
+ data->push_back(::tvm::auto_scheduler::StepReadFromRecord(reader));
s = reader->NextArrayItem();
CHECK(!s);
}
@@ -187,8 +95,8 @@ struct Handler<::tvm::auto_scheduler::StateNode> {
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::StateNode* data) {
- reader->BeginArray();
bool s;
+ reader->BeginArray();
s = reader->NextArrayItem();
CHECK(s);
reader->Read(&data->stages);
@@ -210,18 +118,17 @@ struct Handler<::tvm::auto_scheduler::SearchTaskNode> {
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::SearchTaskNode* data) {
- std::string target_str;
bool s;
-
+ std::string str_value;
reader->BeginArray();
s = reader->NextArrayItem();
CHECK(s);
- reader->Read(&target_str);
- data->workload_key = std::move(target_str);
+ reader->Read(&str_value);
+ data->workload_key = std::move(str_value);
s = reader->NextArrayItem();
CHECK(s);
- reader->Read(&target_str);
- data->target = ::tvm::Target::Create(target_str);
+ reader->Read(&str_value);
+ data->target = ::tvm::Target::Create(str_value);
s = reader->NextArrayItem();
CHECK(!s);
}
@@ -237,11 +144,11 @@ struct Handler<::tvm::auto_scheduler::MeasureInputNode> {
writer->EndArray();
}
inline static void Read(dmlc::JSONReader* reader, ::tvm::auto_scheduler::MeasureInputNode* data) {
- bool s;
auto task_node = ::tvm::make_object<::tvm::auto_scheduler::SearchTaskNode>();
auto state_node = ::tvm::make_object<::tvm::auto_scheduler::StateNode>();
state_node->concrete = true;
+ bool s;
reader->BeginArray();
s = reader->NextArrayItem();
CHECK(s);
@@ -277,15 +184,14 @@ struct Handler<::tvm::auto_scheduler::MeasureResultNode> {
}
inline static void Read(dmlc::JSONReader* reader,
::tvm::auto_scheduler::MeasureResultNode* data) {
+ std::vector<double> double_list;
bool s;
- std::vector<double> tmp;
-
reader->BeginArray();
s = reader->NextArrayItem();
CHECK(s);
- reader->Read(&tmp);
+ reader->Read(&double_list);
data->costs.clear();
- for (const auto& i : tmp) {
+ for (const auto& i : double_list) {
data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i));
}
s = reader->NextArrayItem();
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index 90b4db8..6c672a5 100644
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -28,7 +28,9 @@
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
+#include <string>
#include <utility>
+#include <vector>
#include "loop_state.h"
#include "utils.h"
@@ -36,6 +38,404 @@
namespace tvm {
namespace auto_scheduler {
+const char* IteratorAnnotationString[] = {
+ "for", // kNone = 0
+ "unroll", // kUnroll = 1
+ "vectorize", // kVectorize = 2
+ "parallel", // kParallel = 3
+ "vthread", // kVThread = 4
+ "blockIdx.x", // kBlockX = 5
+ "threadIdx.x", // kThreadX = 6
+ "blockIdx.y", // kBlockY = 7
+ "threadIdx.y", // kThreadY = 8
+ "blockIdx.z", // kBlockZ = 9
+ "threadIdx.z", // kThreadZ = 10
+ "tensorize" // kTensorized = 11
+};
+
+Step StepReadFromRecord(dmlc::JSONReader* reader) {
+ std::string name;
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&name);
+ if (name == AnnotationStepNode::record_prefix_str) {
+ return AnnotationStep(reader);
+ } else if (name == FuseStepNode::record_prefix_str) {
+ return FuseStep(reader);
+ } else if (name == ReorderStepNode::record_prefix_str) {
+ return ReorderStep(reader);
+ } else if (name == SplitStepNode::record_prefix_str) {
+ return SplitStep(reader);
+ } else if (name == ComputeAtStepNode::record_prefix_str) {
+ return ComputeAtStep(reader);
+ } else if (name == ComputeInlineStepNode::record_prefix_str) {
+ return ComputeInlineStep(reader);
+ } else if (name == ComputeRootStepNode::record_prefix_str) {
+ return ComputeRootStep(reader);
+ } else {
+ LOG(FATAL) << "Invalid step format: " << name;
+ }
+ return Step();
+}
+
+void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
+ if (auto ps = step.as<AnnotationStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<FuseStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<ReorderStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<SplitStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<ComputeAtStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<ComputeInlineStepNode>()) {
+ ps->ApplyToState(state);
+ } else if (auto ps = step.as<ComputeRootStepNode>()) {
+ ps->ApplyToState(state);
+ } else {
+ LOG(FATAL) << "Invalid step: " << step;
+ }
+}
+
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) {
+ if (auto ps = step.as<AnnotationStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<FuseStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<ReorderStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<SplitStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<ComputeAtStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<ComputeInlineStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else if (auto ps = step.as<ComputeRootStepNode>()) {
+ ps->ApplyToSchedule(stages, stage_to_axes);
+ } else {
+ LOG(FATAL) << "Invalid Step: " << step;
+ }
+}
+
+String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) {
+ if (auto ps = step.as<AnnotationStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<FuseStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<ReorderStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<SplitStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<ComputeAtStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<ComputeInlineStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else if (auto ps = step.as<ComputeRootStepNode>()) {
+ return ps->PrintAsPythonAPI(stages, stage_to_axes);
+ } else {
+ LOG(FATAL) << "Invalid Step: " << step;
+ }
+ return "";
+}
+
+/********** Primitives working on single stage **********/
+
+/********** Annotation **********/
+AnnotationStep::AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann) {
+ auto node = make_object<AnnotationStepNode>();
+ node->stage_id = stage_id;
+ node->iter_id = iter_id;
+ node->annotation = ann;
+ data_ = std::move(node);
+}
+
+AnnotationStep::AnnotationStep(dmlc::JSONReader* reader) {
+ auto node = make_object<AnnotationStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ int int_val;
+ reader->Read(&int_val);
+ node->annotation = IteratorAnnotation(int_val);
+ data_ = std::move(node);
+}
+
+void AnnotationStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArrayItem(static_cast<int>(annotation));
+}
+
+Iterator AnnotationStepNode::ApplyToState(State* state) const {
+ const Stage& stage = (*state)->stages[stage_id];
+ Iterator it = stage->iters[iter_id];
+
+ CHECK(it->annotation == IteratorAnnotation::kNone);
+ Iterator new_it = Iterator(it->name, it->range, it->iter_kind, annotation);
+ Stage new_stage = stage;
+ new_stage.CopyOnWrite()->iters.Set(iter_id, new_it);
+ state->CopyOnWrite()->stages.Set(stage_id, std::move(new_stage));
+ return new_it;
+}
+
+void AnnotationStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ te::Stage stage = (*stages)[stage_id];
+ const Array<IterVar>& axes = (*stage_to_axes)[stage];
+
+ switch (annotation) {
+ case IteratorAnnotation::kUnroll:
+ stage.unroll(axes[iter_id]);
+ break;
+ case IteratorAnnotation::kVectorize:
+ stage.vectorize(axes[iter_id]);
+ break;
+ case IteratorAnnotation::kParallel:
+ stage.parallel(axes[iter_id]);
+ break;
+ case IteratorAnnotation::kVThread:
+ case IteratorAnnotation::kBlockX:
+ case IteratorAnnotation::kBlockY:
+ case IteratorAnnotation::kBlockZ:
+ case IteratorAnnotation::kThreadX:
+ case IteratorAnnotation::kThreadY:
+ case IteratorAnnotation::kThreadZ:
+ stage.bind(axes[iter_id],
+ te::thread_axis(Range(), IteratorAnnotationString[static_cast<int>(annotation)]));
+ break;
+ case IteratorAnnotation::kNone:
+ break;
+ default:
+ LOG(FATAL) << "Invalid Annotation " << static_cast<int>(annotation);
+ break;
+ }
+
+ stages->Set(stage_id, std::move(stage));
+}
+
+String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ std::stringstream ss;
+ const auto& stage = (*stages)[stage_id];
+ const auto& iter = (*stage_to_axes)[stage][iter_id];
+
+ ss << "s[" << CleanName(stage->op->name) << "].";
+ switch (annotation) {
+ case IteratorAnnotation::kUnroll:
+ ss << "unroll(";
+ break;
+ case IteratorAnnotation::kVectorize:
+ ss << "vectorize(";
+ break;
+ case IteratorAnnotation::kParallel:
+ ss << "parallel(";
+ break;
+ case IteratorAnnotation::kVThread:
+ case IteratorAnnotation::kBlockX:
+ case IteratorAnnotation::kBlockY:
+ case IteratorAnnotation::kBlockZ:
+ case IteratorAnnotation::kThreadX:
+ case IteratorAnnotation::kThreadY:
+ case IteratorAnnotation::kThreadZ:
+ ss << "bind(";
+ break;
+ case IteratorAnnotation::kNone:
+ break;
+ default:
+ LOG(FATAL) << "Invalid annotation " << static_cast<int>(annotation);
+ break;
+ }
+ ss << CleanName(iter->var->name_hint);
+ switch (annotation) {
+ case IteratorAnnotation::kVThread:
+ case IteratorAnnotation::kBlockX:
+ case IteratorAnnotation::kBlockY:
+ case IteratorAnnotation::kBlockZ:
+ case IteratorAnnotation::kThreadX:
+ case IteratorAnnotation::kThreadY:
+ case IteratorAnnotation::kThreadZ:
+ ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
+ << "\")";
+ break;
+ default:
+ break;
+ }
+ ss << ")\n";
+
+ ApplyToSchedule(stages, stage_to_axes);
+ return ss.str();
+}
+
+/********** Fuse **********/
+FuseStep::FuseStep(int stage_id, const Array<Integer>& fused_ids) {
+ auto node = make_object<FuseStepNode>();
+ node->stage_id = stage_id;
+ for (const auto& x : fused_ids) {
+ CHECK(x->IsInstance<IntImmNode>());
+ }
+ node->fused_ids = fused_ids;
+ data_ = std::move(node);
+}
+
+FuseStep::FuseStep(dmlc::JSONReader* reader) {
+ auto node = make_object<FuseStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ std::vector<int> int_list;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&int_list);
+ ::tvm::Array<::tvm::Integer> fused_ids;
+ for (const auto& i : int_list) {
+ fused_ids.push_back(i);
+ }
+ node->fused_ids = fused_ids;
+ data_ = std::move(node);
+}
+
+void FuseStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(IntArrayToVector(fused_ids));
+}
+
+Iterator FuseStepNode::ApplyToState(State* state) const {
+ const Stage& stage = (*state)->stages[stage_id];
+ size_t old_iter_size = static_cast<int>(stage->iters.size());
+
+ String new_name;
+ PrimExpr new_extent = 1;
+ IteratorKind new_iter_kind = IteratorKind::kSpecial;
+
+ for (size_t i = 0; i < fused_ids.size(); ++i) {
+ if (i > 0) {
+ CHECK_EQ(fused_ids[i]->value, fused_ids[i - 1]->value + 1);
+ }
+
+ if (i != fused_ids.size() - 1) {
+ const auto& iter_to_attached_stage = (*state)->attach_map->iter_to_attached_stages;
+ if (iter_to_attached_stage.find(std::make_pair(stage_id, fused_ids[i])) !=
+ iter_to_attached_stage.end()) {
+ LOG(FATAL) << "Invalid Fuse. Trying to fuse iterators that have been attached by some "
+ << "stages. State before fusion:\n"
+ << (*state);
+ }
+ }
+
+ const Iterator& it = stage->iters[fused_ids[i]];
+ new_name = new_name + it->name + "@";
+
+ if (it->range.defined() && new_extent.defined()) {
+ new_extent = new_extent * it->range->extent;
+ } else {
+ new_extent = PrimExpr();
+ }
+
+ if (i == 0) {
+ new_iter_kind = it->iter_kind;
+ } else {
+ if (new_iter_kind != it->iter_kind) {
+ new_iter_kind = IteratorKind::kMixed;
+ }
+ }
+ }
+
+ Range range;
+ if (new_extent.defined()) {
+ range = Range::FromMinExtent(0, new_extent);
+ }
+ Iterator new_it = Iterator(new_name, range, new_iter_kind, IteratorAnnotation::kNone);
+ Array<Iterator> new_iters;
+ new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + fused_ids.front());
+ new_iters.push_back(new_it);
+ new_iters.insert(new_iters.end(), stage->iters.begin() + fused_ids.back() + 1,
+ stage->iters.end());
+
+ StateNode* pstate = state->CopyOnWrite();
+ pstate->stages.Set(stage_id,
+ Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
+
+ // Two vectors are used to represent the iterator relation before and after fuse
+ // The original iterators in AttachMap will be updated with the new iterators
+ std::vector<IterKey> from_iters;
+ std::vector<IterKey> to_iters;
+ const size_t begin_id = fused_ids.front(), end_id = fused_ids.back();
+ for (size_t i = 0; i < old_iter_size; ++i) {
+ if (i <= begin_id) {
+ continue;
+ } else if (i > end_id) {
+ // move forward
+ from_iters.emplace_back(stage_id, i);
+ to_iters.emplace_back(stage_id, i - end_id + begin_id);
+ } else {
+ // move to the fused id
+ from_iters.emplace_back(stage_id, i);
+ to_iters.emplace_back(stage_id, begin_id);
+ }
+ }
+ pstate->attach_map.UpdateIters(from_iters, to_iters);
+
+ return new_it;
+}
+
+IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ auto stage = (*stages)[stage_id];
+ const Array<IterVar>& axes = stage_to_axes->at(stage);
+
+ Array<IterVar> to_fuse;
+ for (const auto& i : fused_ids) {
+ to_fuse.push_back(axes[i]);
+ }
+ IterVar fused_axis;
+ stage.fuse(to_fuse, &fused_axis);
+
+ Array<IterVar> new_axes;
+ new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front());
+ new_axes.push_back(fused_axis);
+ new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end());
+
+ stage_to_axes->Set(stage, std::move(new_axes));
+ stages->Set(stage_id, std::move(stage));
+ return fused_axis;
+}
+
+String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ const auto& stage = (*stages)[stage_id];
+ std::stringstream to_fuse;
+
+ for (size_t i = 0; i < fused_ids.size(); ++i) {
+ to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint);
+ if (i != fused_ids.size() - 1) {
+ to_fuse << ", ";
+ }
+ }
+
+ std::stringstream ss;
+ const auto& fused = ApplyToSchedule(stages, stage_to_axes);
+
+ ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse("
+ << to_fuse.str() << ")\n";
+
+ return ss.str();
+}
+
/********** Reorder **********/
ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
auto node = make_object<ReorderStepNode>();
@@ -47,6 +447,41 @@ ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
data_ = std::move(node);
}
+ReorderStep::ReorderStep(dmlc::JSONReader* reader) {
+ auto node = make_object<ReorderStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ std::vector<int> int_list;
+ reader->Read(&int_list);
+ ::tvm::Array<::tvm::Integer> after_ids;
+ for (const auto& i : int_list) {
+ after_ids.push_back(i);
+ }
+ node->after_ids = after_ids;
+ data_ = std::move(node);
+}
+
+void ReorderStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(IntArrayToVector(after_ids));
+}
+
+void ReorderStepNode::ApplyToState(State* state) const {
+ const Stage& stage = (*state)->stages[stage_id];
+ Array<Iterator> iters;
+ for (auto x : after_ids) {
+ iters.push_back(stage->iters[x]);
+ }
+ state->CopyOnWrite()->stages.Set(
+ stage_id, Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
+}
+
void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
auto stage = (*stages)[stage_id];
@@ -83,6 +518,86 @@ String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
}
/********** Split **********/
+// common part for SplitStep, FollowSplitStep, and FollowFusedSplitStep
+Array<Iterator> ApplySplitToState(State* state, int stage_id, int iter_id,
+ const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
+ const Stage& stage = (*state)->stages[stage_id];
+ const Iterator& it = stage->iters[iter_id];
+ size_t old_iter_size = stage->iters.size();
+ bool concrete = true;
+
+ Optional<PrimExpr> tosplit_min, tosplit_extent;
+ if (it->range.defined()) {
+ tosplit_min = it->range->min;
+ tosplit_extent = it->range->extent;
+ } else {
+ tosplit_min = NullOpt;
+ tosplit_extent = NullOpt;
+ }
+
+ Array<Iterator> outs;
+ for (size_t i = 0; i < lengths.size(); ++i) {
+ Optional<Integer> l;
+ String name;
+ if (inner_to_outer) {
+ l = lengths[lengths.size() - i - 1];
+ name = it->name + "." + std::to_string(lengths.size() - i);
+ } else {
+ l = lengths[i];
+ name = it->name + "." + std::to_string(i);
+ }
+ Iterator res;
+ if (l && tosplit_min && tosplit_extent) {
+ res = Iterator(name, Range::FromMinExtent(tosplit_min.value(), l.value()), it->iter_kind,
+ IteratorAnnotation::kNone);
+ tosplit_min = Integer(0);
+ tosplit_extent = indexdiv(tosplit_extent.value() + l.value() - 1, l.value());
+ } else {
+ res = Iterator(name, Range(), it->iter_kind, IteratorAnnotation::kNone);
+ tosplit_min = NullOpt;
+ tosplit_extent = NullOpt;
+ concrete = false;
+ }
+ outs.push_back(std::move(res));
+ }
+
+ Range range;
+ if (tosplit_min && tosplit_extent) {
+ range = Range::FromMinExtent(tosplit_min.value(), tosplit_extent.value());
+ }
+ if (inner_to_outer) {
+ outs.push_back(Iterator(it->name + ".0", range, it->iter_kind, IteratorAnnotation::kNone));
+ // Reverse the Iterator array
+ Array<Iterator> temp(outs.rbegin(), outs.rend());
+ outs = std::move(temp);
+ } else {
+ outs.push_back(Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_kind,
+ IteratorAnnotation::kNone));
+ }
+
+ Array<Iterator> new_iters;
+ new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id);
+ new_iters.insert(new_iters.end(), outs.begin(), outs.end());
+ new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end());
+
+ StateNode* pstate = state->CopyOnWrite();
+ pstate->stages.Set(stage_id,
+ Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
+ pstate->concrete &= concrete;
+
+ // Two vectors are used to represent the iterator relation before and after split
+ // The original iterators in AttachMap will be updated with the new iterators
+ std::vector<IterKey> from_iters;
+ std::vector<IterKey> to_iters;
+ for (size_t i = iter_id; i < old_iter_size; ++i) {
+ from_iters.emplace_back(stage_id, i);
+ to_iters.emplace_back(stage_id, i + lengths.size());
+ }
+ pstate->attach_map.UpdateIters(from_iters, to_iters);
+
+ return outs;
+}
+
Array<IterVar> ApplySplitToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
int stage_id, int iter_id,
const Array<Optional<Integer>>& lengths, bool inner_to_outer) {
@@ -171,6 +686,51 @@ SplitStep::SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
data_ = std::move(node);
}
+SplitStep::SplitStep(dmlc::JSONReader* reader) {
+ auto node = make_object<SplitStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->iter_id);
+ int int_val;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&int_val);
+ if (int_val) {
+ node->extent = Integer(int_val);
+ }
+ s = reader->NextArrayItem();
+ CHECK(s);
+ std::vector<int> int_list;
+ reader->Read(&int_list);
+ ::tvm::Array<::tvm::Optional<::tvm::Integer>> lengths;
+ for (const auto& i : int_list) {
+ lengths.push_back(::tvm::Integer(i));
+ }
+ node->lengths = lengths;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->inner_to_outer);
+ data_ = std::move(node);
+}
+
+void SplitStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(iter_id);
+ writer->WriteArrayItem(extent ? GetIntImm(extent.value()) : 0);
+ writer->WriteArrayItem(IntArrayToVector(lengths));
+ writer->WriteArrayItem(static_cast<int>(inner_to_outer));
+}
+
+Array<Iterator> SplitStepNode::ApplyToState(State* state) const {
+ return ApplySplitToState(state, stage_id, iter_id, lengths, inner_to_outer);
+}
+
Array<IterVar> SplitStepNode::ApplyToSchedule(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
return ApplySplitToSchedule(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
@@ -181,57 +741,185 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
}
-/********** Fuse **********/
-FuseStep::FuseStep(int stage_id, const Array<Integer>& fused_ids) {
- auto node = make_object<FuseStepNode>();
+/********** Primitives working on multiple stages **********/
+
+/********** Compute At **********/
+ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id) {
+ auto node = make_object<ComputeAtStepNode>();
node->stage_id = stage_id;
- for (const auto& x : fused_ids) {
- CHECK(x->IsInstance<IntImmNode>());
- }
- node->fused_ids = fused_ids;
+ node->target_stage_id = target_stage_id;
+ node->target_iter_id = target_iter_id;
data_ = std::move(node);
}
-IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes) const {
- auto stage = (*stages)[stage_id];
- const Array<IterVar>& axes = stage_to_axes->at(stage);
+ComputeAtStep::ComputeAtStep(dmlc::JSONReader* reader) {
+ auto node = make_object<ComputeAtStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->target_stage_id);
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->target_iter_id);
+ data_ = std::move(node);
+}
- Array<IterVar> to_fuse;
- for (const auto& i : fused_ids) {
- to_fuse.push_back(axes[i]);
+void ComputeAtStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+ writer->WriteArrayItem(target_stage_id);
+ writer->WriteArrayItem(target_iter_id);
+}
+void ComputeAtStepNode::ApplyToState(State* state) const {
+ const Stage& stage = (*state)->stages[stage_id];
+
+ // Remove the bound information of each iterator since they may not be accurate after
+ // compute at
+ Array<Iterator> new_iters;
+ for (const Iterator& it : stage->iters) {
+ new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
}
- IterVar fused_axis;
- stage.fuse(to_fuse, &fused_axis);
- Array<IterVar> new_axes;
- new_axes.insert(new_axes.end(), axes.begin(), axes.begin() + fused_ids.front());
- new_axes.push_back(fused_axis);
- new_axes.insert(new_axes.end(), axes.begin() + fused_ids.back() + 1, axes.end());
+ StateNode* pstate = state->CopyOnWrite();
+ pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+ ComputeAtKind::kIter, stage->attrs));
+ // Update attach map
+ pstate->attach_map.SetComputeAtIter(stage_id, target_stage_id, target_iter_id);
+}
+
+void ComputeAtStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ te::Stage stage = (*stages)[stage_id];
+ const auto& target_stage = (*stages)[target_stage_id];
+ const auto& target_axis = (*stage_to_axes)[target_stage][target_iter_id];
+ stage.compute_at(target_stage, target_axis);
- stage_to_axes->Set(stage, std::move(new_axes));
stages->Set(stage_id, std::move(stage));
- return fused_axis;
}
-String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
- StageToAxesMap* stage_to_axes) const {
+String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ std::stringstream ss;
const auto& stage = (*stages)[stage_id];
- std::stringstream to_fuse;
+ const auto& target_stage = (*stages)[target_stage_id];
+ ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name)
+ << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n";
+ ApplyToSchedule(stages, stage_to_axes);
+ return ss.str();
+}
- for (size_t i = 0; i < fused_ids.size(); ++i) {
- to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint);
- if (i != fused_ids.size() - 1) {
- to_fuse << ", ";
- }
+/********** Compute Inline **********/
+ComputeInlineStep::ComputeInlineStep(int stage_id) {
+ auto node = make_object<ComputeInlineStepNode>();
+ node->stage_id = stage_id;
+ data_ = std::move(node);
+}
+
+ComputeInlineStep::ComputeInlineStep(dmlc::JSONReader* reader) {
+ auto node = make_object<ComputeInlineStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ data_ = std::move(node);
+}
+
+void ComputeInlineStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+}
+
+void ComputeInlineStepNode::ApplyToState(State* state) const {
+ const Stage& stage = (*state)->stages[stage_id];
+
+ // Check the validity of compute_inline
+ for (size_t i = 0; i < stage->iters.size(); ++i) {
+ CHECK_EQ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, i)), 0)
+ << "Invalid compute_inline: There are some other stages that are attached to the "
+ << "target stage";
}
+ StateNode* pstate = state->CopyOnWrite();
+ auto new_stage = pstate->stages[stage_id];
+ new_stage.CopyOnWrite()->compute_at = ComputeAtKind::kInlined;
+ pstate->stages.Set(stage_id, std::move(new_stage));
+ // Update attach map
+ pstate->attach_map.DeleteStage(stage_id);
+}
+
+void ComputeInlineStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ auto stage = (*stages)[stage_id];
+ stage.compute_inline();
+ stages->Set(stage_id, std::move(stage));
+}
+
+String ComputeInlineStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
std::stringstream ss;
- const auto& fused = ApplyToSchedule(stages, stage_to_axes);
+ const auto& stage = (*stages)[stage_id];
+ ss << "s[" << CleanName(stage->op->name) << "].compute_inline()\n";
+ ApplyToSchedule(stages, stage_to_axes);
+ return ss.str();
+}
- ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse("
- << to_fuse.str() << ")\n";
+/********** Compute Root **********/
+ComputeRootStep::ComputeRootStep(int stage_id) {
+ auto node = make_object<ComputeRootStepNode>();
+ node->stage_id = stage_id;
+ data_ = std::move(node);
+}
+ComputeRootStep::ComputeRootStep(dmlc::JSONReader* reader) {
+ auto node = make_object<ComputeRootStepNode>();
+ bool s;
+ s = reader->NextArrayItem();
+ CHECK(s);
+ reader->Read(&node->stage_id);
+ data_ = std::move(node);
+}
+
+void ComputeRootStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+ writer->WriteArraySeperator();
+ writer->WriteString(record_prefix_str);
+ writer->WriteArrayItem(stage_id);
+}
+
+void ComputeRootStepNode::ApplyToState(State* state) const {
+ const Stage& stage = (*state)->stages[stage_id];
+
+ // Remove the bound information of each iterator since they may not be accurate after
+ // compute root
+ Array<Iterator> new_iters;
+ for (const Iterator& it : stage->iters) {
+ new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
+ }
+
+ StateNode* pstate = state->CopyOnWrite();
+ pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+ ComputeAtKind::kRoot, stage->attrs));
+ // Update attach map
+ pstate->attach_map.DeleteStage(stage_id);
+}
+
+void ComputeRootStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ auto stage = (*stages)[stage_id];
+ stage.compute_root();
+ stages->Set(stage_id, std::move(stage));
+}
+
+String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes) const {
+ std::stringstream ss;
+ const auto& stage = (*stages)[stage_id];
+ ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n";
+ ApplyToSchedule(stages, stage_to_axes);
return ss.str();
}
diff --git a/src/auto_scheduler/transform_step.h b/src/auto_scheduler/transform_step.h
index d840cc0..ce3ca50 100644
--- a/src/auto_scheduler/transform_step.h
+++ b/src/auto_scheduler/transform_step.h
@@ -20,29 +20,34 @@
/*!
* \file auto_scheduler/transform_step.h
* \brief Transformation steps. For each schedule primitive, there is a corresponding transform
- * step. The implementation of each step consists of 2 parts:
- * - transform_step.cc: How each step interacts with TE and TE's schedule primitives
- * - loop_state.cc: How each step updates LoopState
+ * step.
*
* \note To add a new transform step:
* Take fuse step for example:
- * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its construction
- * function `FuseStep::FuseStep(...)` in `transform_steps.cc`
- * 2. Implement `FuseStepNode::ApplyToSchedule` and `FuseStepNode::PrintAsPythonAPI`.
+ * 1. Define class `FuseStepNode`, `FuseStep` in `transform_steps.h`, and implement its first
+ * construction function `FuseStep::FuseStep()` in `transform_steps.cc`.
+ * 2. Implement `FuseStepNode::ApplyToSchedule()` and `FuseStepNode::PrintAsPythonAPI()`.
* - In these two functions you need to lower this step with tvm's te schedule API
- * 3. Implement `State::fuse` and `State::DoFuseStep`.
+ * 3. Implement `FuseStepNode::ApplyToState` and the state API `State::fuse`.
* - In these two functions you need to incrementally update all data structures in State with
- * CopyOnWrite style
- * 4. Add you step to `ComputeDAG::ApplySteps` and make sure it works.
- * 5. Add log record serialization support in `struct Handler<Array<::tvm::auto_scheduler::Step>>`
- * in `record.cc`.
- * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test.
+ * CopyOnWrite style.
+ * 4. Add your step implementation to `StepApplyToState`, `StepApplyToSchedule` and
+ * `StepPrintAsPythonAPI`, make sure it works.
+ * 5. Log record serialization support:
+ * - Add `FuseStepNode::WriteToRecord` which takes a mutable JSONWriter pointer as input and
+ * output the record to it.
+ * - Add another construction function that takes a mutable JSONReader as input, this will get a
+ * step record from the reader and create the step.
+ * - Add the step implementation to `StepReadFromRecord`.
+ * 6. Add its corresponding Python API to `loop_state.py` and necessary unit test, the test should
+ * at lease consists of two parts: the functional test and the record serialization test.
*/
#ifndef TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
#define TVM_AUTO_SCHEDULER_TRANSFORM_STEP_H_
#include <dmlc/common.h>
+#include <dmlc/json.h>
#include <tvm/node/node.h>
#include <tvm/te/schedule.h>
@@ -53,6 +58,92 @@ namespace auto_scheduler {
typedef Map<tvm::te::Stage, Array<tir::IterVar>, ObjectHash, ObjectEqual> StageToAxesMap;
+/*! \brief The type of an iterator. */
+enum class IteratorKind : int {
+ /*! \brief Spatial iterator. */
+ kSpatial = 0,
+ /*! \brief Reduction iterator. */
+ kReduction = 1,
+ /*! \brief Fused spatial and reduction iterator. */
+ kMixed = 2,
+ /*! \brief Special iterator. (e.g. virtual root iterator) */
+ kSpecial = 3
+};
+
+/*! \brief The type of an iterator's annotation. */
+enum class IteratorAnnotation : int {
+ /*! \brief This iterator has no annotation. */
+ kNone = 0,
+ /*! \brief This iterator has been unrolled. */
+ kUnroll = 1,
+ /*! \brief This iterator has been vectorized. */
+ kVectorize = 2,
+ /*! \brief This iterator has been paralleld. */
+ kParallel = 3,
+ /*! \brief This iterator has been bind to vthread. */
+ kVThread = 4,
+ /*! \brief This iterator has been bind to blockIdx.x. */
+ kBlockX = 5,
+ /*! \brief This iterator has been bind to threadIdx.x. */
+ kThreadX = 6,
+ /*! \brief This iterator has been bind to blockIdx.y. */
+ kBlockY = 7,
+ /*! \brief This iterator has been bind to threadIdx.y. */
+ kThreadY = 8,
+ /*! \brief This iterator has been bind to blockIdx.y. */
+ kBlockZ = 9,
+ /*! \brief This iterator has been bind to threadIdx.y. */
+ kThreadZ = 10,
+ /*! \brief This iterator has been mapped with a tensorize intrinsic. */
+ kTensorize = 11
+};
+
+extern const char* IteratorAnnotationString[];
+
+/*!
+ * \brief A for loop iterator
+ * Similar to tvm::IterVar in `include/tvm/tir/expr.h`
+ */
+class IteratorNode : public Object {
+ public:
+ /*! \brief The name of this iterator. */
+ String name;
+ /*! \brief The range of this iterator. */
+ Range range;
+ /*! \brief The iterator type of this iterator. */
+ IteratorKind iter_kind;
+ /*! \brief The annotation type of this iterator. */
+ IteratorAnnotation annotation;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("name", &name);
+ v->Visit("range", &range);
+ v->Visit("iter_kind", &iter_kind);
+ v->Visit("annotation", &annotation);
+ }
+
+ static constexpr const char* _type_key = "auto_scheduler.Iterator";
+ TVM_DECLARE_FINAL_OBJECT_INFO(IteratorNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IteratorNode.
+ * \sa IteratorNode
+ */
+class Iterator : public ObjectRef {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param name The name of this iterator.
+ * \param range The range of this iterator.
+ * \param iter_kind The iterator type of this iterator.
+ * \param annotation The annotation type of this iterator.
+ */
+ Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode);
+};
+
/*!
* \brief The base class of transformation steps. Each step has its corresponding tvm.te
* schedule primitives.
@@ -62,6 +153,12 @@ class StepNode : public Object {
/*! \brief The index of the stage. */
int stage_id;
+ /*!
+ * \brief Serialize the current step record to JSONWriter.
+ * \param writer The output JSONWriter.
+ */
+ virtual void WriteToRecord(dmlc::JSONWriter* writer) const = 0;
+
static constexpr const char* _type_key = "auto_scheduler.Step";
TVM_DECLARE_BASE_OBJECT_INFO(StepNode, Object);
};
@@ -75,6 +172,172 @@ class Step : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
};
+// Forward declaration
+class State;
+class ComputeDAG;
+
+/*!
+ * \brief Read a step record from JSONReader and create the corresponding step.
+ * \param reader The input JSONReader.
+ */
+Step StepReadFromRecord(dmlc::JSONReader* reader);
+
+/*!
+ * \brief Apply the step to State.
+ * \param step The step to be applied to State.
+ * \param state A mutable pointer to State.
+ * \param dag The original ComputeDAG of this state.
+ * \return The iterator result after annotate.
+ */
+void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
+
+/*!
+ * \brief Apply the step to tvm.schedule.
+ * \param step The step to be applied to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ */
+void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes);
+
+/*!
+ * \brief Print the step as equivalent python schedule API.
+ * \param step The step to be applied to python API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
+ StageToAxesMap* stage_to_axes);
+
+/********** Primitives working on single stage **********/
+
+/*!
+ * \brief Annotation step that corresponds to vectorize, parallel, unroll and thread binding.
+ * (i.e. te::Stage::vectorize, te::Stage::parallel, te::Stage::vectorize, te::Stage::bind)
+ */
+class AnnotationStepNode : public StepNode {
+ public:
+ /*! \brief The index of the iterator to add annotation. */
+ int iter_id;
+ /*! \brief The annotation type of this step. */
+ IteratorAnnotation annotation;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ * \return The iterator result after annotate.
+ */
+ Iterator ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ */
+ void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* record_prefix_str = "AN";
+
+ static constexpr const char* _type_key = "auto_scheduler.AnnotationStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AnnotationStepNode.
+ * \sa AnnotationStepNode
+ */
+class AnnotationStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to add annotation.
+ * \param iter_id The index of the iterator to add annotation.
+ * \param ann The annotation type of this step.
+ */
+ AnnotationStep(int stage_id, int iter_id, IteratorAnnotation ann);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit AnnotationStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(AnnotationStep, Step, AnnotationStepNode);
+};
+
+/*! \brief Fuse step that corresponds to te::Stage::fuse */
+class FuseStepNode : public StepNode {
+ public:
+ /*! \brief The ids of iterators to fuse. */
+ Array<Integer> fused_ids;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ * \return The iterator result after fuse.
+ * \note If the iterators to be fused have stages attached at them(by compute_at), the fused
+ * result will become the new attach point.
+ */
+ Iterator ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return The iterator result after fuse.
+ */
+ tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* record_prefix_str = "FU";
+
+ static constexpr const char* _type_key = "auto_scheduler.FuseStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to FuseStepNode.
+ * \sa FuseStepNode
+ */
+class FuseStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be fused.
+ * \param fused_ids The index of the iterators to be fused.
+ */
+ FuseStep(int stage_id, const Array<Integer>& fused_ids);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit FuseStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
+};
+
/*! \brief Reorder step that corresponds to te::Stage::reorder */
class ReorderStepNode : public StepNode {
public:
@@ -84,21 +347,31 @@ class ReorderStepNode : public StepNode {
*/
Array<Integer> after_ids;
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ void ApplyToState(State* state) const;
+
/*!
- * \brief Apply the current state to tvm.schedule
+ * \brief Apply the current step to tvm.schedule.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
- * \brief Print step as equivalent python schedule API.
+ * \brief Print the current step as equivalent python schedule API.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+ static constexpr const char* record_prefix_str = "RE";
+
static constexpr const char* _type_key = "auto_scheduler.ReorderStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object);
};
@@ -116,6 +389,13 @@ class ReorderStep : public Step {
*/
ReorderStep(int stage_id, const Array<Integer>& after_ids);
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit ReorderStep(dmlc::JSONReader* reader);
+
TVM_DEFINE_OBJECT_REF_METHODS(ReorderStep, Step, ReorderStepNode);
};
@@ -137,8 +417,19 @@ class SplitStepNode : public StepNode {
*/
bool inner_to_outer;
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
/*!
- * \brief Apply the current state to tvm.schedule
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ * \return The iterator results after split.
+ * \note If we do split on an iterator which has stages attached at it(by compute_at), the inner
+ * most iterator of split results will become the new attach point.
+ */
+ Array<Iterator> ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \return The iterator results after split.
@@ -147,13 +438,15 @@ class SplitStepNode : public StepNode {
StageToAxesMap* stage_to_axes) const;
/*!
- * \brief Print step as equivalent python schedule API.
+ * \brief Print the current step as equivalent python schedule API.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+ static constexpr const char* record_prefix_str = "SP";
+
static constexpr const char* _type_key = "auto_scheduler.SplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object);
};
@@ -175,49 +468,195 @@ class SplitStep : public Step {
SplitStep(int stage_id, int iter_id, Optional<PrimExpr> extent,
const Array<Optional<Integer>>& lengths, bool inner_to_outer);
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit SplitStep(dmlc::JSONReader* reader);
+
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};
-/*! \brief Fuse step that corresponds to te::Stage::fuse */
-class FuseStepNode : public StepNode {
+/********** Primitives working on multiple stages **********/
+
+/*! \brief Compute at step that corresponds to te::Stage::compute_at */
+class ComputeAtStepNode : public StepNode {
public:
- /*! \brief The ids of iterators to fuse. */
- Array<Integer> fused_ids;
+ /*! \brief The index of stage that this step will compute at to. */
+ int target_stage_id;
+ /*! \brief The index of iterator in target stage that this step will compute at to. */
+ int target_iter_id;
+
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ * \note After compute_at, we need careful dependency analysis to compute the accurate bound
+ * information. However, it is relatively expensive and complicated, so we just fill "None" as
+ * bound for the newly created iterators.
+ * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+ */
+ void ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ */
+ void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* record_prefix_str = "CA";
+
+ static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ComputeAtStepNode.
+ * \sa ComputeAtStepNode
+ */
+class ComputeAtStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be compute at.
+ * \param target_stage_id The index of stage that this step will compute at to.
+ * \param target_iter_id The index of iterator in target stage that this step will compute at to.
+ */
+ ComputeAtStep(int stage_id, int target_stage_id, int target_iter_id);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit ComputeAtStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ComputeAtStep, Step, ComputeAtStepNode);
+};
+
+/*! \brief Compute inline step that corresponds to te::Stage::compute_inline */
+class ComputeInlineStepNode : public StepNode {
+ public:
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
/*!
- * \brief Apply the current state to tvm.schedule
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ */
+ void ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \return The iterator result after fuse.
*/
- tir::IterVar ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+ void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
/*!
- * \brief Print step as equivalent python schedule API.
+ * \brief Print the current step as equivalent python schedule API.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
- static constexpr const char* _type_key = "auto_scheduler.FuseStep";
- TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object);
+ static constexpr const char* record_prefix_str = "CI";
+
+ static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object);
};
/*!
- * \brief Managed reference to FuseStepNode.
- * \sa FuseStepNode
+ * \brief Managed reference to ComputeInlineStepNode.
+ * \sa ComputeInlineStepNode
*/
-class FuseStep : public Step {
+class ComputeInlineStep : public Step {
public:
/*!
* \brief The constructor.
- * \param stage_id The index of the stage to be fused.
- * \param fused_ids The index of the iterators to be fused.
+ * \param stage_id The index of the stage to be compute inline.
*/
- FuseStep(int stage_id, const Array<Integer>& fused_ids);
+ explicit ComputeInlineStep(int stage_id);
- TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit ComputeInlineStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ComputeInlineStep, Step, ComputeInlineStepNode);
+};
+
+/*! \brief Compute root step that corresponds to te::Stage::compute_root */
+class ComputeRootStepNode : public StepNode {
+ public:
+ void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+ /*!
+ * \brief Apply the current step to State.
+ * \param state A mutable pointer to State.
+ * \note After compute_at, we need careful dependency analysis to compute the accurate bound
+ * information. However, it is relatively expensive and complicated, so we just fill "None" as
+ * bound for the newly created iterators.
+ * Call ComputeDAG::InferBound on the updated state to get the complete bound information.
+ */
+ void ApplyToState(State* state) const;
+
+ /*!
+ * \brief Apply the current step to tvm.schedule.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return The iterator result after fuse.
+ */
+ void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ /*!
+ * \brief Print the current step as equivalent python schedule API.
+ * \param stages A pointer to a `te::Stage` Array.
+ * \param stage_to_axes A pointer to a StageToAxesMap.
+ * \return Python schedule code.
+ */
+ String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+ static constexpr const char* record_prefix_str = "CR";
+
+ static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep";
+ TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ComputeRootStepNode.
+ * \sa ComputeRootStepNode
+ */
+class ComputeRootStep : public Step {
+ public:
+ /*!
+ * \brief The constructor.
+ * \param stage_id The index of the stage to be compute root
+ */
+ explicit ComputeRootStep(int stage_id);
+
+ /*!
+ * \brief The constructor used to read a step record from JSONReader and create the
+ * corresponding step.
+ * \param reader The input JSONReader.
+ */
+ explicit ComputeRootStep(dmlc::JSONReader* reader);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
};
} // namespace auto_scheduler
diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h
index 5637780..de800da 100644
--- a/src/auto_scheduler/utils.h
+++ b/src/auto_scheduler/utils.h
@@ -63,7 +63,7 @@ struct hash<std::tuple<T1, T2, T3>> {
namespace tvm {
namespace auto_scheduler {
-/********** Utilities for Array, std::string **********/
+/********** Utilities for Array, std::vector, std::string **********/
/*! \brief Get the first appearance index of elements in an Array */
template <typename T>
inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, Array<Integer>* indices) {
@@ -89,6 +89,15 @@ inline int GetIndex(const Array<T>& array, const T& to_locate) {
return -1;
}
+/*! \brief Delete the item in a std::vector if it exists. */
+template <typename T>
+inline void FindAndDeleteItem(std::vector<T>* array, const T& to_delete) {
+ auto iter = std::find(array->begin(), array->end(), to_delete);
+ if (iter != array->end()) {
+ array->erase(iter);
+ }
+}
+
/*! \brief Replace a sub-string to another sub-string in a string */
inline void StrReplace(std::string* base, const std::string& from, const std::string& to) {
auto pos = base->find(from);
@@ -98,6 +107,27 @@ inline void StrReplace(std::string* base, const std::string& from, const std::st
}
}
+/*! \brief Convert a Array<Integer> to std::vector<int>. */
+inline std::vector<int> IntArrayToVector(const ::tvm::Array<::tvm::Integer>& data) {
+ std::vector<int> out;
+ for (const auto& x : data) {
+ CHECK(x.defined());
+ out.push_back(x);
+ }
+ return out;
+}
+
+/*! \brief Convert a Array<Optional<Integer>> to std::vector<int>. */
+inline std::vector<int> IntArrayToVector(
+ const ::tvm::Array<::tvm::Optional<::tvm::Integer>>& data) {
+ std::vector<int> out;
+ for (const auto& x : data) {
+ CHECK(x);
+ out.push_back(x.value());
+ }
+ return out;
+}
+
/********** Utilities for TVM Containers / ByteArray **********/
/*! \brief Compute mean of a FloatImm array */
inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py
index 078e1ae..fa22fdc 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -40,7 +40,7 @@ def matmul_auto_scheduler_test_rename_0(N, M, K):
C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
return [A, B, C]
-
+@auto_scheduler.register_workload
def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1):
data = te.placeholder((N, CI, H, W), name='Data')
kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='Kernel')
diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py
index 0801d92..32ea8fa 100644
--- a/tests/python/unittest/test_auto_scheduler_loop_state.py
+++ b/tests/python/unittest/test_auto_scheduler_loop_state.py
@@ -26,8 +26,8 @@ import topi
from test_auto_scheduler_common import matmul_auto_scheduler_test, conv2d_nchw_bn_relu
-def test_split_fuse_reorder():
- A, B, C = matmul_auto_scheduler_test(512, 512, 512)
+def test_split_fuse_reorder_annotation():
+ A, B, C = matmul_auto_scheduler_test(N=512, M=512, K=512)
dag = auto_scheduler.ComputeDAG([A, B, C])
s0 = dag.get_init_state()
i, j, k = s0[C].iters
@@ -61,5 +61,88 @@ def test_split_fuse_reorder():
assert s1[C].iters[4].range.extent == 8
assert s1[C].iters[5].range.extent == 2
+ res = s1.bind(C, i1, "blockIdx.x")
+ assert res == s1[C].iters[0]
+ assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["blockIdx.x"]
+
+ res = s1.bind(C, i2, "vthread")
+ assert res == s1[C].iters[1]
+ assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vthread"]
+
+ res = s1.bind(C, i3, "threadIdx.y")
+ assert res == s1[C].iters[2]
+ assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["threadIdx.y"]
+
+ res = s1.parallel(C, j1)
+ assert res == s1[C].iters[3]
+ assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["parallel"]
+
+ res = s1.unroll(C, j2)
+ assert res == s1[C].iters[4]
+ assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["unroll"]
+
+ res = s1.vectorize(C, j3)
+ assert res == s1[C].iters[5]
+ assert res.annotation == auto_scheduler.loop_state.State.ANNOTATION_TRANS_TABLE["vectorize"]
+
+
+def test_compute_at_root_inline():
+ dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, CI=3, CO=64,
+ kernel_size=7, strides=2, padding=3))
+ s0 = dag.get_init_state()
+
+ # data, padding, kernel = 0, 1, 2
+ conv = s0.stage_ops[3]
+ # bias = 4
+ bias_add = s0.stage_ops[5]
+ # bn_scale = 6
+ bn_mul = s0.stage_ops[7]
+ # bn_offset = 8
+ bn_add = s0.stage_ops[9]
+ relu = s0.stage_ops[10]
+
+ s0.compute_inline(bn_add)
+ assert s0[bn_add].compute_at == 1
+
+ s0.compute_inline(bn_mul)
+ assert s0[bn_mul].compute_at == 1
+
+ s0.compute_inline(bias_add)
+ assert s0[bias_add].compute_at == 1
+
+ assert s0[conv].iters[0].range.extent == 1
+ assert s0[conv].iters[1].range.extent == 64
+ assert s0[conv].iters[2].range.extent == 112
+ assert s0[conv].iters[3].range.extent == 112
+ assert s0[conv].iters[4].range.extent == 3
+ assert s0[conv].iters[5].range.extent == 7
+ assert s0[conv].iters[6].range.extent == 7
+ s0.compute_at(conv, relu, s0[relu].iters[2])
+ assert s0[conv].compute_at == 2
+ s0 = dag.infer_bound_from_state(s0)
+ assert s0[conv].iters[0].range.extent == 1
+ assert s0[conv].iters[1].range.extent == 1
+ assert s0[conv].iters[2].range.extent == 1
+ assert s0[conv].iters[3].range.extent == 112
+ assert s0[conv].iters[4].range.extent == 3
+ assert s0[conv].iters[5].range.extent == 7
+ assert s0[conv].iters[6].range.extent == 7
+
+ s0.compute_root(bn_mul)
+ assert s0[bn_mul].compute_at == 0
+
+ s0.compute_root(conv)
+ assert s0[conv].compute_at == 0
+ s0 = dag.infer_bound_from_state(s0)
+ assert s0[conv].iters[0].range.extent == 1
+ assert s0[conv].iters[1].range.extent == 64
+ assert s0[conv].iters[2].range.extent == 112
+ assert s0[conv].iters[3].range.extent == 112
+ assert s0[conv].iters[4].range.extent == 3
+ assert s0[conv].iters[5].range.extent == 7
+ assert s0[conv].iters[6].range.extent == 7
+
+
if __name__ == "__main__":
- test_split_fuse_reorder()
+ test_split_fuse_reorder_annotation()
+ test_compute_at_root_inline()
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index d6e6c51..333d20e 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -18,7 +18,8 @@
""" Test measurement and log serialization. """
import tvm
-from tvm import auto_scheduler
+import topi
+from tvm import te, auto_scheduler
import tempfile
from test_auto_scheduler_common import get_tiled_matmul
@@ -28,7 +29,44 @@ def test_record():
if not tvm.runtime.enabled("llvm"):
return
- dag, s = get_tiled_matmul()
+ A = te.placeholder((512, 512), name='A')
+ B = te.placeholder((512, 512), name='B')
+ k = te.reduce_axis((0, 512), name='k')
+ C = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C')
+ D = topi.nn.relu(C)
+ k = te.reduce_axis((0, 512), name='k')
+ E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='C')
+ F = topi.nn.relu(E)
+
+ dag = auto_scheduler.ComputeDAG([A, B, F])
+ s = dag.get_init_state()
+
+ # Split
+ its0 = s.split(C, s[C].iters[0], [4, 8, 8])
+ its1 = s.split(C, s[C].iters[4], [8, 4, 4])
+ # Reorder
+ s.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], its0[3], s[C].iters[8],
+ its1[3]])
+ # Fuse
+ s.fuse(C, [s[C].iters[0], s[C].iters[1], s[C].iters[2]])
+ # Compute at
+ s.split(F, s[F].iters[0], [2])
+ s.compute_at(E, F, s[F].iters[0])
+ # Compute inline
+ s.compute_inline(D)
+ # Compute root
+ s.compute_root(D)
+ # Parallel
+ s.parallel(C, s[C].iters[0])
+ # Thread bind(The blockIdx & threadIdx are used in GPU, just for record testing here)
+ s.bind(C, s[C].iters[1], "blockIdx.x")
+ s.bind(C, s[C].iters[2], "threadIdx.z")
+ s.bind(C, s[C].iters[3], "vthread")
+ # Unroll
+ s.unroll(C, s[C].iters[4])
+ # Vectorize
+ s.vectorize(C, s[C].iters[6])
+
target = tvm.target.create("llvm")
task = auto_scheduler.SearchTask(dag, "test", target)