You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/11/23 18:35:46 UTC
[incubator-mxnet] branch master updated: Convert symbol to numpy
symbol in Symbol class (#19523)
This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 221344e Convert symbol to numpy symbol in Symbol class (#19523)
221344e is described below
commit 221344e2824d5356679ea21c9f27d05be6a05dd5
Author: bgawrych <ba...@intel.com>
AuthorDate: Mon Nov 23 19:33:20 2020 +0100
Convert symbol to numpy symbol in Symbol class (#19523)
---
python/mxnet/gluon/block.py | 4 ----
python/mxnet/symbol/numpy/_symbol.py | 8 ++++++++
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 8ca4b5a..6b280e8 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1064,10 +1064,6 @@ class HybridBlock(Block):
# Partition the graph
out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, input_shapes, **self._backend_opts)
- # convert to numpy symbol if needed
- if _mx_npx.is_np_array():
- out = out.as_np_ndarray()
-
#update cached graph with partitioned graph
if update_graph:
self._cached_graph = data, out
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index c630869..c1df972 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -1071,6 +1071,14 @@ class _Symbol(Symbol):
def broadcast_like(self, *args, **kwargs):
raise AttributeError('_Symbol object has no attribute broadcast_like')
+ # pylint: disable=too-many-arguments
+ def optimize_for(self, backend, args=None, aux=None, ctx=None,
+ shape_dict=None, type_dict=None, stype_dict=None, skip_infer=False, **kwargs):
+ """Partitions current symbol and optimizes it for a given backend."""
+ new_sym = super().optimize_for(backend, args, aux, ctx, shape_dict, type_dict,
+ stype_dict, skip_infer, **kwargs)
+ new_sym = new_sym.as_np_ndarray()
+ return new_sym
@set_module('mxnet.symbol.numpy')
def zeros(shape, dtype=float, order='C', ctx=None):