You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/31 20:00:31 UTC

[GitHub] sandeep-krishnamurthy commented on a change in pull request #12412: Infer dtype in SymbolBlock import from input symbol

sandeep-krishnamurthy commented on a change in pull request #12412: Infer dtype in SymbolBlock import from input symbol
URL: https://github.com/apache/incubator-mxnet/pull/12412#discussion_r214461609
 
 

 ##########
 File path: python/mxnet/gluon/block.py
 ##########
 @@ -1086,3 +1107,49 @@ def _clear_cached_op(self):
 
     def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
+
+def _infer_param_types(in_params, out_params, arg_params, aux_params):
+    """Utility function that helps in inferring DType of args and auxs params
+    from given input param.
+
+    Parameters
+    ----------
+    in_params: Symbol
+        Input symbol variable.
+    out_params: Symbol
+        Output symbol variable.
+    arg_params: List of Str
+        List of names of argument parametrs.
+    aux_params: List of Str
+        List of names of auxiliary parameters.
+
+    Returns
+    -------
+    infer_type_success: Boolean
+        True if able to infer types for all given arg_params and aux_params.
+        False, otherwise.
+    arg_types: List of numpy.dtype
+        List of arg_params type. Order is same as arg_params.
+        None if unable to infer type.
+    aux_types: List of numpy.dtype
+        List of aux_params type. Order is same as aux_params.
+        None if unable to infer type.
+    """
+    infer_type_success = False
+    arg_types = None
+    aux_types = None
+
+    # Get Input symbol details. This will be used to infer types of
+    # other parameters.
+    input_sym_name = in_params.name
+    input_sym_arg_type = in_params.infer_type()[0]
+
+    # Try to infer types of other parameters.
+    if input_sym_arg_type and len(input_sym_arg_type) > 0:
+        params = {input_sym_name:input_sym_arg_type[0]}
+        arg_types, _, aux_types = out_params.infer_type(**params)
+        if arg_types is not None and len(arg_types) == len(arg_params) and \
+           aux_types is not None and len(aux_types) == len(aux_params):
+            infer_type_success = True
+
+    return (infer_type_success, arg_types, aux_types)
 
 Review comment:
   Done.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services