You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/11/23 18:35:46 UTC

[incubator-mxnet] branch master updated: Convert symbol to numpy symbol in Symbol class (#19523)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 221344e  Convert symbol to numpy symbol in Symbol class (#19523)
221344e is described below

commit 221344e2824d5356679ea21c9f27d05be6a05dd5
Author: bgawrych <ba...@intel.com>
AuthorDate: Mon Nov 23 19:33:20 2020 +0100

    Convert symbol to numpy symbol in Symbol class (#19523)
---
 python/mxnet/gluon/block.py          | 4 ----
 python/mxnet/symbol/numpy/_symbol.py | 8 ++++++++
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 8ca4b5a..6b280e8 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1064,10 +1064,6 @@ class HybridBlock(Block):
             # Partition the graph
             out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, input_shapes, **self._backend_opts)
 
-            # convert to numpy symbol if needed
-            if _mx_npx.is_np_array():
-                out = out.as_np_ndarray()
-
             #update cached graph with partitioned graph
             if update_graph:
                 self._cached_graph = data, out
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index c630869..c1df972 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -1071,6 +1071,14 @@ class _Symbol(Symbol):
     def broadcast_like(self, *args, **kwargs):
         raise AttributeError('_Symbol object has no attribute broadcast_like')
 
+    # pylint: disable=too-many-arguments
+    def optimize_for(self, backend, args=None, aux=None, ctx=None,
+                     shape_dict=None, type_dict=None, stype_dict=None, skip_infer=False, **kwargs):
+        """Partitions current symbol and optimizes it for a given backend."""
+        new_sym = super().optimize_for(backend, args, aux, ctx, shape_dict, type_dict,
+                                       stype_dict, skip_infer, **kwargs)
+        new_sym = new_sym.as_np_ndarray()
+        return new_sym
 
 @set_module('mxnet.symbol.numpy')
 def zeros(shape, dtype=float, order='C', ctx=None):