You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/03/07 04:30:21 UTC

[incubator-tvm] branch master updated: [relay][external codegen] outline and inline lifted functions for external codegen (#4996)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 28ee806  [relay][external codegen] outline and inline lifted functions for external codegen (#4996)
28ee806 is described below

commit 28ee806dcbd803f4079365dd308a673bd1a89588
Author: Zhi <51...@users.noreply.github.com>
AuthorDate: Fri Mar 6 20:30:13 2020 -0800

    [relay][external codegen] outline and inline lifted functions for external codegen (#4996)
    
    * outline and inline lifted functions for external codegen
    
    * add batch_norm test
    
    * test batch_norm inline
---
 src/relay/backend/build_module.cc               |   7 +
 src/relay/backend/vm/compiler.cc                |   7 +
 src/relay/backend/vm/inline_primitives.cc       |   1 +
 src/relay/backend/vm/lambda_lift.cc             |   1 +
 src/relay/pass/partition_graph.cc               |  79 ++++----
 src/relay/pass/to_a_normal_form.cc              |   3 +
 tests/python/relay/test_pass_partition_graph.py | 257 +++++++++++++++++++++---
 7 files changed, 294 insertions(+), 61 deletions(-)

diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index 61ec281..41833c4 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode {
     // Fuse the operations if it is needed.
     relay_module = transform::FuseOps()(relay_module);
     relay_module = transform::InferType()(relay_module);
+    // Inline the functions that have been lifted by the module scope.
+    //
+    // TODO(@zhiics) Note that we need to be careful about the subgraphs with
+    // global function calls. We should make sure that these callees are also
+    // inline functions. However, this should be very unlikely for accelerators
+    // and vendor-provided libraries. So we don't handle for now.
+    relay_module = transform::Inline()(relay_module);
     CHECK(relay_module.defined());
 
     return relay_module;
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 2129b64..fc52a8e 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
   pass_seqs.push_back(transform::LambdaLift());
   pass_seqs.push_back(transform::InlinePrimitives());
 
+  // Inline the functions that are lifted to the module scope. We perform this
+  // pass after all other optimization passes but before the memory allocation
+  // pass. This is because memory allocation pass will insert `invoke_tvm_op`
+  // and we use these ops to invoke the symbols in the module generated by
+  // external codegen.
+  pass_seqs.push_back(transform::Inline());
+
   // Manifest the allocations.
   pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
   // Compute away possibly introduced constant computation.
diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc
index 1d6ba4a..25a9bcd 100644
--- a/src/relay/backend/vm/inline_primitives.cc
+++ b/src/relay/backend/vm/inline_primitives.cc
@@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator {
       auto global = pair.first;
       auto base_func = pair.second;
       if (auto* n = base_func.as<FunctionNode>()) {
+        if (!n->UseDefaultCompiler()) continue;
         auto func = GetRef<Function>(n);
 
         DLOG(INFO) << "Before inlining primitives: " << global
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
index 1cf671a..5cf66c5 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator {
     auto glob_funcs = module_->functions;
     for (auto pair : glob_funcs) {
       if (auto* n = pair.second.as<FunctionNode>()) {
+        if (!n->UseDefaultCompiler()) continue;
         auto func = GetRef<Function>(n);
         func = FunctionNode::make(func->params,
                                   VisitExpr(func->body),
diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc
index e58bf61..d9600bd 100644
--- a/src/relay/pass/partition_graph.cc
+++ b/src/relay/pass/partition_graph.cc
@@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor {
  */
 class Partitioner : public ExprMutator {
  public:
+  explicit Partitioner(const IRModule& module) : module_(module) {}
+
   std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
     for (auto candidate : this->subgraphs_) {
       if (candidate->nodes.find(node) != candidate->nodes.end()) {
@@ -163,8 +165,10 @@ class Partitioner : public ExprMutator {
 
       // Replace the begin annotation with an external call input variable.
       auto compiler_attrs = call->attrs.as<CompilerAttrs>();
+      // The type of the created variable is the same as the compiler_begin
+      // node.
       auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
-                               input_expr->checked_type_);
+                               call->checked_type_);
 
       // Find the corresponding subgraph and add the argument.
       auto subgraph = GetSubgraph(GetRef<Call>(call));
@@ -182,7 +186,7 @@ class Partitioner : public ExprMutator {
 
       auto compiler_attrs = call->attrs.as<CompilerAttrs>();
 
-      // Check if the argument already belongs to an exist subgraph
+      // Check if the argument already belongs to an existing subgraph
       auto subgraph = GetSubgraph(call->args[0]);
       if (!subgraph) {
         auto ret = this->subgraphs_.emplace(std::make_shared<Subgraph>());
@@ -207,16 +211,28 @@ class Partitioner : public ExprMutator {
       }
 
       auto subgraph_func =
-          FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs());
+          FunctionNode::make(params, input, call->checked_type_, {}, Attrs());
 
-      Expr arg0 = call->args[0];
       std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
       subgraph_func =
           FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name));
       subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
       subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
                                       tvm::tir::StringImmNode::make(compiler_attrs->compiler));
-      return CallNode::make(subgraph_func, args);
+      subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1));
+      CHECK(!module_->ContainGlobalVar(name))
+          << "Global function " << name << " already exists";
+      // Create a global function and add it to the IRModule for the subgraph.
+      // This way we lift the functions that should be handled by external
+      // codegen to the module scope and rely on the pass manager to prevent relay
+      // function level passes (i.e. simplify inference and fusion) optimizing it.
+      GlobalVar glob_func(name);
+      module_->Add(glob_func, subgraph_func);
+      // The return type of callnode is the same as the type of the
+      // compiler_end node.
+      auto ret = CallNode::make(glob_func, args);
+      ret->checked_type_ = call->checked_type_;
+      return std::move(ret);
     }
   }
 
@@ -330,50 +346,39 @@ class Partitioner : public ExprMutator {
     }
   }
 
+  IRModule Partition() {
+    auto glob_funcs = module_->functions;
+    for (const auto& pair : glob_funcs) {
+      if (auto* fn = pair.second.as<FunctionNode>()) {
+        auto func = GetRef<Function>(fn);
+        func = FunctionNode::make(func->params,
+                                  VisitExpr(func->body),
+                                  func->ret_type,
+                                  func->type_params,
+                                  func->attrs);
+        module_->Update(pair.first, func);
+      }
+    }
+    return module_;
+  }
+
  private:
   int var_id_{0};
   int subgraph_id_{0};
   std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
+  IRModule module_;
 };
 
-/*!
- * \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
- * the same codegen backend. This reduces rounds trips between TVM and external
- * backends. Likely we can borrow some ideas from operator fusion.
- *
- * For example, sg1 and sg2 should be combined if they belong to the same
- * codegen tool in the following case.
- *
- *      op1
- *     /   \
- *   sg1   sg2
- *
- *       |
- *      \|/
- *
- *      op1
- *       |
- *    sg1_sg2
- *
- * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
- * inputs that obtained from the tuple.
- */
-
-Expr PartitionGraph(const Expr& expr) {
-  Partitioner part;
-  return part.Mutate(expr);
-}
-
 }  // namespace partitioning
 
 namespace transform {
 
 Pass PartitionGraph() {
-  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
-      [=](Function f, IRModule m, PassContext pc) {
-        return Downcast<Function>(partitioning::PartitionGraph(f));
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
+      [=](IRModule m, PassContext pc) {
+        return partitioning::Partitioner(m).Partition();
       };
-  auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {});
+  auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
   return Sequential({partitioned, InferType()});
 }
 
diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc
index 9322e49..c75afd1 100644
--- a/src/relay/pass/to_a_normal_form.cc
+++ b/src/relay/pass/to_a_normal_form.cc
@@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) {
   auto funcs = m->functions;
   for (const auto& it : funcs) {
     CHECK_EQ(FreeVars(it.second).size(), 0);
+    if (const auto* n = it.second.as<FunctionNode>()) {
+      if (!n->UseDefaultCompiler()) continue;
+    }
     Expr ret =
       TransformF([&](const Expr& e) {
         return ToANormalFormAux(e);
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 9c3228f..209376a 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -18,14 +18,12 @@
 import os
 import sys
 import numpy as np
-import pytest
 
 import tvm
-from tvm import te
 import tvm.relay.testing
-import tvm.relay.transform as transform
 from tvm import relay
 from tvm import runtime
+from tvm.relay import transform
 from tvm.contrib import util
 from tvm.relay.annotation import compiler_begin, compiler_end
 from tvm.relay.expr_functor import ExprMutator
@@ -189,7 +187,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         return lib
 
     def check_vm_result():
-        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+        with relay.build_config(opt_level=3):
             exe = relay.vm.compile(mod, target=target, params=params)
         code, lib = exe.save()
         lib = update_lib(lib)
@@ -200,7 +198,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
         tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
 
     def check_graph_runtime_result():
-        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+        with relay.build_config(opt_level=3):
             json, lib, param = relay.build(mod, target=target, params=params)
         lib = update_lib(lib)
         rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
@@ -297,6 +295,7 @@ def test_extern_ccompiler_single_op():
 
 def test_extern_ccompiler_default_ops():
     def expected():
+        mod = tvm.IRModule()
         x = relay.var("x", shape=(8, 8))
         y = relay.var("y", shape=(8, 8))
         x0 = relay.var("x0", shape=(8, 8))
@@ -305,11 +304,14 @@ def test_extern_ccompiler_default_ops():
         # Function that uses C compiler
         func = relay.Function([x0, y0], add)
         func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+        func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
         func = func.set_attribute("Compiler",
                                   tvm.tir.StringImm("ccompiler"))
         func = func.set_attribute("ExternalSymbol",
                                   tvm.tir.StringImm("ccompiler_0"))
-        add_call = relay.Call(func, [x, y])
+        glb_0 = relay.GlobalVar("ccompiler_0")
+        mod[glb_0] = func
+        add_call = relay.Call(glb_0, [x, y])
         # Function that uses default compiler. Ops are fused in this function.
         p0 = relay.var("p0", shape=(8, 8))
         log = relay.log(p0)
@@ -320,7 +322,6 @@ def test_extern_ccompiler_default_ops():
                                               tvm.tir.IntImm("int32", 1))
         fused_call = relay.Call(fused_func, [add_call])
         main = relay.Function([x, y], fused_call)
-        mod = tvm.IRModule()
         mod["main"] = main
         return mod
 
@@ -371,28 +372,65 @@ def test_extern_dnnl():
     dtype = 'float32'
     ishape = (1, 32, 14, 14)
     w1shape = (32, 1, 3, 3)
-    data = relay.var('data', shape=(ishape), dtype=dtype)
-    weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
-    depthwise_conv2d_1 = relay.nn.conv2d(data,
-                                         weight1,
-                                         kernel_size=(3, 3),
-                                         padding=(1, 1),
-                                         groups=32)
-    depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
-                                         weight1,
-                                         kernel_size=(3, 3),
-                                         padding=(1, 1),
-                                         groups=32)
-    out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
-
-    f = relay.Function([data, weight1], out)
+
+    def expected():
+        data0 = relay.var("data", shape=(ishape), dtype=dtype)
+        input0 = relay.var("input0", shape=(w1shape), dtype=dtype)
+        input1 = relay.var("input1", shape=(w1shape), dtype=dtype)
+        depthwise_conv2d_1 = relay.nn.conv2d(data0,
+                                             input0,
+                                             kernel_size=(3, 3),
+                                             padding=(1, 1),
+                                             groups=32)
+        depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+                                             input1,
+                                             kernel_size=(3, 3),
+                                             padding=(1, 1),
+                                             groups=32)
+        out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+        func = relay.Function([data0, input0, input1], out)
+        func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+        func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
+        func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl"))
+        func = func.set_attribute("ExternalSymbol",
+                                  tvm.tir.StringImm("dnnl_0"))
+        glb_var = relay.GlobalVar("dnnl_0")
+        mod = tvm.IRModule()
+        mod[glb_var] = func
+
+        data = relay.var("data", shape=(ishape), dtype=dtype)
+        weight = relay.var("input", shape=(w1shape), dtype=dtype)
+        main_f = relay.Function([data, weight], glb_var(data, weight, weight))
+        mod["main"] = main_f
+
+        return mod
+
+    def get_func():
+        data = relay.var("data", shape=(ishape), dtype=dtype)
+        weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
+        depthwise_conv2d_1 = relay.nn.conv2d(data,
+                                             weight1,
+                                             kernel_size=(3, 3),
+                                             padding=(1, 1),
+                                             groups=32)
+        depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
+                                             weight1,
+                                             kernel_size=(3, 3),
+                                             padding=(1, 1),
+                                             groups=32)
+        out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
+
+        return relay.Function([data, weight1], out)
 
     mod = tvm.IRModule()
-    mod['main'] = WholeGraphAnnotator('dnnl').visit(f)
+    mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
     mod = transform.PartitionGraph()(mod)
 
+    assert relay.alpha_equal(mod, expected())
+
     ref_mod = tvm.IRModule()
-    ref_mod['main'] = f
+    ref_mod["main"] = get_func()
 
     i_data = np.random.uniform(0, 1, ishape).astype(dtype)
     w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
@@ -427,6 +465,175 @@ def test_extern_dnnl_mobilenet():
                  (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
 
 
+def test_function_lifting():
+    def partition():
+        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+        weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
+        bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
+        bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
+        bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
+        bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+
+        conv = relay.nn.conv2d(
+            data=data,
+            weight=weight,
+            kernel_size=(3, 3),
+            channels=16,
+            padding=(1, 1))
+        bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta, bn_mmean,
+                                        bn_mvar)
+
+        func = relay.Function([data, weight, bn_gamma, bn_beta, bn_mmean,
+                               bn_mvar], bn_output.astuple())
+        mod = tvm.IRModule()
+        mod["main"] = func
+        op_list = ["nn.batch_norm", "nn.conv2d"]
+        mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
+
+        opt_pass = transform.Sequential([
+            transform.InferType(),
+            transform.PartitionGraph(),
+            transform.SimplifyInference(),
+            transform.FoldConstant(),
+            transform.AlterOpLayout(),
+        ])
+
+        with relay.build_config(opt_level=3):
+            mod = opt_pass(mod)
+
+        return mod
+
+    def expected():
+        # function for batch_norm
+        data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
+                                                    "float32"))
+        mod = tvm.IRModule()
+        bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
+        bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
+        bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
+        bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
+
+        bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+        func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
+                               bn.astuple())
+        func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+        func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
+        func0 = func0.set_attribute("Compiler",
+                                    tvm.tir.StringImm("test_compiler"))
+        func0 = func0.set_attribute("ExternalSymbol",
+                                    tvm.tir.StringImm("test_compiler_0"))
+        gv0 = relay.GlobalVar("test_compiler_0")
+        mod[gv0] = func0
+
+        # function for conv2d
+        data1 = relay.var("data1", relay.TensorType((1, 3, 224, 224), "float32"))
+        weight1 = relay.var("weight1", relay.TensorType((16, 3, 3, 3), "float32"))
+        conv = relay.nn.conv2d(
+            data=data1,
+            weight=weight1,
+            kernel_size=(3, 3),
+            channels=16,
+            padding=(1, 1))
+        func1 = relay.Function([data1, weight1], conv)
+        func1 = func1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+        func1 = func1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
+        func1 = func1.set_attribute("Compiler",
+                                    tvm.tir.StringImm("test_compiler"))
+        func1 = func1.set_attribute("ExternalSymbol",
+                                    tvm.tir.StringImm("test_compiler_1"))
+        gv1 = relay.GlobalVar("test_compiler_1")
+        mod[gv1] = func1
+
+        # main function
+        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+        weight = relay.var("weight", relay.TensorType((16, 3, 3, 3), "float32"))
+        bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
+        bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
+        bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
+        bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+
+        call1 = gv1(data, weight)
+        call0 = gv0(call1, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
+        mod["main"] = relay.Function([data, weight, bn_gamma0, bn_beta0, bn_mmean0,
+                                      bn_mvar0], call0)
+        mod = transform.InferType()(mod)
+        return mod
+
+    partitioned = partition()
+    ref_mod = expected()
+    assert relay.analysis.alpha_equal(partitioned, ref_mod)
+
+
+def test_function_lifting_inline():
+    def partition():
+        data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
+        bn_gamma = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
+        bn_beta = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
+        bn_mmean = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
+        bn_mvar = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+
+        bn_output = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean,
+                                        bn_mvar)
+
+        func = relay.Function([data, bn_gamma, bn_beta, bn_mmean,
+                               bn_mvar], bn_output.astuple())
+        mod = tvm.IRModule()
+        mod["main"] = func
+        op_list = ["nn.batch_norm", "nn.conv2d"]
+        mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
+
+        opt_pass = transform.Sequential([
+            transform.InferType(),
+            transform.PartitionGraph(),
+            transform.SimplifyInference(),
+            transform.FoldConstant(),
+            transform.AlterOpLayout(),
+            transform.Inline(),
+        ])
+
+        with relay.build_config(opt_level=3):
+            mod = opt_pass(mod)
+
+        return mod
+
+    def expected():
+        # function for batch_norm
+        data0 = relay.var("data0", relay.TensorType((1, 16, 224, 224),
+                                                    "float32"))
+        mod = tvm.IRModule()
+        bn_gamma = relay.var("bn_gamma1", relay.TensorType((16, ), "float32"))
+        bn_beta = relay.var("bn_beta1", relay.TensorType((16, ), "float32"))
+        bn_mmean = relay.var("bn_mean1", relay.TensorType((16, ), "float32"))
+        bn_mvar = relay.var("bn_var1", relay.TensorType((16, ), "float32"))
+
+        bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
+        func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
+                               bn.astuple())
+        func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
+        func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
+        func0 = func0.set_attribute("Compiler",
+                                    tvm.tir.StringImm("test_compiler"))
+        func0 = func0.set_attribute("ExternalSymbol",
+                                    tvm.tir.StringImm("test_compiler_0"))
+
+        # main function
+        data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
+        bn_gamma0 = relay.var("bn_gamma", relay.TensorType((16, ), "float32"))
+        bn_beta0 = relay.var("bn_beta", relay.TensorType((16, ), "float32"))
+        bn_mmean0 = relay.var("bn_mean", relay.TensorType((16, ), "float32"))
+        bn_mvar0 = relay.var("bn_var", relay.TensorType((16, ), "float32"))
+
+        call0 = func0(data, bn_gamma0, bn_beta0, bn_mmean0, bn_mvar0)
+        mod["main"] = relay.Function([data, bn_gamma0, bn_beta0, bn_mmean0,
+                                      bn_mvar0], call0)
+        mod = transform.InferType()(mod)
+        return mod
+
+    partitioned = partition()
+    ref_mod = expected()
+    assert relay.analysis.alpha_equal(partitioned, ref_mod)
+
+
 if __name__ == "__main__":
     test_multi_node_compiler()
     test_extern_ccompiler_single_op()
@@ -434,3 +641,5 @@ if __name__ == "__main__":
     test_extern_ccompiler()
     test_extern_dnnl()
     test_extern_dnnl_mobilenet()
+    test_function_lifting()
+    test_function_lifting_inline()