You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pt...@apache.org on 2020/07/14 21:13:04 UTC

[incubator-mxnet] branch master updated: Add better partial args/aux handling in symbol optimize_for (#18350)

This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7f7e1c5  Add better partial args/aux handling in symbol optimize_for (#18350)
7f7e1c5 is described below

commit 7f7e1c5a714262e8cd1015716258416e6ce1ff3e
Author: Serge Panev <sp...@nvidia.com>
AuthorDate: Tue Jul 14 14:12:00 2020 -0700

    Add better partial args/aux handling in symbol optimize_for (#18350)
    
    * Add missing args/aux support in optimize_for and deferred inference option
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Add input shape_dict, type_dict and stype_dict to optimize_for
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Remove warnings for Werror
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
    
    * Address PR comments
    
    Signed-off-by: Serge Panev <sp...@nvidia.com>
---
 include/mxnet/c_api.h         |  30 +++++++++++
 python/mxnet/symbol/symbol.py |  98 +++++++++++++++++++++++++++++++---
 src/c_api/c_api_symbolic.cc   | 120 +++++++++++++++++++++++++++++-------------
 src/common/exec_utils.h       |  72 -------------------------
 4 files changed, 204 insertions(+), 116 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 153d8c2..36c76e5 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -2235,6 +2235,25 @@ MXNET_DLL int MXGenAtomicSymbolFromSymbol(SymbolHandle sym_handle, SymbolHandle
  * \param num_options number of key value pairs
  * \param keys keys for options
  * \param vals values corresponding to keys
+ * \param num_input_shapes number of input shapes
+ * \param input_shape_names names of the input shapes
+ * \param input_shape_data pointer to the contiguous data shapes
+ * \param input_shape_idx array of per shape starting idx, the shape length for the i-th input shape
+ * is calculate as input_shape_idx[i+1] - input_shape_idx[i]
+ * \param num_input_dtypes number of input data types
+ * \param input_dtype_names array of names of the input data types
+ * \param input_dtypes array of values of the input data types
+ * \param num_input_stypesnumber of input storage types
+ * \param input_stype_names array of names of the input storage types
+ * \param input_stypes array of values of input storage types
+ * \param skip_infer if the optimization should skip the attribute inferences
+ * (to use if the backend does not require shape inference)
+ * \param new_args_cnt pointer a number to store the number of new args
+ * \param new_args_handle pointer on array to store the new args handles
+ * \param new_arg_names_handle pointer on array to store the new args names
+ * \param new_aux_cnt pointer a number to store the number of new aux
+ * \param new_aux_handle pointer on array to store the new aux handles
+ * \param new_aux_names_handle pointer on array to store the new aux names
  */
 MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
                                    const char* backend_name,
@@ -2247,6 +2266,17 @@ MXNET_DLL int MXOptimizeForBackend(SymbolHandle sym_handle,
                                    const mx_uint num_options,
                                    const char** keys,
                                    const char** vals,
+                                   const uint32_t num_input_shapes,
+                                   const char** input_shape_names,
+                                   const int64_t* input_shape_data,
+                                   const uint32_t* input_shape_idx,
+                                   const uint32_t num_input_dtypes,
+                                   const char** input_dtype_names,
+                                   const int* input_dtypes,
+                                   const uint32_t num_input_stypes,
+                                   const char** input_stype_names,
+                                   const int* input_stypes,
+                                   bool skip_infer,
                                    int* new_args_cnt,
                                    NDArrayHandle** new_args_handle,
                                    char*** new_arg_names_handle,
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 89ff6bf..9d23050 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1446,7 +1446,8 @@ class Symbol(SymbolBase):
 
 
     # pylint: disable=too-many-locals
-    def optimize_for(self, backend, args=None, aux=None, ctx=None, **kwargs):
+    def optimize_for(self, backend, args=None, aux=None, ctx=None,
+                     shape_dict=None, type_dict=None, stype_dict=None, skip_infer=False, **kwargs):
         """Partitions current symbol and optimizes it for a given backend,
         returns new partitioned symbol.
 
@@ -1457,19 +1458,33 @@ class Symbol(SymbolBase):
 
         args : dict of str to NDArray, optional
             Input arguments to the symbol, required to infer shapes/types before partitioning
-
             - If type is a dict of str to `NDArray`, then it maps the name of arguments
-              to the corresponding `NDArray`.
+              to the corresponding `NDArray`. Non defined arguments' `NDArray`s don't have to be
+              specified in the dict.
 
         aux : dict of str to NDArray, optional
             Input auxiliary arguments to the symbol
-
             - If type is a dict of str to `NDArray`, then it maps the name of arguments
               to the corresponding `NDArray`.
 
         ctx : Context, optional
             Device context, used to infer stypes
 
+        shape_dict  : Dict of str->tuple, optional
+            Input shape dictionary.
+            Used iff input NDArray is not in `args`.
+
+        type_dict  : Dict of str->numpy.dtype, optional
+            Input type dictionary.
+            Used iff input NDArray is not in `args`.
+
+        stype_dict  : Dict of str->str, optional
+            Input storage type dictionary.
+            Used iff input NDArray is not in `args`.
+
+        skip_infer : bool, optional
+            If True, the optimization skips the shape, type and storage type inference pass.
+
         kwargs : optional arguments
             Passed on to `PrePartition` and `PostPartition` functions of `SubgraphProperty`
 
@@ -1488,18 +1503,78 @@ class Symbol(SymbolBase):
             args_handle = c_array(NDArrayHandle, [])
         else:
             args_handle, args_ = self._get_ndarray_inputs('args', args,
-                                                          self.list_arguments(), False)
+                                                          self.list_arguments(), True)
 
         if aux is None or len(aux) == 0:
             aux_ = []
             aux_handle = c_array(NDArrayHandle, [])
         else:
             aux_handle, aux_ = self._get_ndarray_inputs('aux_states', aux,
-                                                        self.list_auxiliary_states(), False)
+                                                        self.list_auxiliary_states(), True)
         if ctx is None:
             ctx = current_context()
         assert isinstance(ctx, Context)
 
+
+        # parse input data shape dict
+        num_input_shapes = 0
+        input_shape_names = ctypes.POINTER(ctypes.c_char_p)()
+        input_shape_data = ctypes.POINTER(mx_int64)()
+        input_shape_idx = ctypes.POINTER(mx_uint)()
+        if shape_dict is not None:
+            input_shape_names = []
+            input_shape_data = []
+            input_shape_idx = [0]
+            for k, v in shape_dict.items():
+                if isinstance(v, (tuple, list)):
+                    input_shape_names.append(k)
+                    input_shape_data.extend(v)
+                    input_shape_idx.append(len(input_shape_data))
+                else:
+                    raise ValueError(str(v) + " has to be a tuple or list.")
+            num_input_shapes = mx_uint(len(input_shape_names))
+            input_shape_names = c_str_array(input_shape_names)
+            input_shape_data = c_array_buf(mx_int64, array('q', input_shape_data))
+            input_shape_idx = c_array_buf(mx_uint, array('i', input_shape_idx))
+
+        # parse input data types dict
+        num_input_types = 0
+        input_type_names = ctypes.POINTER(ctypes.c_char_p)()  # provided type argument names
+        input_type_data = ctypes.POINTER(mx_uint)()  # provided types
+        if type_dict is not None:
+            input_type_names = []
+            input_type_data = []
+            for k, v in type_dict.items():
+                v = _numpy.dtype(v).type
+                if v in _DTYPE_NP_TO_MX:
+                    input_type_names.append(k)
+                    input_type_data.append(_DTYPE_NP_TO_MX[v])
+                else:
+                    raise ValueError(str(v) + " is not a MXNet type.")
+
+            num_input_types = mx_uint(len(input_type_names))
+            input_type_names = c_str_array(input_type_names)
+            input_type_data = c_array_buf(ctypes.c_int, array('i', input_type_data))
+
+        # parse input data storage types dict
+        num_input_stypes = 0
+        # provided storage type argument names
+        input_stype_names = ctypes.POINTER(ctypes.c_char_p)()
+        input_stype_data = ctypes.POINTER(mx_uint)()  # provided storage types
+        if stype_dict is not None:
+            input_stype_names = []
+            input_stype_data = []
+            for k, v in stype_dict.items():
+                if v in _STORAGE_TYPE_STR_TO_ID:
+                    input_stype_names.append(k)
+                    input_stype_data.append(_STORAGE_TYPE_STR_TO_ID[v])
+                else:
+                    raise ValueError(str(v) + " is not a MXNet storage type.")
+
+            num_input_stypes = mx_uint(len(input_stype_names))
+            input_stype_names = c_str_array(input_stype_names)
+            input_stype_data = c_array_buf(ctypes.c_int, array('i', input_stype_data))
+
         new_args_size = ctypes.c_uint()
         new_arg_names = ctypes.POINTER(ctypes.c_char_p)()
         new_args_handle = ctypes.POINTER(NDArrayHandle)()
@@ -1523,6 +1598,17 @@ class Symbol(SymbolBase):
                                              mx_uint(len(key_list)),
                                              c_str_array(key_list),
                                              c_str_array(val_list),
+                                             num_input_shapes,
+                                             input_shape_names,
+                                             input_shape_data,
+                                             input_shape_idx,
+                                             num_input_types,
+                                             input_type_names,
+                                             input_type_data,
+                                             num_input_stypes,
+                                             input_stype_names,
+                                             input_stype_data,
+                                             ctypes.c_bool(skip_infer),
                                              ctypes.byref(new_args_size),
                                              ctypes.byref(new_args_handle),
                                              ctypes.byref(new_arg_names),
diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc
index 963a2d5..0cc0ef7 100644
--- a/src/c_api/c_api_symbolic.cc
+++ b/src/c_api/c_api_symbolic.cc
@@ -1360,6 +1360,17 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
                          const mx_uint num_options,
                          const char** keys,
                          const char** vals,
+                         const uint32_t num_input_shapes,
+                         const char** input_shape_names,
+                         const int64_t* input_shape_data,
+                         const uint32_t* input_shape_idx,
+                         const uint32_t num_input_dtypes,
+                         const char** input_dtype_names,
+                         const int* input_dtypes,
+                         const uint32_t num_input_stypes,
+                         const char** input_stype_names,
+                         const int* input_stypes,
+                         bool skip_infer,
                          int* new_args_cnt,
                          NDArrayHandle** new_args_handle,
                          char*** new_arg_names_handle,
@@ -1383,47 +1394,80 @@ int MXOptimizeForBackend(SymbolHandle sym_handle,
   if (args_len || aux_len) {
     NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args_handle);
     NDArray **in_aux_ptr = reinterpret_cast<NDArray**>(in_aux_handle);
-    Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
-    mxnet::ShapeVector arg_shapes(args_len + aux_len);
-    nnvm::DTypeVector arg_dtypes(args_len + aux_len);
-    StorageTypeVector arg_stypes(args_len + aux_len);
-    size_t args_top = 0, aux_top = 0;
-    // loop over inputs to symbol in order and add to args/aux if mutable
-    for (size_t i = 0; i < num_forward_inputs; ++i) {
-      const uint32_t nid = indexed_graph.input_nodes().at(i);
-      if (mutable_nodes.count(nid)) {
-        CHECK_LT(aux_top, aux_len)
-          << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
-        const auto &in_arg = *(in_aux_ptr[aux_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
-      } else {
-        CHECK_LT(args_top, args_len)
-          << "Cannot find arg '" << input_names[i] << "' in provided args to optimize_for";
-        const auto &in_arg = *(in_args_ptr[args_top++]);
-        arg_shapes[i] = in_arg.shape();
-        arg_dtypes[i] = in_arg.dtype();
-        arg_stypes[i] = in_arg.storage_type();
+    if (!skip_infer) {
+      Context default_ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0);
+      mxnet::ShapeVector arg_shapes(args_len + aux_len);
+      nnvm::DTypeVector arg_dtypes(args_len + aux_len);
+      StorageTypeVector arg_stypes(args_len + aux_len);
+
+      // create the input shape, dtype and stype maps
+      std::unordered_map<std::string, mxnet::TShape> input_shape_map(num_input_shapes);
+      for (uint32_t i = 0; i < num_input_shapes; ++i) {
+        input_shape_map.emplace(input_shape_names[i],
+                    mxnet::TShape(input_shape_data + input_shape_idx[i],
+                    input_shape_data + input_shape_idx[i+1]));
+      }
+      std::unordered_map<std::string, int> input_dtype_map(num_input_dtypes);
+      for (uint32_t i = 0; i < num_input_dtypes; ++i) {
+        input_dtype_map.emplace(input_dtype_names[i], input_dtypes[i]);
+      }
+      std::unordered_map<std::string, int> input_stype_map(num_input_stypes);
+      for (uint32_t i = 0; i < num_input_stypes; ++i) {
+        input_stype_map.emplace(input_stype_names[i], input_stypes[i]);
       }
-    }
 
-    g.attrs["context"] = std::make_shared<nnvm::any>(
-        exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+      size_t args_top = 0, aux_top = 0;
+      // loop over inputs to symbol in order and add to args/aux if mutable
+      for (size_t i = 0; i < num_forward_inputs; ++i) {
+        const uint32_t nid = indexed_graph.input_nodes().at(i);
+        if (mutable_nodes.count(nid)) {
+          CHECK_LT(aux_top, aux_len)
+            << "Cannot find aux '" << input_names[i] << "' in provided aux to optimize_for";
+          if (in_aux_ptr[aux_top] != nullptr) {
+            const auto &in_arg = *(in_aux_ptr[aux_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          }
+          aux_top++;
+        } else {
+          auto name = input_names[i];
+          CHECK_LT(args_top, args_len)
+            << "Cannot find arg '" << name << "' in provided args to optimize_for";
+          if (in_args_ptr[args_top] != nullptr) {
+            const auto &in_arg = *(in_args_ptr[args_top]);
+            arg_shapes[i] = in_arg.shape();
+            arg_dtypes[i] = in_arg.dtype();
+            arg_stypes[i] = in_arg.storage_type();
+          } else {
+            // input_names[i] is not in args but can be in the optional
+            // shape/type/stype attribute dicts.
+            auto it_shape = input_shape_map.find(name);
+            if (it_shape != input_shape_map.end()) {
+              arg_shapes[i] = it_shape->second;
+            }
+            auto it_type = input_dtype_map.find(name);
+            if (it_type != input_dtype_map.end()) {
+              arg_dtypes[i] = it_type->second;
+            }
+            it_type = input_stype_map.find(name);
+            if (it_type != input_stype_map.end()) {
+              arg_stypes[i] = it_type->second;
+            }
+          }
+          args_top++;
+        }
+      }
 
-    // infer shapes
-    g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
-    // infer dtypes
-    g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
-    if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
-      common::HandleInferTypeError(num_forward_inputs, indexed_graph,
-                                   g.GetAttr<nnvm::DTypeVector>("dtype"));
-    }
-    // infer stypes
-    g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
-    if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
-      common::HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
-                                          g.GetAttr<StorageTypeVector>("storage_type"));
+      g.attrs["context"] = std::make_shared<nnvm::any>(
+          exec::ContextVector(indexed_graph.num_nodes(), default_ctx));
+
+      // infer shapes
+      g = exec::InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+      // infer dtypes
+      g = exec::InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+      // infer stypes
+      g = exec::InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
     }
     // set args/aux as attributes on graph so that subgraph property can use them
     std::vector<std::string> arg_names = sym->ListInputNames(nnvm::Symbol::kReadOnlyArgs);
diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index ff1c477..d69e0c5 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -369,78 +369,6 @@ inline void LogInferStorage(const nnvm::Graph& g) {
   }
 }
 
-// prints a helpful message after shape inference errors in executor.
-inline void HandleInferShapeError(const size_t num_forward_inputs,
-                                  const nnvm::IndexedGraph& idx,
-                                  const mxnet::ShapeVector& inferred_shapes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const mxnet::TShape& inferred_shape = inferred_shapes[eid];
-    if (!shape_is_known(inferred_shape)) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << inferred_shape << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferShape pass cannot decide shapes for the following arguments "
-                "(-1 means unknown dimensions). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
-// prints a helpful message after type inference errors in executor.
-inline void HandleInferTypeError(const size_t num_forward_inputs,
-                                 const nnvm::IndexedGraph& idx,
-                                 const nnvm::DTypeVector& inferred_dtypes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const int inferred_dtype = inferred_dtypes[eid];
-    if (inferred_dtype == -1) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << inferred_dtype << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferType pass cannot decide dtypes for the following arguments "
-                "(-1 means unknown dtype). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
-// prints a helpful message after storage type checking errors in executor.
-inline void HandleInferStorageTypeError(const size_t num_forward_inputs,
-                                        const nnvm::IndexedGraph& idx,
-                                        const StorageTypeVector& inferred_stypes) {
-  int cnt = 10;
-  std::ostringstream oss;
-  for (size_t i = 0; i < num_forward_inputs; ++i) {
-    const uint32_t nid = idx.input_nodes().at(i);
-    const uint32_t eid = idx.entry_id(nid, 0);
-    const int inferred_stype = inferred_stypes[eid];
-    if (inferred_stype == -1) {
-      const std::string& arg_name = idx[nid].source->attrs.name;
-      oss << arg_name << ": " << common::stype_string(inferred_stype) << ", ";
-      if (--cnt == 0) {
-        oss << "...";
-        break;
-      }
-    }
-  }
-  LOG(FATAL) << "InferStorageType pass cannot decide storage type for the following arguments "
-                "(-1 means unknown stype). Please consider providing them as inputs:\n"
-             << oss.str();
-}
-
 /*!
  * \brief If the requested ndarray's shape size is less than
  * the corresponding shared_data_array's shape size and the