You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/07/24 14:23:30 UTC
[incubator-tvm] branch master updated: Add 'get_num_inputs' to
GraphRuntime (#6118)
This is an automated email from the ASF dual-hosted git repository.
marisa pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new bfa4eae Add 'get_num_inputs' to GraphRuntime (#6118)
bfa4eae is described below
commit bfa4eae1dcac7f2493e543823e51eb420b0f8b2c
Author: Alexander Booth <ad...@gmail.com>
AuthorDate: Fri Jul 24 07:22:39 2020 -0700
Add 'get_num_inputs' to GraphRuntime (#6118)
---
python/tvm/contrib/graph_runtime.py | 11 +++++++++++
src/runtime/graph/graph_runtime.cc | 9 +++++++++
src/runtime/graph/graph_runtime.h | 6 ++++++
3 files changed, 26 insertions(+)
diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py
index ec102f5..326eccb 100644
--- a/python/tvm/contrib/graph_runtime.py
+++ b/python/tvm/contrib/graph_runtime.py
@@ -133,6 +133,7 @@ class GraphModule(object):
self._get_output = module["get_output"]
self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"]
+ self._get_num_inputs = module["get_num_inputs"]
self._load_params = module["load_params"]
self._share_params = module["share_params"]
@@ -187,6 +188,16 @@ class GraphModule(object):
"""
return self._get_num_outputs()
+ def get_num_inputs(self):
+ """Get the number of inputs to the graph
+
+ Returns
+ -------
+ count : int
+ The number of inputs.
+ """
+ return self._get_num_inputs()
+
def get_input(self, index, out=None):
"""Get index-th input to out
diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc
index e984861..18245ba 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -135,6 +135,12 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) {
*/
int GraphRuntime::NumOutputs() const { return outputs_.size(); }
/*!
+ * \brief Get the number of inputs
+ *
+ * \return The number of inputs to the graph.
+ */
+int GraphRuntime::NumInputs() const { return input_nodes_.size(); }
+/*!
* \brief Return NDArray for given input index.
* \param index The input index.
*
@@ -433,6 +439,9 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
} else if (name == "get_num_outputs") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
+ } else if (name == "get_num_inputs") {
+ return PackedFunc(
+ [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); });
} else if (name == "run") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); });
} else if (name == "load_params") {
diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h
index d0c9822..dcef1e4 100644
--- a/src/runtime/graph/graph_runtime.h
+++ b/src/runtime/graph/graph_runtime.h
@@ -125,6 +125,12 @@ class TVM_DLL GraphRuntime : public ModuleNode {
*/
int NumOutputs() const;
/*!
+ * \brief Get the number of inputs
+ *
+ * \return The number of inputs to the graph.
+ */
+ int NumInputs() const;
+ /*!
* \brief Return NDArray for given input index.
* \param index The input index.
*