You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/07/24 14:23:30 UTC

[incubator-tvm] branch master updated: Add 'get_num_inputs' to GraphRuntime (#6118)

This is an automated email from the ASF dual-hosted git repository.

marisa pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new bfa4eae  Add 'get_num_inputs' to GraphRuntime (#6118)
bfa4eae is described below

commit bfa4eae1dcac7f2493e543823e51eb420b0f8b2c
Author: Alexander Booth <ad...@gmail.com>
AuthorDate: Fri Jul 24 07:22:39 2020 -0700

    Add 'get_num_inputs' to GraphRuntime (#6118)
---
 python/tvm/contrib/graph_runtime.py | 11 +++++++++++
 src/runtime/graph/graph_runtime.cc  |  9 +++++++++
 src/runtime/graph/graph_runtime.h   |  6 ++++++
 3 files changed, 26 insertions(+)

diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py
index ec102f5..326eccb 100644
--- a/python/tvm/contrib/graph_runtime.py
+++ b/python/tvm/contrib/graph_runtime.py
@@ -133,6 +133,7 @@ class GraphModule(object):
         self._get_output = module["get_output"]
         self._get_input = module["get_input"]
         self._get_num_outputs = module["get_num_outputs"]
+        self._get_num_inputs = module["get_num_inputs"]
         self._load_params = module["load_params"]
         self._share_params = module["share_params"]
 
@@ -187,6 +188,16 @@ class GraphModule(object):
         """
         return self._get_num_outputs()
 
+    def get_num_inputs(self):
+        """Get the number of inputs to the graph
+
+        Returns
+        -------
+        count : int
+            The number of inputs.
+        """
+        return self._get_num_inputs()
+
     def get_input(self, index, out=None):
         """Get index-th input to out
 
diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc
index e984861..18245ba 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -135,6 +135,12 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) {
  */
 int GraphRuntime::NumOutputs() const { return outputs_.size(); }
 /*!
+ * \brief Get the number of inputs
+ *
+ * \return The number of inputs to the graph.
+ */
+int GraphRuntime::NumInputs() const { return input_nodes_.size(); }
+/*!
  * \brief Return NDArray for given input index.
  * \param index The input index.
  *
@@ -433,6 +439,9 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
   } else if (name == "get_num_outputs") {
     return PackedFunc(
         [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
+  } else if (name == "get_num_inputs") {
+    return PackedFunc(
+        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); });
   } else if (name == "run") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); });
   } else if (name == "load_params") {
diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h
index d0c9822..dcef1e4 100644
--- a/src/runtime/graph/graph_runtime.h
+++ b/src/runtime/graph/graph_runtime.h
@@ -125,6 +125,12 @@ class TVM_DLL GraphRuntime : public ModuleNode {
    */
   int NumOutputs() const;
   /*!
+   * \brief Get the number of inputs
+   *
+   * \return The number of inputs to the graph.
+   */
+  int NumInputs() const;
+  /*!
    * \brief Return NDArray for given input index.
    * \param index The input index.
    *