You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/05/25 17:47:41 UTC

[tvm] branch main updated: [Relay][dismantler] Added handling of packed func (#8004)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new aefa0c8  [Relay][dismantler] Added handling of packed func (#8004)
aefa0c8 is described below

commit aefa0c85e46fc5ed15e71805f52bf7be6e238e33
Author: Dmitriy Smirnov <dm...@arm.com>
AuthorDate: Tue May 25 18:47:20 2021 +0100

    [Relay][dismantler] Added handling of packed func (#8004)
    
    Added handling of CallNode objects created via packed
    functions invocation + test cases.
    
    Change-Id: I5374abc59a3b0f79f27364c45f1a5789536df940
---
 include/tvm/relay/expr.h                   |  6 +++
 src/relay/ir/expr.cc                       | 34 ++++++++++---
 tests/cpp/relay_dismantler_test.cc         | 77 +++++++++++++++++++++++++++++-
 tests/python/relay/test_ir_text_printer.py | 12 +++++
 4 files changed, 121 insertions(+), 8 deletions(-)

diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 17718d1..daad851 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -227,6 +227,11 @@ class Var : public Expr {
 class Call;
 /*! \brief Call container. */
 class CallNode : public ExprNode {
+ protected:
+  // CallNode uses own deleter to indirectly call non-recursive destructor
+  Object::FDeleter saved_deleter_;
+  static void Deleter_(Object* ptr);
+
  public:
   /*!
    * \brief The operator(function) being invoked
@@ -290,6 +295,7 @@ class CallNode : public ExprNode {
 
   static constexpr const char* _type_key = "relay.Call";
   TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
+  friend class Call;
 };
 
 class Call : public Expr {
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 62ff0b1..3b3c879 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -115,6 +115,8 @@ Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span s
   n->attrs = std::move(attrs);
   n->type_args = std::move(type_args);
   n->span = std::move(span);
+  n->saved_deleter_ = n->deleter_;
+  n->deleter_ = CallNode::Deleter_;
   data_ = std::move(n);
 }
 
@@ -288,16 +290,24 @@ inline void Dismantle(const Expr& expr) {
 
       // special handling
       if (const CallNode* op = node.as<CallNode>()) {
-        for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
-          fpush_to_stack(*it);
+        // do not process args if used elsewhere
+        if (op->args.use_count() < 2) {
+          for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
+            fpush_to_stack(*it);
+          }
         }
-        fpush_to_stack(op->op);
       } else if (const TupleNode* op = node.as<TupleNode>()) {
-        for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
-          fpush_to_stack(*it);
+        // do not process fields if used elsewhere
+        if (op->fields.use_count() < 2) {
+          for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
+            fpush_to_stack(*it);
+          }
         }
       } else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
-        fpush_to_stack(op->tuple);
+        // do not process tuple if used elsewhere
+        if (op->tuple.use_count() < 2) {
+          fpush_to_stack(op->tuple);
+        }
       }
     }
   }
@@ -306,7 +316,6 @@ inline void Dismantle(const Expr& expr) {
 /*
  * Non-recursive destructor
  */
-
 Call::~Call() {
   // attempt to dismantle if referenced one or zero times
   if (this->use_count() < 2) {
@@ -316,5 +325,16 @@ Call::~Call() {
   }
 }
 
+/*
+ * CallNode's deleter
+ */
+void CallNode::Deleter_(Object* ptr) {
+  auto p = reinterpret_cast<CallNode*>(ptr);
+  // resore original deleter
+  p->deleter_ = p->saved_deleter_;
+  // create Call reference in order to invoke ~Call
+  auto c = GetRef<Call>(p);
+}
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/cpp/relay_dismantler_test.cc b/tests/cpp/relay_dismantler_test.cc
index d5c089b..8c74d41 100644
--- a/tests/cpp/relay_dismantler_test.cc
+++ b/tests/cpp/relay_dismantler_test.cc
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
 #include <gtest/gtest.h>
 #include <tvm/ir/expr.h>
 #include <tvm/ir/type_functor.h>
@@ -38,6 +37,8 @@
 #include <tvm/topi/broadcast.h>
 #include <tvm/topi/generic/injective.h>
 
+#include <memory>
+
 using namespace tvm;
 using namespace tvm::relay;
 
@@ -69,6 +70,80 @@ TEST(Relay, OutOfStack_cast) {
   ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
 }
 
+TEST(Relay, OutOfStack_packed_func) {
+  constexpr int len = 1e6;
+  auto foo = [] {
+    auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
+    auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
+    auto add_func = tvm::runtime::Registry::Get("relay.op._make.add");
+    auto y = (*add_func)(x, one);
+    for (int i = 0; i < len; ++i) {
+      y = (*add_func)(y, one);
+    }
+
+    // check if still reachable
+    int k = 0;
+    Expr e = y;
+    while (e.defined() && e.as<CallNode>() != nullptr) {
+      e = e.as<CallNode>()->args[0];
+      ++k;
+    }
+    ASSERT_EQ(len + 1, k);
+  };
+  ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*");
+}
+
+TEST(Relay, CallNodeSharedArgs) {
+  auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
+  auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
+  auto relu_op = relay::Op::Get("nn.relu");
+  Call y = relay::Call(relu_op, {x}, Attrs(), {});
+  y = relay::Call(relu_op, {y}, Attrs(), {});
+  ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size());
+  y = relay::Call(y.get()->op, y.get()->args, y.get()->attrs, y.get()->type_args);
+  ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size());
+}
+
+TEST(Relay, TupleSharedFields) {
+  auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
+  auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
+  auto relu_op = relay::Op::Get("nn.relu");
+  Expr y = relay::Call(relu_op, {x}, Attrs(), {});
+  y = relay::Call(relu_op, {y}, Attrs(), {});
+  {
+    Expr y1 = relay::Tuple(y.as<CallNode>()->args);
+    Expr y2 = relay::Tuple(y.as<CallNode>()->args);
+
+    y1 = relay::Call(relu_op, {y1});
+    y2 = relay::Call(relu_op, {y2});
+    y = y1;
+  }
+  ASSERT_EQ(1, y.as<CallNode>()->args[0].as<TupleNode>()->fields[0].as<CallNode>()->args.size());
+}
+
+TEST(Relay, TupleiGetItemSharedTuple) {
+  auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32)));
+  auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0}));
+  auto relu_op = relay::Op::Get("nn.relu");
+  Expr y = relay::Call(relu_op, {x}, Attrs(), {});
+  y = relay::Tuple({y});
+  {
+    Expr y1 = relay::TupleGetItem(y, 0);
+    Expr y2 = relay::TupleGetItem(y, 0);
+
+    y1 = relay::Call(relu_op, {y1});
+    y2 = relay::Call(relu_op, {y2});
+    y = y1;
+  }
+  ASSERT_EQ(1, y.as<CallNode>()
+                   ->args[0]
+                   .as<TupleGetItemNode>()
+                   ->tuple.as<TupleNode>()
+                   ->fields[0]
+                   .as<CallNode>()
+                   ->args.size());
+}
+
 int main(int argc, char** argv) {
   testing::InitGoogleTest(&argc, argv);
   testing::FLAGS_gtest_death_test_style = "threadsafe";
diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py
index b2ae286..4968660 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -30,6 +30,7 @@ SEMVER = '#[version = "0.0.5"]\n'
 
 def astext(program, unify_free_vars=False):
     text = program.astext()
+
     print(text)
     if isinstance(program, Expr):
         roundtrip_program = tvm.parser.parse_expr(text)
@@ -47,6 +48,17 @@ def show(text):
         print(text)
 
 
+def test_large_graph():
+    x = relay.var("x", shape=(3, 2))
+    y = relay.var("y")
+    one = relay.const(10e10, dtype="float32")
+    z = relay.add(x, one)
+    for i in range(int(1e6)):
+        z = relay.add(z, one)
+    f = relay.Function([x, y], z)
+    show(astext(f))
+
+
 def test_func():
     x = relay.var("x", shape=(3, 2))
     y = relay.var("y")