You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/03 03:57:25 UTC

[GitHub] [tvm] masahi commented on a change in pull request #9494: [Runtime] Pipeline Executor Add Set and Get Input/Output interfaces.

masahi commented on a change in pull request #9494:
URL: https://github.com/apache/tvm/pull/9494#discussion_r761629988



##########
File path: python/tvm/contrib/pipeline_executor.py
##########
@@ -93,8 +107,75 @@ def __init__(self, module):
         else:
             self.module = module
         # Get the packed functions from the pipeline executor.
+        self._run = self.module["run"]
+        self._stop = self.module["stop"]
+        self._set_input = self.module["set_input"]
+        self._set_param = self.module["set_param"]
+        self._get_input = self.module["get_input"]
+        self._get_output = self.module["get_output"]
+        self._get_num_inputs = self.module["get_num_inputs"]
         self._get_num_outputs = self.module["get_num_outputs"]
 
+    def run(self, sync=False):
+        """Run the pipeline executor."""
+        self._run(sync)
+
+    def stop(self):
+        """Stop the pipeline executor."""
+        self._stop()
+
+    def set_input(self, key, value):
+        """Set inputs to the module via "value".
+        Parameters
+        ----------
+        key : str
+            The input key
+
+        value : array_like.
+            The input value
+        """
+        v = self._get_input(key)
+        if v is None:
+            raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
+        v.copyfrom(value)
+
+    def set_params(self, params_name, params_data):
+        """Set params to the module via param name and params data.
+        Parameters
+        ----------
+        params_name : str
+            The params name
+
+        params_data : dict of str to NDArray
+            A list of params data and params key name.
+        """
+        for key, val in params_data.items():
+            self._set_param(params_name, key, val)
+
+    def get_input(self, key):
+        """Get the input via a input name.

Review comment:
       an input

##########
File path: python/tvm/contrib/pipeline_executor.py
##########
@@ -93,8 +107,75 @@ def __init__(self, module):
         else:
             self.module = module
         # Get the packed functions from the pipeline executor.
+        self._run = self.module["run"]
+        self._stop = self.module["stop"]
+        self._set_input = self.module["set_input"]
+        self._set_param = self.module["set_param"]
+        self._get_input = self.module["get_input"]
+        self._get_output = self.module["get_output"]
+        self._get_num_inputs = self.module["get_num_inputs"]
         self._get_num_outputs = self.module["get_num_outputs"]
 
+    def run(self, sync=False):
+        """Run the pipeline executor."""
+        self._run(sync)
+
+    def stop(self):
+        """Stop the pipeline executor."""
+        self._stop()
+
+    def set_input(self, key, value):
+        """Set inputs to the module via "value".

Review comment:
       `via "value"` doesn't make sense

##########
File path: python/tvm/contrib/pipeline_executor.py
##########
@@ -93,8 +107,75 @@ def __init__(self, module):
         else:
             self.module = module
         # Get the packed functions from the pipeline executor.
+        self._run = self.module["run"]
+        self._stop = self.module["stop"]
+        self._set_input = self.module["set_input"]
+        self._set_param = self.module["set_param"]
+        self._get_input = self.module["get_input"]
+        self._get_output = self.module["get_output"]
+        self._get_num_inputs = self.module["get_num_inputs"]
         self._get_num_outputs = self.module["get_num_outputs"]
 
+    def run(self, sync=False):
+        """Run the pipeline executor."""
+        self._run(sync)
+
+    def stop(self):
+        """Stop the pipeline executor."""
+        self._stop()
+
+    def set_input(self, key, value):
+        """Set inputs to the module via "value".
+        Parameters
+        ----------
+        key : str
+            The input key
+
+        value : array_like.
+            The input value
+        """
+        v = self._get_input(key)
+        if v is None:
+            raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
+        v.copyfrom(value)
+
+    def set_params(self, params_name, params_data):
+        """Set params to the module via param name and params data.
+        Parameters
+        ----------
+        params_name : str
+            The params name
+
+        params_data : dict of str to NDArray
+            A list of params data and params key name.

Review comment:
       what is the difference between params key and params_name

##########
File path: python/tvm/contrib/pipeline_executor.py
##########
@@ -93,8 +107,75 @@ def __init__(self, module):
         else:
             self.module = module
         # Get the packed functions from the pipeline executor.
+        self._run = self.module["run"]
+        self._stop = self.module["stop"]
+        self._set_input = self.module["set_input"]
+        self._set_param = self.module["set_param"]
+        self._get_input = self.module["get_input"]
+        self._get_output = self.module["get_output"]
+        self._get_num_inputs = self.module["get_num_inputs"]
         self._get_num_outputs = self.module["get_num_outputs"]
 
+    def run(self, sync=False):
+        """Run the pipeline executor."""
+        self._run(sync)
+
+    def stop(self):
+        """Stop the pipeline executor."""
+        self._stop()
+
+    def set_input(self, key, value):
+        """Set inputs to the module via "value".
+        Parameters
+        ----------
+        key : str
+            The input key
+
+        value : array_like.
+            The input value
+        """
+        v = self._get_input(key)
+        if v is None:
+            raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
+        v.copyfrom(value)
+
+    def set_params(self, params_name, params_data):
+        """Set params to the module via param name and params data.
+        Parameters
+        ----------
+        params_name : str
+            The params name
+
+        params_data : dict of str to NDArray
+            A list of params data and params key name.
+        """
+        for key, val in params_data.items():
+            self._set_param(params_name, key, val)

Review comment:
       Related to the above comment, this is a weird API. 

##########
File path: python/tvm/contrib/pipeline_executor.py
##########
@@ -93,8 +107,75 @@ def __init__(self, module):
         else:
             self.module = module
         # Get the packed functions from the pipeline executor.
+        self._run = self.module["run"]
+        self._stop = self.module["stop"]
+        self._set_input = self.module["set_input"]
+        self._set_param = self.module["set_param"]
+        self._get_input = self.module["get_input"]
+        self._get_output = self.module["get_output"]
+        self._get_num_inputs = self.module["get_num_inputs"]
         self._get_num_outputs = self.module["get_num_outputs"]
 
+    def run(self, sync=False):
+        """Run the pipeline executor."""
+        self._run(sync)
+
+    def stop(self):
+        """Stop the pipeline executor."""
+        self._stop()
+
+    def set_input(self, key, value):
+        """Set inputs to the module via "value".
+        Parameters
+        ----------
+        key : str
+            The input key
+
+        value : array_like.
+            The input value
+        """
+        v = self._get_input(key)
+        if v is None:
+            raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
+        v.copyfrom(value)
+
+    def set_params(self, params_name, params_data):
+        """Set params to the module via param name and params data.
+        Parameters
+        ----------
+        params_name : str
+            The params name
+
+        params_data : dict of str to NDArray
+            A list of params data and params key name.
+        """
+        for key, val in params_data.items():
+            self._set_param(params_name, key, val)
+
+    def get_input(self, key):
+        """Get the input via a input name.
+        Parameters
+        ----------
+        key : str
+            The input key
+
+        Returns
+        -------
+        data : NDArray
+            Then input data.

Review comment:
       The

##########
File path: python/tvm/contrib/pipeline_executor.py
##########
@@ -93,8 +107,75 @@ def __init__(self, module):
         else:
             self.module = module
         # Get the packed functions from the pipeline executor.
+        self._run = self.module["run"]
+        self._stop = self.module["stop"]
+        self._set_input = self.module["set_input"]
+        self._set_param = self.module["set_param"]
+        self._get_input = self.module["get_input"]
+        self._get_output = self.module["get_output"]
+        self._get_num_inputs = self.module["get_num_inputs"]
         self._get_num_outputs = self.module["get_num_outputs"]
 
+    def run(self, sync=False):
+        """Run the pipeline executor."""
+        self._run(sync)
+
+    def stop(self):
+        """Stop the pipeline executor."""
+        self._stop()
+
+    def set_input(self, key, value):
+        """Set inputs to the module via "value".
+        Parameters
+        ----------
+        key : str
+            The input key
+
+        value : array_like.
+            The input value
+        """
+        v = self._get_input(key)
+        if v is None:
+            raise RuntimeError("Could not find '%s' in pipeline's inputs" % key)
+        v.copyfrom(value)
+
+    def set_params(self, params_name, params_data):
+        """Set params to the module via param name and params data.

Review comment:
       Choose param or params




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org