You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/11/01 20:10:36 UTC
[tvm] branch main updated: [VMCompiler] Support shape func lowering
for nested function call (#9405)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1a3dac7 [VMCompiler] Support shape func lowering for nested function call (#9405)
1a3dac7 is described below
commit 1a3dac71407151681f6edbaf5def5788bdd2d8d2
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Nov 2 05:09:56 2021 +0900
[VMCompiler] Support shape func lowering for nested function call (#9405)
* Support nested function call in shape func lowering
* add test
---
src/relay/backend/te_compiler_cache.cc | 21 ++++++++++++++++++---
tests/python/relay/test_vm.py | 26 ++++++++++++++++++++++++++
2 files changed, 44 insertions(+), 3 deletions(-)
diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index be5b172..3970b0e 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -466,8 +466,13 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
auto var = GetRef<Var>(var_node);
- auto it = param_states_.find(var);
- if (it == param_states_.end()) {
+ auto it = param_arg_map_.find(var);
+ if (it != param_arg_map_.end()) {
+ // This var is a parameter of a nested function. Visit the corresponding argument in the
+ // function call site.
+ return VisitExpr(it->second);
+ }
+ if (param_states_.find(var) == param_states_.end()) {
LOG(FATAL) << "Unexpected free variable " << var->name_hint();
return {};
} else {
@@ -542,6 +547,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
}
Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
+ if (auto* func = call_node->op.as<FunctionNode>()) {
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ param_arg_map_[func->params[i]] = call_node->args[i];
+ }
+ return VisitExpr(func->body);
+ }
static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc");
static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
@@ -601,7 +612,7 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
}
Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
- LOG(FATAL) << "Do not support sub function";
+ LOG(FATAL) << "Nested functions are not allowed to be visited.";
return Array<te::Tensor>();
}
@@ -644,6 +655,10 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
std::vector<bool> data_dependents_per_input_;
/*! \brief Scalars used in the shape function */
Array<te::Tensor> scalars_;
+ /*! \brief Map from parameters of a nested function to corresponding arguments in a function
+ * call site.
+ */
+ std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_arg_map_;
};
CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 8ec4152..52a2fef 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -32,6 +32,8 @@ from tvm import rpc
import tvm.testing
from tvm.relay.transform import InferType
from tvm.relay.testing import mlp
+from tvm.relay.dataflow_pattern import wildcard, is_op
+from tvm.relay.backend.vm import VMCompiler
def check_result(target, dev, args, expected_result, mod=None):
@@ -973,6 +975,30 @@ def test_benchmark_end_to_end_rpc():
assert result.mean > 0
+def test_shape_func_nested_function():
+ data_shape = (relay.Any(), 16)
+ weight_shape = (relay.Any(), 16)
+
+ dense = relay.nn.dense(
+ relay.var("data", shape=data_shape), relay.var("weight", shape=weight_shape)
+ )
+ mod = tvm.IRModule.from_expr(dense)
+
+ patterns = [("test.dense", is_op("nn.dense")(wildcard(), wildcard()))]
+ passes = tvm.transform.Sequential(
+ [
+ relay.transform.MergeComposite(patterns),
+ relay.transform.AnnotateTarget(["test"]),
+ relay.transform.PartitionGraph(),
+ ]
+ )
+
+ mod = passes(mod)
+
+ compiler = VMCompiler()
+ compiler.lower(mod, "llvm")
+
+
if __name__ == "__main__":
import sys