You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/05/01 20:27:54 UTC
[incubator-tvm] branch master updated: [REFACTOR][BOYC] Non
recursive partitioning (#5493)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 9c1e74c [REFACTOR][BOYC] Non recursive partitioning (#5493)
9c1e74c is described below
commit 9c1e74ce0727ac7aacd012b35ac068a25cbc9a42
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Fri May 1 13:27:44 2020 -0700
[REFACTOR][BOYC] Non recursive partitioning (#5493)
* non recursive partitioning
* refactor maps
* rebase upstream
* refactor shared output
* address comments
Co-authored-by: Cody Yu <co...@gmail.com>
---
src/relay/transforms/partition_graph.cc | 393 ++++++++++----------------------
1 file changed, 115 insertions(+), 278 deletions(-)
diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc
index 3b0d6bc..634434d 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -54,39 +54,30 @@ namespace partitioning {
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
-/*!
- * \brief The checker that verifies if a Relay program is annotated correctly
- * for partitioning.
+/*! \brief This struct maintains the required metadata for a region to generate a corresponding
+ * global function and function call. Global function will be passed to the target specific codegen
+ * and function call will be used in the transform Relay graph to invoke the function in runtime.
*/
-class AnnotationChecker : public ExprVisitor {
- public:
- bool Check() {
- if (!found_start_ && !found_end_) {
- LOG(WARNING) << "No compiler annotation found";
- } else if (!found_start_) {
- LOG(ERROR) << "compiler_begin annotation is missing";
- return false;
- } else if (!found_end_) {
- LOG(ERROR) << "compiler_end annotation is missing";
- return false;
- }
- return true;
- }
+struct RegionFuncMetadata {
+ /*! \brief The call node of the generated global function for this region. */
+ Call func_call;
- void VisitExpr_(const CallNode* call) final {
- auto op_node = call->op.as<OpNode>();
- if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
- return;
- } else if (call->op == compiler_begin_op) {
- found_start_ = true;
- } else if (call->op == compiler_end_op) {
- found_end_ = true;
- }
- }
+ /*! \brief A list of argument pairs. Each pair includes (var, expr). var is used
+ * as a function node argument; input expression is used as a function call parameter.
+ */
+ std::vector<std::pair<Var, Expr>> args;
- private:
- bool found_start_{false};
- bool found_end_{false};
+ /*! \brief Map from each region output expr (compiler end) node to
+ * the corresponding function output expr.
+ */
+ std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> region_func_out;
+
+ /*! \brief Map from each region input expression (compiler begin) to
+ * the corresponding function input variable. This cache is used to make sure
+ * a region function will not have duplicated inputs even if it refers to
+ * the same expr multiple times.
+ */
+ std::unordered_map<Expr, Var, ObjectHash, ObjectEqual> region_func_in;
};
/*! \brief This class partitions the expr labeled with begin and end annotations
@@ -124,37 +115,35 @@ class AnnotationChecker : public ExprVisitor {
* the compiler name.
*/
-class Partitioner : public ExprMutator {
+class Partitioner : public MixedModeMutator {
public:
explicit Partitioner(const IRModule& module) : module_(module) {
for (auto f : module->functions) {
GlobalVar f_var = f.first;
BaseFunc f_func = f.second;
- // Creating regionset per function in the module
+ // Creating regionset per function in the module.
auto region_set = AnnotatedRegionSet::Create(f_func, partitioning::compiler_begin_op,
partitioning::compiler_end_op);
regions_sets_[region_set] = f_func;
}
}
- Expr VisitExpr_(const CallNode* call) final {
+ Expr Rewrite_(const CallNode* call, const Expr& post) final {
auto op_node = call->op.as<OpNode>();
if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
- return ExprMutator::VisitExpr_(call);
+ return post;
} else if (call->op == compiler_begin_op) {
- // The annotation node is inserted on edge so it must have only one
- // argument.
+ // The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
// Traverse the rest graph.
Expr parent = call->args[0];
- auto input_expr = VisitExpr(parent);
+ auto input_expr = Downcast<Call>(post)->args[0];
// Backtrace the parent to find the first ancestor node that is not a begin or end op
while (const auto* parent_call = parent.as<CallNode>()) {
- if (parent_call->op == compiler_begin_op ||
- parent_call->op == compiler_end_op) {
+ if (parent_call->op == compiler_begin_op || parent_call->op == compiler_end_op) {
parent = parent_call->args[0];
} else {
break;
@@ -165,8 +154,8 @@ class Partitioner : public ExprMutator {
int index = GetArgIdx(sg, GetRef<Call>(call));
CHECK_NE(index, -1);
- if (shared_output_.count(parent) && shared_output_[parent].count(sg)) {
- return shared_output_[parent][sg];
+ if (region_func_meta_[sg].region_func_in.count(parent)) {
+ return region_func_meta_[sg].region_func_in[parent];
} else {
// The type of the created variable is the same as the compiler_begin
// node.
@@ -177,11 +166,11 @@ class Partitioner : public ExprMutator {
std::pair<Var, Expr> cand = std::make_pair(var, input_expr);
- if (std::find(region_args[sg].begin(), region_args[sg].end(), cand) ==
- region_args[sg].end()) {
- region_args[sg].push_back(cand);
+ if (std::find(region_func_meta_[sg].args.begin(), region_func_meta_[sg].args.end(), cand) ==
+ region_func_meta_[sg].args.end()) {
+ region_func_meta_[sg].args.push_back(cand);
}
- shared_output_[parent][sg] = var;
+ region_func_meta_[sg].region_func_in[parent] = var;
return std::move(var);
}
} else {
@@ -197,114 +186,21 @@ class Partitioner : public ExprMutator {
BaseFunc f = GetFunc(GetRef<Call>(call));
// Traverse subgraph inputs.
- auto input = VisitExpr(call->args[0]);
+ auto input = Downcast<Call>(post)->args[0];
CHECK(region.defined()) << "Region not defined for " << GetRef<Call>(call);
// functions are created for each annotated regions,
// when their first output is encountered.
// If multiple outputs are there, a tuple node is inserted at the end.
- // region_function_calls is map that maintains
- // (each annotated regions) --> created function
- if (region_function_calls.find(region) == region_function_calls.end()) {
- // First time this region is encountered in the traversal.
- // Creating the function.
+ if (!region_func_meta_[region].func_call.defined()) {
+ // First time this region is encountered in the traversal. Creating the function.
CreateFunction(region, call);
}
- // Retrieve this particular output of function.
- return GetFunctionOutput(region, GetRef<Call>(call));
- }
- }
-
- Expr VisitExpr_(const TupleNode* op) final {
- auto region = GetRegion(GetRef<Tuple>(op));
- if (!region.defined()) {
- return ExprMutator::VisitExpr_(op);
- } else {
- Array<Expr> fields;
- for (auto field : op->fields) {
- fields.push_back(VisitExpr(field));
- }
- return Tuple(fields);
- }
- }
-
- Expr VisitExpr_(const TupleGetItemNode* g) final {
- auto region = GetRegion(GetRef<TupleGetItem>(g));
- if (!region.defined()) {
- return ExprMutator::VisitExpr_(g);
- } else {
- auto t = VisitExpr(g->tuple);
- return TupleGetItem(t, g->index);
- }
- }
-
- Expr VisitExpr_(const FunctionNode* op) final {
- auto region = GetRegion(GetRef<Function>(op));
- if (!region.defined()) {
- return ExprMutator::VisitExpr_(op);
- } else {
- Array<Var> params;
- for (auto param : op->params) {
- Var new_param = Downcast<Var>(VisitExpr(param));
- params.push_back(new_param);
- }
- auto body = VisitExpr(op->body);
- return Function(params, body, op->ret_type, op->type_params, op->attrs);
- }
- }
-
- Expr VisitExpr_(const LetNode* op) final {
- auto region = GetRegion(GetRef<Let>(op));
- if (!region.defined()) {
- return ExprMutator::VisitExpr_(op);
- } else {
- Var var = Downcast<Var>(VisitExpr(op->var));
- auto value = VisitExpr(op->value);
- auto body = VisitExpr(op->body);
- return Let(var, value, body);
- }
- }
-
- Expr VisitExpr_(const IfNode* op) final {
- auto region = GetRegion(GetRef<If>(op));
- if (!region.defined()) {
- return ExprMutator::VisitExpr_(op);
- } else {
- auto guard = VisitExpr(op->cond);
- auto true_b = VisitExpr(op->true_branch);
- auto false_b = VisitExpr(op->false_branch);
- return If(guard, true_b, false_b);
- }
- }
-
- Expr VisitExpr_(const RefCreateNode* op) final {
- auto region = GetRegion(GetRef<RefCreate>(op));
- if (!region.defined()) {
- return ExprMutator::VisitExpr_(op);
- } else {
- Expr value = VisitExpr(op->value);
- return RefCreate(value);
- }
- }
- Expr VisitExpr_(const RefReadNode* op) final {
- auto region = GetRegion(GetRef<RefRead>(op));
- if (!region.defined()) {
- return ExprMutator::VisitExpr_(op);
- } else {
- Expr ref = VisitExpr(op->ref);
- return RefRead(ref);
- }
- }
-
- Expr VisitExpr_(const RefWriteNode* op) final {
- auto region = GetRegion(GetRef<RefWrite>(op));
- if (!region.defined()) {
- return ExprMutator::VisitExpr_(op);
- } else {
- Expr ref = VisitExpr(op->ref);
- Expr value = VisitExpr(op->value);
- return RefWrite(ref, value);
+ // Retrieve this particular output of function.
+ Expr region_out_expr = Downcast<Call>(GetRef<Call>(call))->args[0];
+ CHECK(region_func_meta_[region].region_func_out.count(region_out_expr));
+ return region_func_meta_[region].region_func_out[region_out_expr];
}
}
@@ -370,24 +266,22 @@ class Partitioner : public ExprMutator {
}
/*!
- * \brief This function is called first time that we encounter a compiler_end
- * node to create the function for the subgraph.
+ * \brief Create a function and its function call for the given region. If the function has
+ * multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
+ * will be created to serve output consumers.
*/
- void CreateFunction(AnnotatedRegion region, const CallNode* call) {
- // Create fields which is a unique list of outputs. Also populate
- // region_return_indices_ map which maps parent of compiler_end node to
- // corresponding index in fields.
+ void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
+ // Create fields which is a unique list of outputs.
Array<Expr> fields;
- int i = 0;
- for (auto ret : region->GetOutputs()) {
- auto ret_node = Downcast<Call>(ret)->args[0];
+ std::unordered_map<Expr, int, ObjectHash, ObjectEqual> out_expr_to_idx;
+ int out_idx = 0;
+ for (auto region_end_node : region->GetOutputs()) {
+ auto ret_node = Downcast<Call>(region_end_node)->args[0];
// Don't duplicate outputs.
- if (!region_return_indices_.count(region) ||
- !region_return_indices_[region].count(ret_node)) {
- auto ret_expr = VisitExpr(ret_node);
+ if (!out_expr_to_idx.count(ret_node)) {
+ auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
fields.push_back(ret_expr);
- region_return_indices_[region][ret_node] = i;
- i++;
+ out_expr_to_idx[ret_node] = out_idx++;
}
}
@@ -396,20 +290,14 @@ class Partitioner : public ExprMutator {
Map<Var, Expr> params_bind;
auto IsConstant = [](const Expr& expr) {
- if (expr->IsInstance<ConstantNode>())
- return true;
- if (expr->IsInstance<TupleNode>()) {
- auto tuple = expr.as<TupleNode>();
- for (const auto& field : tuple->fields) {
- if (!field->IsInstance<ConstantNode>())
- return false;
- }
- return true;
- }
- return false;
+ if (expr->IsInstance<ConstantNode>()) return true;
+ if (!expr->IsInstance<TupleNode>()) return false;
+ const auto* tn = expr.as<TupleNode>();
+ return std::all_of(tn->fields.begin(), tn->fields.end(),
+ [](const Expr& e) { return e->IsInstance<ConstantNode>(); });
};
- for (auto pair : region_args[region]) {
+ for (auto pair : region_func_meta_[region].args) {
params.push_back(pair.first);
if (IsConstant(pair.second)) {
params_bind.Set(pair.first, pair.second);
@@ -422,23 +310,21 @@ class Partitioner : public ExprMutator {
if (fields.size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
- Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
+ Function(params, fields[0], end_node->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
- std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+ std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
- global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
- runtime::String(name));
global_region_func =
- WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
- global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
- tvm::runtime::String(target));
+ WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name));
+ global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func =
- WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
+ WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target));
+ global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
@@ -446,8 +332,7 @@ class Partitioner : public ExprMutator {
}
std::string fname = name;
- CHECK(!module_->ContainGlobalVar(fname))
- << "Global function " << fname << " already exists";
+ CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
@@ -456,129 +341,81 @@ class Partitioner : public ExprMutator {
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
- // The return type of callnode is the same as the type of the
- // compiler_end node.
- auto ret = Call(glob_func, param_expr);
- region_function_calls[region] = ret;
- }
+ // Create a call node for the function.
+ auto call = Call(glob_func, param_expr);
+ region_func_meta_[region].func_call = call;
- /*!
- * \brief Get the return(output) of the function for compiler end node "end_arg".
- * This will return either a Call (for a function with a single output) or a
- * TupleGetItem (for a function with multiple outputs).
- */
- Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
- Expr arg = Downcast<Call>(end_arg)->args[0];
- // Function has one output.
- if (region_return_indices_[region].size() == 1) {
- return region_function_calls[region];
- }
- // Function has multiple outputs.
- // Use already made TupleGetItem.
- if (region_return_tuplegetitem_.count(region) &&
- region_return_tuplegetitem_[region].count(arg)) {
- return region_return_tuplegetitem_[region][arg];
+ // Create output expr(s) for the function call.
+ if (out_expr_to_idx.size() == 1) {
+ // Single output direcly uses the call node as the output expr.
+ region_func_meta_[region].region_func_out[out_expr_to_idx.begin()->first] = call;
+ } else {
+ // Multiple outptus need to create TupleGetItem nodes as output exprs.
+ for (auto pair : out_expr_to_idx) {
+ Expr region_out_expr = pair.first; // The arg of a compiler end node of this region.
+ int idx = pair.second; // Corresponding function output tuple index.
+ auto tuple_get_item = TupleGetItem(call, idx);
+ tuple_get_item->checked_type_ = region_out_expr->checked_type_;
+ region_func_meta_[region].region_func_out[region_out_expr] = tuple_get_item;
+ }
}
- // Create new TupleGetItem.
- CHECK(region_return_indices_.count(region) &&
- region_return_indices_[region].count(arg));
- int index = region_return_indices_[region][arg];
-
- auto func_call = region_function_calls[region];
- auto tuple_get_item_ = TupleGetItem(func_call, index);
- tuple_get_item_->checked_type_ = arg->checked_type_;
- region_return_tuplegetitem_[region][arg] = tuple_get_item_;
- return std::move(tuple_get_item_);
}
- /*!
- * \brief This map maintains the already created function calls.
- * This is required in the multi-output scenario, to link rest of the outputs
- * to call
- */
- std::unordered_map<AnnotatedRegion, Call, ObjectHash, ObjectEqual> region_function_calls;
-
- /*!
- * \brief This map maintains arguments (of region) visits through visitor
- * patterns. Those arguement var and expression will be used to when creating
- * the function.
- */
- std::unordered_map<AnnotatedRegion, std::vector<std::pair<Var, Expr>>, ObjectHash, ObjectEqual>
- region_args;
-
- /*!
- * \brief This map maintains the index of an output in the subgraph function
- * for a given region. If there are multiple entries for a region, then the
- * function has a tuple of multiple outputs for its return.
- */
- using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
- std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
- region_return_indices_;
+ /*! \brief Map from each region to its metadata of the generated function. */
+ std::unordered_map<AnnotatedRegion, RegionFuncMetadata, ObjectHash, ObjectEqual>
+ region_func_meta_;
- /*!
- * \brief This map holds already created TupleGetItem nodes for accessing
- * outputs of a function.
- */
- using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
- std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
- region_return_tuplegetitem_;
-
- /*!
- * \brief Each region set is associated with a function in the module.
+ /*! \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it
* belongs to
*/
std::unordered_map<AnnotatedRegionSet, BaseFunc, ObjectHash, ObjectEqual> regions_sets_;
- /*!\brief Cache the output that is shared by different nodes. */
- using RegionOutputMap = std::unordered_map<AnnotatedRegion, Var, ObjectHash, ObjectEqual>;
- std::unordered_map<Expr, RegionOutputMap, ObjectHash, ObjectEqual> shared_output_;
-
/*!\brief The IRModule used for partitioning. */
IRModule module_;
};
-class DefaultRemover : public ExprMutator {
- public:
- explicit DefaultRemover(const IRModule& module) : module_(module) {}
+IRModule RemoveDefaultAnnotations(IRModule module) {
+ class DefaultRemover : public ExprRewriter {
+ public:
+ DefaultRemover() = default;
- IRModule Remove() {
- auto glob_funcs = module_->functions;
- for (const auto& pair : glob_funcs) {
- if (auto* fn = pair.second.as<FunctionNode>()) {
- auto func = GetRef<Function>(fn);
- func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
- func->attrs);
- module_->Update(pair.first, func);
+ Expr Rewrite_(const CallNode* call, const Expr& post) final {
+ auto attrs = call->attrs.as<CompilerAttrs>();
+ if (attrs != nullptr && attrs->compiler == "default") {
+ return Downcast<Call>(post)->args[0];
}
+ return post;
}
- return module_;
- }
+ };
- Expr VisitExpr_(const CallNode* call) final {
- auto attrs = call->attrs.as<CompilerAttrs>();
- if (attrs != nullptr && attrs->compiler == "default") {
- return VisitExpr(call->args[0]);
+ auto glob_funcs = module->functions;
+ // module is mutable, hence, we make a copy of it.
+ module.CopyOnWrite();
+ for (const auto& pair : glob_funcs) {
+ if (auto* fn = pair.second.as<FunctionNode>()) {
+ auto func = GetRef<Function>(fn);
+ DefaultRemover remover;
+ auto removed = PostOrderRewrite(func->body, &remover);
+ func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs);
+ module->Update(pair.first, func);
}
- return ExprMutator::VisitExpr_(call);
}
-
- private:
- IRModule module_;
-};
+ return module;
+}
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
- runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
- [=](IRModule m, PassContext pc) {
- // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
- // by treating them as un-annotated, but we don't have it yet. This workaround pass removes
- // all "default" annotations and should be deleted in the future.
- auto new_m = partitioning::DefaultRemover(m).Remove();
- return partitioning::Partitioner(new_m).Partition();
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func = [=](IRModule m,
+ PassContext pc) {
+ // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute
+ // by treating them as un-annotated, but we don't have it yet. This workaround pass removes
+ // all "default" annotations and should be deleted in the future.
+ auto new_m = partitioning::RemoveDefaultAnnotations(m);
+ return partitioning::Partitioner(new_m).Partition();
};
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});