You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/06/13 06:03:59 UTC

[GitHub] [incubator-mxnet] ZhennanQin commented on a change in pull request #15118: Conversion from FP32 model to Mixed Precision model

ZhennanQin commented on a change in pull request #15118: Conversion from FP32 model to Mixed Precision model
URL: https://github.com/apache/incubator-mxnet/pull/15118#discussion_r293217028
 
 

 ##########
 File path: python/mxnet/contrib/amp/amp.py
 ##########
 @@ -342,3 +349,320 @@ def unscale(optimizer_or_trainer):
     else:
         raise TypeError("optimizer_or_trainer should be a Gluon Trainer or "
                         "an optimizer, instead is %s" % type(optimizer_or_trainer))
+
+def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
+                   fp32_ops=None, conditional_fp32_ops=None,
+                   excluded_sym_names=None, data_names=None):
+    """Given a symbol object representing a neural network of data type FP32 and target_dtype,
+    add cast layers according to the op lists (target_dtype_ops, fp32_ops,
+    conditional_fp32_ops) if provided, otherwise use the default
+    lists provided by the framework.
+
+    Parameters
+    ----------
+    sym : Symbol
+        FP32 neural network symbol
+    target_dtype : str or numpy, optional defaults to float16
+        currently only supports float16. The target dtype indicates to add cast layers
+        when possible so that lower precision computation can be leveraged.
+    target_dtype_ops : list of strs, optional
+        Override the list of operator names casted to the target_dtype.
+        If None, uses the framework's default list to be casted to target_dtype.
+    fp32_ops : list of strs, optional
+        Override the list of operator names casted to FP32.
+        If None, uses the framework's default list to be casted to FP32.
+    conditional_fp32_ops : list of (string, string, list of string), optional
+        Override the list of functions to be casted to FP32.
+        The format of the list is
+        (name of the function, name of the parameter,
+         list of values of the parameter that make the operator to be casted to FP32)
+    excluded_sym_names : list of strs, optional
+        A list of strings that represent the names of symbols that users want to exclude
+        from being casted to FP16 or FP32.
+    data_names : list of strs, optional
+        A list of strings that represent input data tensor names to the model
+    """
+    if target_dtype != "float16":
+        raise ValueError("Only target_dtype float16 is supported currently")
+
+    if target_dtype_ops is not None:
+        assert isinstance(target_dtype_ops, list), "target_dtype_ops should be a list of strs"
+    else:
+        target_dtype_ops = lists.symbol.FP16_FUNCS
+
+    if fp32_ops is not None:
+        assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
+    else:
+        fp32_ops = lists.symbol.FP32_FUNCS
+
+    if conditional_fp32_ops is not None:
+        assert isinstance(conditional_fp32_ops, list), "conditional_fp32_ops should be a list"
+    else:
+        conditional_fp32_ops = lists.symbol.CONDITIONAL_FP32_FUNCS
+
+    original_conditional_op_names = []
+    conditional_op_names = []
+    param_names = []
+    param_vals = []
+    indptr = [0]
+    for conditional_fp32_op in conditional_fp32_ops:
+        assert isinstance(conditional_fp32_op[0], str) and isinstance(conditional_fp32_op[1], str) \
+            and isinstance(conditional_fp32_op[2], list), "conditional_fp32_ops should be a list of " \
+                                                          "(str, str, list of str)"
+        param_vals += conditional_fp32_op[2]
+        indptr.append(len(param_vals))
+        param_names.append(conditional_fp32_op[1])
+        conditional_op_names.append(conditional_fp32_op[0])
+
+    if excluded_sym_names is not None:
+        assert isinstance(excluded_sym_names, list), "excluded_sym_names should be a list of strs"
+    else:
+        excluded_sym_names = []
+
+    for original_conditional_fp32_op in lists.symbol.CONDITIONAL_FP32_FUNCS:
+        original_conditional_op_names.append(original_conditional_fp32_op[0])
+
+    # Op lists should not have intersection
+    common_ops = set(target_dtype_ops) & set(fp32_ops)
+    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
+                                 "Common ops in target_dtype_ops and fp32_ops {}".format(common_ops)
+    common_ops = set(target_dtype_ops) & set(conditional_op_names)
+    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
+                                 "Common ops in target_dtype_ops and conditional_fp32_ops {}".format(common_ops)
+    common_ops = set(conditional_op_names) & set(fp32_ops)
+    assert len(common_ops) == 0, "Ops cannot be in two or more lists. " \
+                                 "Common ops in fp32_ops and conditional_fp32_ops {}".format(common_ops)
+
+    combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
+    all_fp16_fp32_ops = set(lists.symbol.FP16_FUNCS + lists.symbol.FP32_FUNCS
+                            + lists.symbol.FP16_FP32_FUNCS + original_conditional_op_names)
+
+    illegal_ops = combined_ops - all_fp16_fp32_ops
+    assert not illegal_ops, '''Can only choose ops from one of the three lists
+                            for fp16_ops and fp32_ops
+                            1. amp.list_fp16_ops()
+                            2. amp.list_fp32_ops()
+                            3. amp.list_fp16_fp32_ops()
+                            4. amp.list_conditional_fp32_ops()
+                            Op %s not in any of them''' % (illegal_ops)
+
+    widest_dtype_ops = lists.symbol.WIDEST_TYPE_CASTS
+    target_dtype = _DTYPE_NP_TO_MX[np.dtype(target_dtype).type]
+
+    # Prepare a data_names list based on list_inputs if its not provided
+    # Add all names in list for the nodes in the symbol which don't have
+    # __dtype__ set
+    attr_dict = sym.attr_dict()
+    if not data_names:
+        data_names = []
+        for sym_name in sym.list_inputs():
+            if not sym_name in attr_dict:
+                data_names.append(sym_name)
+                continue
+            if not "__dtype__" in attr_dict[sym_name]:
+                data_names.append(sym_name)
+    model_param_names = list(set(sym.list_inputs()) - set(data_names))
+
+    # Since assumption is that it is a FP32 model, set dtypes for all
+    # data_names to float32
+    str_keys = []
+    sdata = []
+    for k in data_names:
+        str_keys.append(k)
+        sdata.append(0)
+    keys = c_str_array(str_keys)
+
+    out = SymbolHandle()
+    check_call(_LIB.MXReducePrecisionSymbol(sym.handle,
+                                            ctypes.byref(out),
+                                            mx_uint(len(sdata)),
+                                            c_array_buf(ctypes.c_int, array('i', sdata)),
+                                            mx_uint(len(indptr)),
+                                            c_array_buf(ctypes.c_int, array('i', indptr)),
+                                            ctypes.byref(ctypes.c_int(target_dtype)),
+                                            mx_uint(len(target_dtype_ops)),
+                                            mx_uint(len(fp32_ops)),
+                                            mx_uint(len(widest_dtype_ops)),
+                                            mx_uint(len(conditional_op_names)),
+                                            mx_uint(len(excluded_sym_names)),
+                                            mx_uint(len(model_param_names)),
+                                            c_str_array(target_dtype_ops),
+                                            c_str_array(fp32_ops),
+                                            c_str_array(widest_dtype_ops),
+                                            c_str_array(conditional_op_names),
+                                            c_str_array(excluded_sym_names),
+                                            c_str_array(param_names),
+                                            c_str_array(param_vals),
+                                            c_str_array(model_param_names),
+                                            keys))
+    return Symbol(out)
+
+def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=None,
+                  fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=None):
+    """API for converting a model from FP32 model to a mixed precision model.
+    MXNet tries to convert the FP32 model to mixed precision model by adding
+    cast layers using amp_cast and amp_multicast operators which can be used for inference use cases.
+    The decision on which cast layer to add is based on hardcoded lists for Automatic Mixed Precision
+    in MXNet. These lists can be overridden by the user by providing their own lists
+    using : targe_precision_ops, fp32_ops, widest_precision_ops, conditional_fp32_ops
+
+    arg_params : dict
+        Dictionary of name to `NDArray`.
+    aux_params : dict
+        Dictionary of name to `NDArray`.
+    target_dtype : str
+        Currently only supports float16. The target dtype indicates to add cast layers
+        when possible so that lower precision computation can be leveraged.
+    target_dtype_ops : list of strs
+        Override the list of operator names casted to target_dtype.
+        If None, uses the framework's default list to be casted to target dtype.
+    fp32_ops : list of strs
+        Override the lists of operator names casted to FP32.
+        If None, uses the framework's default list to be casted to FP32.
+    widest_dtype_ops : list of strs
+        A list of op names provided by user which should run in widest precision among its inputs.
+        If None, uses the framework's default list of widest_precision_ops.
+    conditional_fp32_ops : list of (string, string, list of string)
+        Override the list of operators to be casted to FP32.
+        The format of the list is
+        (name of the function, name of the parameter,
+         list of values of the parameter that make the operator to be casted to
+        fp32)
+    excluded_sym_names : list of strs
+        A list of strings that represent the names of symbols that users want to exclude
+        from being quantized.
 
 Review comment:
   I'm not sure if low precision is a kind of quantization, maybe change `quantized` to `low precision`?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services