You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/05/25 17:47:41 UTC
[tvm] branch main updated: [Relay][dismantler] Added handling of
packed func (#8004)
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new aefa0c8 [Relay][dismantler] Added handling of packed func (#8004)
aefa0c8 is described below
commit aefa0c85e46fc5ed15e71805f52bf7be6e238e33
Author: Dmitriy Smirnov <dm...@arm.com>
AuthorDate: Tue May 25 18:47:20 2021 +0100
[Relay][dismantler] Added handling of packed func (#8004)
Added handling of CallNode objects created via packed
functions invocation + test cases.
Change-Id: I5374abc59a3b0f79f27364c45f1a5789536df940
---
include/tvm/relay/expr.h | 6 +++
src/relay/ir/expr.cc | 34 ++++++++++---
tests/cpp/relay_dismantler_test.cc | 77 +++++++++++++++++++++++++++++-
tests/python/relay/test_ir_text_printer.py | 12 +++++
4 files changed, 121 insertions(+), 8 deletions(-)
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 17718d1..daad851 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -227,6 +227,11 @@ class Var : public Expr {
class Call;
/*! \brief Call container. */
class CallNode : public ExprNode {
+ protected:
+ // CallNode uses own deleter to indirectly call non-recursive destructor
+ Object::FDeleter saved_deleter_;
+ static void Deleter_(Object* ptr);
+
public:
/*!
* \brief The operator(function) being invoked
@@ -290,6 +295,7 @@ class CallNode : public ExprNode {
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
+ friend class Call;
};
class Call : public Expr {
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 62ff0b1..3b3c879 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -115,6 +115,8 @@ Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span s
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
n->span = std::move(span);
+ n->saved_deleter_ = n->deleter_;
+ n->deleter_ = CallNode::Deleter_;
data_ = std::move(n);
}
@@ -288,16 +290,24 @@ inline void Dismantle(const Expr& expr) {
// special handling
if (const CallNode* op = node.as<CallNode>()) {
- for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
- fpush_to_stack(*it);
+ // do not process args if used elsewhere
+ if (op->args.use_count() < 2) {
+ for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
+ fpush_to_stack(*it);
+ }
}
- fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
- for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
- fpush_to_stack(*it);
+ // do not process fields if used elsewhere
+ if (op->fields.use_count() < 2) {
+ for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
+ fpush_to_stack(*it);
+ }
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
- fpush_to_stack(op->tuple);
+ // do not process tuple if used elsewhere
+ if (op->tuple.use_count() < 2) {
+ fpush_to_stack(op->tuple);
+ }
}
}
}
@@ -306,7 +316,6 @@ inline void Dismantle(const Expr& expr) {
/*
* Non-recursive destructor
*/
-
Call::~Call() {
// attempt to dismantle if referenced one or zero times
if (this->use_count() < 2) {
@@ -316,5 +325,16 @@ Call::~Call() {
}
}
+/*
+ * CallNode's deleter
+ */
+void CallNode::Deleter_(Object* ptr) {
+ auto p = reinterpret_cast<CallNode*>(ptr);
+ // resore original deleter
+ p->deleter_ = p->saved_deleter_;
+ // create Call reference in order to invoke ~Call
+ auto c = GetRef<Call>(p);
+}
+
} // namespace relay
} // namespace tvm
diff --git a/tests/cpp/relay_dismantler_test.cc b/tests/cpp/relay_dismantler_test.cc
index d5c089b..8c74d41 100644
--- a/tests/cpp/relay_dismantler_test.cc
+++ b/tests/cpp/relay_dismantler_test.cc
@@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
-
#include <gtest/gtest.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type_functor.h>
@@ -38,6 +37,8 @@
#include <tvm/topi/broadcast.h>
#include <tvm/topi/generic/injective.h>
+#include <memory>
+
using namespace tvm;
using namespace tvm::relay;
@@ -69,6 +70,80 @@ TEST(Relay, OutOfStack_cast) {
ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
}
+TEST(Relay, OutOfStack_packed_func) {
+ constexpr int len = 1e6;
+ auto foo = [] {
+ auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
+ auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
+ auto add_func = tvm::runtime::Registry::Get("relay.op._make.add");
+ auto y = (*add_func)(x, one);
+ for (int i = 0; i < len; ++i) {
+ y = (*add_func)(y, one);
+ }
+
+ // check if still reachable
+ int k = 0;
+ Expr e = y;
+ while (e.defined() && e.as<CallNode>() != nullptr) {
+ e = e.as<CallNode>()->args[0];
+ ++k;
+ }
+ ASSERT_EQ(len + 1, k);
+ };
+ ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
+}
+
+TEST(Relay, CallNodeSharedArgs) {
+ auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
+ auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
+ auto relu_op = relay::Op::Get("nn.relu");
+ Call y = relay::Call(relu_op, {x}, Attrs(), {});
+ y = relay::Call(relu_op, {y}, Attrs(), {});
+ ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size());
+ y = relay::Call(y.get()->op, y.get()->args, y.get()->attrs, y.get()->type_args);
+ ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size());
+}
+
+TEST(Relay, TupleSharedFields) {
+ auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
+ auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
+ auto relu_op = relay::Op::Get("nn.relu");
+ Expr y = relay::Call(relu_op, {x}, Attrs(), {});
+ y = relay::Call(relu_op, {y}, Attrs(), {});
+ {
+ Expr y1 = relay::Tuple(y.as<CallNode>()->args);
+ Expr y2 = relay::Tuple(y.as<CallNode>()->args);
+
+ y1 = relay::Call(relu_op, {y1});
+ y2 = relay::Call(relu_op, {y2});
+ y = y1;
+ }
+ ASSERT_EQ(1, y.as<CallNode>()->args[0].as<TupleNode>()->fields[0].as<CallNode>()->args.size());
+}
+
+TEST(Relay, TupleiGetItemSharedTuple) {
+ auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
+ auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
+ auto relu_op = relay::Op::Get("nn.relu");
+ Expr y = relay::Call(relu_op, {x}, Attrs(), {});
+ y = relay::Tuple({y});
+ {
+ Expr y1 = relay::TupleGetItem(y, 0);
+ Expr y2 = relay::TupleGetItem(y, 0);
+
+ y1 = relay::Call(relu_op, {y1});
+ y2 = relay::Call(relu_op, {y2});
+ y = y1;
+ }
+ ASSERT_EQ(1, y.as<CallNode>()
+ ->args[0]
+ .as<TupleGetItemNode>()
+ ->tuple.as<TupleNode>()
+ ->fields[0]
+ .as<CallNode>()
+ ->args.size());
+}
+
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py
index b2ae286..4968660 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -30,6 +30,7 @@ SEMVER = '#[version = "0.0.5"]\n'
def astext(program, unify_free_vars=False):
text = program.astext()
+
print(text)
if isinstance(program, Expr):
roundtrip_program = tvm.parser.parse_expr(text)
@@ -47,6 +48,17 @@ def show(text):
print(text)
+def test_large_graph():
+ x = relay.var("x", shape=(3, 2))
+ y = relay.var("y")
+ one = relay.const(10e10, dtype="float32")
+ z = relay.add(x, one)
+ for i in range(int(1e6)):
+ z = relay.add(z, one)
+ f = relay.Function([x, y], z)
+ show(astext(f))
+
+
def test_func():
x = relay.var("x", shape=(3, 2))
y = relay.var("y")