You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/11/01 20:10:36 UTC

[tvm] branch main updated: [VMCompiler] Support shape func lowering for nested function call (#9405)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1a3dac7  [VMCompiler] Support shape func lowering for nested function call (#9405)
1a3dac7 is described below

commit 1a3dac71407151681f6edbaf5def5788bdd2d8d2
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Nov 2 05:09:56 2021 +0900

    [VMCompiler] Support shape func lowering for nested function call (#9405)
    
    * Support nested function call in shape func lowering
    
    * add test
---
 src/relay/backend/te_compiler_cache.cc | 21 ++++++++++++++++++---
 tests/python/relay/test_vm.py          | 26 ++++++++++++++++++++++++++
 2 files changed, 44 insertions(+), 3 deletions(-)

diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index be5b172..3970b0e 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -466,8 +466,13 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
 
   Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
     auto var = GetRef<Var>(var_node);
-    auto it = param_states_.find(var);
-    if (it == param_states_.end()) {
+    auto it = param_arg_map_.find(var);
+    if (it != param_arg_map_.end()) {
+      // This var is a parameter of a nested function. Visit the corresponding argument in the
+      // function call site.
+      return VisitExpr(it->second);
+    }
+    if (param_states_.find(var) == param_states_.end()) {
       LOG(FATAL) << "Unexpected free variable " << var->name_hint();
       return {};
     } else {
@@ -542,6 +547,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
   }
 
   Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+    if (auto* func = call_node->op.as<FunctionNode>()) {
+      for (size_t i = 0; i < func->params.size(); ++i) {
+        param_arg_map_[func->params[i]] = call_node->args[i];
+      }
+      return VisitExpr(func->body);
+    }
     static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc");
     static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
     ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
@@ -601,7 +612,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
   }
 
   Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
-    LOG(FATAL) << "Do not support sub function";
+    LOG(FATAL) << "Nested functions are not allowed to be visited.";
     return Array<te::Tensor>();
   }
 
@@ -644,6 +655,10 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
   std::vector<bool> data_dependents_per_input_;
   /*! \brief Scalars used in the shape function */
   Array<te::Tensor> scalars_;
+  /*! \brief Map from parameters of a nested function to corresponding arguments in a function
+   * call site.
+   */
+  std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_arg_map_;
 };
 
 CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 8ec4152..52a2fef 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -32,6 +32,8 @@ from tvm import rpc
 import tvm.testing
 from tvm.relay.transform import InferType
 from tvm.relay.testing import mlp
+from tvm.relay.dataflow_pattern import wildcard, is_op
+from tvm.relay.backend.vm import VMCompiler
 
 
 def check_result(target, dev, args, expected_result, mod=None):
@@ -973,6 +975,30 @@ def test_benchmark_end_to_end_rpc():
     assert result.mean > 0
 
 
+def test_shape_func_nested_function():
+    data_shape = (relay.Any(), 16)
+    weight_shape = (relay.Any(), 16)
+
+    dense = relay.nn.dense(
+        relay.var("data", shape=data_shape), relay.var("weight", shape=weight_shape)
+    )
+    mod = tvm.IRModule.from_expr(dense)
+
+    patterns = [("test.dense", is_op("nn.dense")(wildcard(), wildcard()))]
+    passes = tvm.transform.Sequential(
+        [
+            relay.transform.MergeComposite(patterns),
+            relay.transform.AnnotateTarget(["test"]),
+            relay.transform.PartitionGraph(),
+        ]
+    )
+
+    mod = passes(mod)
+
+    compiler = VMCompiler()
+    compiler.lower(mod, "llvm")
+
+
 if __name__ == "__main__":
     import sys