You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/07 04:30:21 UTC
[incubator-tvm] branch master updated: [relay][external codegen]
outline and inline lifted functions for external codegen (#4996)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 28ee806 [relay][external codegen] outline and inline lifted functions for external codegen (#4996)
28ee806 is described below
commit 28ee806dcbd803f4079365dd308a673bd1a89588
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Fri Mar 6 20:30:13 2020 -0800
[relay][external codegen] outline and inline lifted functions for external codegen (#4996)
* outline and inline lifted functions for external codegen
* add batch_norm test
* test batch_norm inline
---
src/relay/backend/build_module.cc | 7 +
src/relay/backend/vm/compiler.cc | 7 +
src/relay/backend/vm/inline_primitives.cc | 1 +
src/relay/backend/vm/lambda_lift.cc | 1 +
src/relay/pass/partition_graph.cc | 79 ++++----
src/relay/pass/to_a_normal_form.cc | 3 +
tests/python/relay/test_pass_partition_graph.py | 257 +++++++++++++++++++++---
7 files changed, 294 insertions(+), 61 deletions(-)
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 61ec281..41833c4 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode {
// Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module);
+ // Inline the functions that have been lifted by the module scope.
+ //
+ // TODO(@zhiics) Note that we need to be careful about the subgraphs with
+ // global function calls. We should make sure that these callees are also
+ // inline functions. However, this should be very unlikely for accelerators
+ // and vendor-provided libraries. So we don't handle for now.
+ relay_module = transform::Inline()(relay_module);
CHECK(relay_module.defined());
return relay_module;
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 2129b64..fc52a8e 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());
+ // Inline the functions that are lifted to the module scope. We perform this
+ // pass after all other optimization passes but before the memory allocation
+ // pass. This is because memory allocation pass will insert `invoke_tvm_op`
+ // and we use these ops to invoke the symbols in the module generated by
+ // external codegen.
+ pass_seqs.push_back(transform::Inline());
+
// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc
index 1d6ba4a..25a9bcd 100644
--- a/src/relay/backend/vm/inline_primitives.cc
+++ b/src/relay/backend/vm/inline_primitives.cc
@@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
+ if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n);
DLOG(INFO) << "Before inlining primitives: " << global
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
index 1cf671a..5cf66c5 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
+ if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc
index e58bf61..d9600bd 100644
--- a/src/relay/pass/partition_graph.cc
+++ b/src/relay/pass/partition_graph.cc
@@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor {
*/
class Partitioner : public ExprMutator {
public:
+ explicit Partitioner(const IRModule& module) : module_(module) {}
+
std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
for (auto candidate : this->subgraphs_) {
if (candidate->nodes.find(node) != candidate->nodes.end()) {
@@ -163,8 +165,10 @@ class Partitioner : public ExprMutator {
// Replace the begin annotation with an external call input variable.
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
+ // The type of the created variable is the same as the compiler_begin
+ // node.
auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
- input_expr->checked_type_);
+ call->checked_type_);
// Find the corresponding subgraph and add the argument.
auto subgraph = GetSubgraph(GetRef<Call>(call));
@@ -182,7 +186,7 @@ class Partitioner : public ExprMutator {
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
- // Check if the argument already belongs to an exist subgraph
+ // Check if the argument already belongs to an existing subgraph
auto subgraph = GetSubgraph(call->args[0]);
if (!subgraph) {
auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
@@ -207,16 +211,28 @@ class Partitioner : public ExprMutator {
}
auto subgraph_func =
- FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs());
+ FunctionNode::make(params, input, call->checked_type_, {}, Attrs());
- Expr arg0 = call->args[0];
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func =
FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
- return CallNode::make(subgraph_func, args);
+ subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1));
+ CHECK(!module_->ContainGlobalVar(name))
+ << "Global function " << name << " already exists";
+ // Create a global function and add it to the IRModule for the subgraph.
+ // This way we lift the functions that should be handled by external
+ // codegen to the module scope and rely on the pass manager to prevent relay
+ // function level passes (i.e. simplify inference and fusion) optimizing it.
+ GlobalVar glob_func(name);
+ module_->Add(glob_func, subgraph_func);
+ // The return type of callnode is the same as the type of the
+ // compiler_end node.
+ auto ret = CallNode::make(glob_func, args);
+ ret->checked_type_ = call->checked_type_;
+ return std::move(ret);
}
}
@@ -330,50 +346,39 @@ class Partitioner : public ExprMutator {
}
}
+ IRModule Partition() {
+ auto glob_funcs = module_->functions;
+ for (const auto& pair : glob_funcs) {
+ if (auto* fn = pair.second.as<FunctionNode>()) {
+ auto func = GetRef<Function>(fn);
+ func = FunctionNode::make(func->params,
+ VisitExpr(func->body),
+ func->ret_type,
+ func->type_params,
+ func->attrs);
+ module_->Update(pair.first, func);
+ }
+ }
+ return module_;
+ }
+
private:
int var_id_{0};
int subgraph_id_{0};
std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
+ IRModule module_;
};
-/*!
- * \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
- * the same codegen backend. This reduces rounds trips between TVM and external
- * backends. Likely we can borrow some ideas from operator fusion.
- *
- * For example, sg1 and sg2 should be combined if they belong to the same
- * codegen tool in the following case.
- *
- * op1
- * / \
- * sg1 sg2
- *
- * |
- * \|/
- *
- * op1
- * |
- * sg1_sg2
- *
- * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
- * inputs that obtained from the tuple.
- */
-
-Expr PartitionGraph(const Expr& expr) {
- Partitioner part;
- return part.Mutate(expr);
-}
-
} // namespace partitioning
namespace transform {
Pass PartitionGraph() {
- runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
- [=](Function f, IRModule m, PassContext pc) {
- return Downcast<Function>(partitioning::PartitionGraph(f));
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
+ [=](IRModule m, PassContext pc) {
+ return partitioning::Partitioner(m).Partition();
};
- auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {});
+ auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}
diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc
index 9322e49..c75afd1 100644
--- a/src/relay/pass/to_a_normal_form.cc
+++ b/src/relay/pass/to_a_normal_form.cc
@@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) {
auto funcs = m->functions;
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
+ if (const auto* n = it.second.as<FunctionNode>()) {
+ if (!n->UseDefaultCompiler()) continue;
+ }
Expr ret =
TransformF([&](const Expr& e) {
return ToANormalFormAux(e);
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 9c3228f..209376a 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -18,14 +18,12 @@
import os
import sys
import numpy as np
-import pytest
import tvm
-from tvm import te
import tvm.relay.testing
-import tvm.relay.transform as transform
from tvm import relay
from tvm import runtime
+from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator
@@ -189,7 +187,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
return lib
def check_vm_result():
- with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
@@ -200,7 +198,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result():
- with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+ with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
@@ -297,6 +295,7 @@ def test_extern_ccompiler_single_op():
def test_extern_ccompiler_default_ops():
def expected():
+ mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
@@ -305,11 +304,14 @@ def test_extern_ccompiler_default_ops():
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+ func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler",
tvm.tir.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
- add_call = relay.Call(func, [x, y])
+ glb_0 = relay.GlobalVar("ccompiler_0")
+ mod[glb_0] = func
+ add_call = relay.Call(glb_0, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0)
@@ -320,7 +322,6 @@ def test_extern_ccompiler_default_ops():
tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
- mod = tvm.IRModule()
mod["main"] = main
return mod
@@ -371,28 +372,65 @@ def test_extern_dnnl():
dtype = 'float32'
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
- data = relay.var('data', shape=(ishape), dtype=dtype)
- weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
- depthwise_conv2d_1 = relay.nn.conv2d(data,
- weight1,
- kernel_size=(3, 3),
- padding=(1, 1),
- groups=32)
- depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
- weight1,
- kernel_size=(3, 3),
- padding=(1, 1),
- groups=32)
- out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
-
- f = relay.Function([data, weight1], out)
+
+ def expected():
+ data0 = relay.var("data", shape=(ishape), dtype=dtype)
+ input0 = relay.var("input0", shape=(w1shape), dtype=dtype)
+ input1 = relay.var("input1", shape=(w1shape), dtype=dtype)
+ depthwise_conv2d_1 = relay.nn.conv2d(data0,
+ input0,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+ input1,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+ func = relay.Function([data0, input0, input1], out)
+ func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+ func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
+ func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl"))
+ func = func.set_attribute("ExternalSymbol",
+ tvm.tir.StringImm("dnnl_0"))
+ glb_var = relay.GlobalVar("dnnl_0")
+ mod = tvm.IRModule()
+ mod[glb_var] = func
+
+ data = relay.var("data", shape=(ishape), dtype=dtype)
+ weight = relay.var("input", shape=(w1shape), dtype=dtype)
+ main_f = relay.Function([data, weight], glb_var(data, weight, weight))
+ mod["main"] = main_f
+
+ return mod
+
+ def get_func():
+ data = relay.var("data", shape=(ishape), dtype=dtype)
+ weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
+ depthwise_conv2d_1 = relay.nn.conv2d(data,
+ weight1,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+ weight1,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ groups=32)
+ out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+ return relay.Function([data, weight1], out)
mod = tvm.IRModule()
- mod['main'] = WholeGraphAnnotator('dnnl').visit(f)
+ mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = transform.PartitionGraph()(mod)
+ assert relay.alpha_equal(mod, expected())
+
ref_mod = tvm.IRModule()
- ref_mod['main'] = f
+ ref_mod["main"] = get_func()
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
@@ -427,6 +465,175 @@ def test_extern_dnnl_mobilenet():
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
+def test_function_lifting():
+ def partition():
+ data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+ weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
+ bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
+ bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
+ bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
+ bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+
+ conv = relay.nn.conv2d(
+ data=data,
+ weight=weight,
+ kernel_size=(3, 3),
+ channels=16,
+ padding=(1, 1))
+ bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean,
+ bn_mvar)
+
+ func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mmean,
+ bn_mvar], bn_output.astuple())
+ mod = tvm.IRModule()
+ mod["main"] = func
+ op_list = ["nn.batch_norm", "nn.conv2d"]
+ mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
+
+ opt_pass = transform.Sequential([
+ transform.InferType(),
+ transform.PartitionGraph(),
+ transform.SimplifyInference(),
+ transform.FoldConstant(),
+ transform.AlterOpLayout(),
+ ])
+
+ with relay.build_config(opt_level=3):
+ mod = opt_pass(mod)
+
+ return mod
+
+ def expected():
+ # function for batch_norm
+ data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
+ "float32"))
+ mod = tvm.IRModule()
+ bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
+ bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
+ bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
+ bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
+
+ bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+ func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
+ bn.astuple())
+ func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+ func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
+ func0 = func0.set_attribute("Compiler",
+ tvm.tir.StringImm("test_compiler"))
+ func0 = func0.set_attribute("ExternalSymbol",
+ tvm.tir.StringImm("test_compiler_0"))
+ gv0 = relay.GlobalVar("test_compiler_0")
+ mod[gv0] = func0
+
+ # function for conv2d
+ data1 = relay.var("data1", relay.TensorType((1, 3, 224, 224), "float32"))
+ weight1 = relay.var("weight1", relay.TensorType((16, 3, 3, 3), "float32"))
+ conv = relay.nn.conv2d(
+ data=data1,
+ weight=weight1,
+ kernel_size=(3, 3),
+ channels=16,
+ padding=(1, 1))
+ func1 = relay.Function([data1, weight1], conv)
+ func1 = func1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+ func1 = func1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
+ func1 = func1.set_attribute("Compiler",
+ tvm.tir.StringImm("test_compiler"))
+ func1 = func1.set_attribute("ExternalSymbol",
+ tvm.tir.StringImm("test_compiler_1"))
+ gv1 = relay.GlobalVar("test_compiler_1")
+ mod[gv1] = func1
+
+ # main function
+ data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+ weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
+ bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
+ bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
+ bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
+ bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+
+ call1 = gv1(data, weight)
+ call0 = gv0(call1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
+ mod["main"] = relay.Function([data, weight, bn_gamma0, bn_beta0, bn_mmean0,
+ bn_mvar0], call0)
+ mod = transform.InferType()(mod)
+ return mod
+
+ partitioned = partition()
+ ref_mod = expected()
+ assert relay.analysis.alpha_equal(partitioned, ref_mod)
+
+
+def test_function_lifting_inline():
+ def partition():
+ data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
+ bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
+ bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
+ bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
+ bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+
+ bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean,
+ bn_mvar)
+
+ func = relay.Function([data, bn_gamma, bn_beta, bn_mmean,
+ bn_mvar], bn_output.astuple())
+ mod = tvm.IRModule()
+ mod["main"] = func
+ op_list = ["nn.batch_norm", "nn.conv2d"]
+ mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
+
+ opt_pass = transform.Sequential([
+ transform.InferType(),
+ transform.PartitionGraph(),
+ transform.SimplifyInference(),
+ transform.FoldConstant(),
+ transform.AlterOpLayout(),
+ transform.Inline(),
+ ])
+
+ with relay.build_config(opt_level=3):
+ mod = opt_pass(mod)
+
+ return mod
+
+ def expected():
+ # function for batch_norm
+ data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
+ "float32"))
+ mod = tvm.IRModule()
+ bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
+ bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
+ bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
+ bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
+
+ bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+ func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
+ bn.astuple())
+ func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+ func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
+ func0 = func0.set_attribute("Compiler",
+ tvm.tir.StringImm("test_compiler"))
+ func0 = func0.set_attribute("ExternalSymbol",
+ tvm.tir.StringImm("test_compiler_0"))
+
+ # main function
+ data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
+ bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
+ bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
+ bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
+ bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+
+ call0 = func0(data, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
+ mod["main"] = relay.Function([data, bn_gamma0, bn_beta0, bn_mmean0,
+ bn_mvar0], call0)
+ mod = transform.InferType()(mod)
+ return mod
+
+ partitioned = partition()
+ ref_mod = expected()
+ assert relay.analysis.alpha_equal(partitioned, ref_mod)
+
+
if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
@@ -434,3 +641,5 @@ if __name__ == "__main__":
test_extern_ccompiler()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
+ test_function_lifting()
+ test_function_lifting_inline()