You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2019/07/18 00:00:53 UTC

[incubator-mxnet] 17/42: Fix (#15188)

This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 7bcddcf2c2e9a93a5b0058400fb6e7cc39476408
Author: reminisce <wu...@gmail.com>
AuthorDate: Sun Jun 9 08:56:16 2019 -0700

    Fix (#15188)
---
 example/numpy/numpy_semantics.ipynb        | 308 +++++++++++++++++++++++++++++
 python/mxnet/gluon/data/dataloader.py      |  10 +-
 python/mxnet/gluon/data/vision/datasets.py |   5 +-
 python/mxnet/numpy/multiarray.py           |  19 +-
 python/mxnet/numpy_extension/__init__.py   |   7 +-
 python/mxnet/util.py                       |  47 +++--
 6 files changed, 369 insertions(+), 27 deletions(-)

diff --git a/example/numpy/numpy_semantics.ipynb b/example/numpy/numpy_semantics.ipynb
new file mode 100644
index 0000000..1cec51f
--- /dev/null
+++ b/example/numpy/numpy_semantics.ipynb
@@ -0,0 +1,308 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# How to Use NumPy Semantics in MXNet with `mxnet.numpy` Module\n",
+    "\n",
+    "## NumPy Shape Semantics\n",
+    "\n",
+    "### Example \n",
+    "\n",
+    "| Shape Example  | MXNet (before)  | MXNet/NumPy   |\n",
+    "|:---:|:---:|:---:|\n",
+    "| `()`   | unknown  | Scalar tensor   |\n",
+    "| `(2, 0, 1)` | Second dimension unknown | Zero-size tensor |\n",
+    "| `None`(Python) | N/A | Unknown |\n",
+    "| `(2, -1, 0)`(C++) | N/A | Second dim uknown|\n",
+    "\n",
+    "### Affected modules\n",
+    "- Shape inference: imperative, symbolic, Gluon\n",
+    "- Legacy operators (not recommended to use)\n",
+    "- MXNet/NumPy operators\n",
+    "\n",
+    "## NumPy Array Semantics\n",
+    "**Definition:** The type of created ndarrays is `mxnet.numpy.ndarray`/`mxnet.symbol.numpy._Symbol`, instead of `mxnet.ndarray.NDArray`/`mxnet.symbol.Symbol` (only affects Gluon modules).\n",
+    "- Block/HybridBlock\n",
+    "    - Parameter creation and initialization.\n",
+    "    - Inputs/outputs (symbol/ndarray) of `__call__`/`forward`/`hybrid_forward`.\n",
+    "    - Computational graph construction.\n",
+    "- Dataloader\n",
+    "\n",
+    "## Dependency of Two Types of Semantics\n",
+    "- It is required to keep NumPy shape semantics active while activating NumPy array semantics.\n",
+    "- Deactivating NumPy shape semantics while NumPy array semantics is still active is not allowed."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import logging\n",
+    "import mxnet as mx\n",
+    "from mxnet import np, npx, gluon\n",
+    "\n",
+    "logging.basicConfig(level=logging.INFO)\n",
+    "\n",
+    "try:\n",
+    "    npx.set_np(shape=False, array=True)\n",
+    "except ValueError as e:\n",
+    "    print(e)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## How to Enable NumPy Shape semantics"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "try:\n",
+    "    a = mx.nd.random.uniform(shape=())\n",
+    "except mx.MXNetError as e:\n",
+    "    print(e)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "try:\n",
+    "    b = mx.nd.random.uniform(shape=(2, 0, 1))\n",
+    "except mx.MXNetError as e:\n",
+    "    print(e)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "try:\n",
+    "    c = np.random.uniform()\n",
+    "except mx.MXNetError as e:\n",
+    "    print(e)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "try:\n",
+    "    d = np.random.uniform(size=(2, 0, 1))\n",
+    "except mx.MXNetError as e:\n",
+    "    print(e)  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "npx.set_np(shape=True, array=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "a = mx.nd.random.uniform(shape=())\n",
+    "b = mx.nd.random.uniform(shape=(2, 0, 1))\n",
+    "c = np.random.uniform()\n",
+    "d = np.random.uniform(size=(2, 0, 1))\n",
+    "\n",
+    "print('type(a) =', type(a))\n",
+    "print('a.shape = ', a.shape)\n",
+    "print('a.size = ', a.size)\n",
+    "\n",
+    "print('type(b) =', type(b))\n",
+    "print('b.shape = ', b.shape)\n",
+    "print('b.size = ', b.size)\n",
+    "\n",
+    "print('type(c) =', type(c))\n",
+    "print('c.shape = ', c.shape)\n",
+    "print('c.size = ', c.size)\n",
+    "\n",
+    "print('type(d) =', type(d))\n",
+    "print('d.shape = ', d.shape)\n",
+    "print('d.size = ', d.size)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## How to Enable NumPy Array Semantics\n",
+    "\n",
+    "### Parameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "npx.reset_np()  # reset two types of semantics to the default state, which is False for both of them\n",
+    "\n",
+    "from mxnet.gluon import nn\n",
+    "class Net(gluon.Block):\n",
+    "    def __init__(self, in_units=0, **kwargs):  # 0 means in_units is unknown and must be inferred at runtime\n",
+    "        super(Net, self).__init__(**kwargs)\n",
+    "        with self.name_scope():\n",
+    "            self.dense0 = nn.Dense(5, in_units=in_units)\n",
+    "            self.dense1 = nn.Dense(5, in_units=in_units)\n",
+    "            \n",
+    "    def forward(self, x):\n",
+    "        return self.dense1(self.dense0(x))\n",
+    "\n",
+    "net1 = Net()\n",
+    "net1.initialize()\n",
+    "net1(mx.nd.zeros((3, 10)))\n",
+    "for k, v in net1.collect_params().items():\n",
+    "    print('parameter {}, type {}'.format(k, str(type(v.data()))))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "npx.set_np()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net2 = Net()\n",
+    "net2.initialize()\n",
+    "net2(np.zeros((3, 10)))\n",
+    "for k, v in net2.collect_params().items():\n",
+    "    print('parameter {}, type {}'.format(k, str(type(v.data()))))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Dataloader"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sys\n",
+    "import os\n",
+    "from mxnet.gluon import data as gdata\n",
+    "\n",
+    "\n",
+    "npx.reset_np()\n",
+    "\n",
+    "\n",
+    "def load_data_fashion_mnist(batch_size, resize=None, root=os.path.join(\n",
+    "        '~', '.mxnet', 'datasets', 'fashion-mnist')):\n",
+    "    \"\"\"Download the Fashion-MNIST dataset and then load into memory.\"\"\"\n",
+    "    root = os.path.expanduser(root)\n",
+    "    transformer = []\n",
+    "    if resize:\n",
+    "        transformer += [gdata.vision.transforms.Resize(resize)]\n",
+    "    transformer += [gdata.vision.transforms.ToTensor()]\n",
+    "    transformer = gdata.vision.transforms.Compose(transformer)\n",
+    "\n",
+    "    mnist_train = gdata.vision.FashionMNIST(root=root, train=True)\n",
+    "    mnist_test = gdata.vision.FashionMNIST(root=root, train=False)\n",
+    "    num_workers = 0 if sys.platform.startswith('win32') else 4\n",
+    "\n",
+    "    train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),\n",
+    "                                  batch_size, shuffle=True,\n",
+    "                                  num_workers=num_workers)\n",
+    "    test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),\n",
+    "                                 batch_size, shuffle=False,\n",
+    "                                 num_workers=num_workers)\n",
+    "    return train_iter, test_iter"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_iter, test_iter = load_data_fashion_mnist(16)\n",
+    "\n",
+    "for X, y in train_iter:\n",
+    "    print('type(X) = ', type(X))\n",
+    "    print('type(y) = ', type(y))\n",
+    "    break"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "npx.set_np()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train_iter, test_iter = load_data_fashion_mnist(16)\n",
+    "\n",
+    "for X, y in train_iter:\n",
+    "    print('type(X) = ', type(X))\n",
+    "    print('type(y) = ', type(y))\n",
+    "    break"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py
index 7e8110c..1923f65 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -39,7 +39,7 @@ except ImportError:
 from . import sampler as _sampler
 from ... import nd, context
 from ...util import is_np_array
-from ... import numpy as _mx_np  #pylint: disable=reimported
+from ... import numpy as _mx_np  # pylint: disable=reimported
 
 if sys.platform == 'darwin' or sys.platform == 'win32':
     def rebuild_ndarray(*args):
@@ -127,6 +127,7 @@ class SimpleQueue(multiprocessing.queues.SimpleQueue):
         self._send = self._writer.send
         self._recv = self._reader.recv
 
+
 def default_batchify_fn(data):
     """Collate data into batch."""
     if isinstance(data[0], nd.NDArray):
@@ -143,10 +144,10 @@ def default_batchify_fn(data):
 def default_mp_batchify_fn(data):
     """Collate data into batch. Use shared memory for stacking."""
     if isinstance(data[0], nd.NDArray):
-        out = nd.empty((len(data),) + data[0].shape, dtype=data[0].dtype,
+        empty_fn = _mx_np.empty if is_np_array() else nd.empty
+        out = empty_fn((len(data),) + data[0].shape, dtype=data[0].dtype,
                        ctx=context.Context('cpu_shared', 0))
         if is_np_array():
-            out = out.as_np_ndarray()
             return _mx_np.stack(data, out=out)
         else:
             return nd.stack(*data, out=out)
@@ -163,8 +164,7 @@ def default_mp_batchify_fn(data):
 def _as_in_context(data, ctx):
     """Move data into new context."""
     if isinstance(data, nd.NDArray):
-        out = data.as_in_context(ctx)
-        return out.as_np_ndarray() if is_np_array() else out
+        return data.as_in_context(ctx)
     elif isinstance(data, (list, tuple)):
         return [_as_in_context(d, ctx) for d in data]
     return data
diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py
index 12ef7e1..c580502 100644
--- a/python/mxnet/gluon/data/vision/datasets.py
+++ b/python/mxnet/gluon/data/vision/datasets.py
@@ -31,6 +31,8 @@ import numpy as np
 from .. import dataset
 from ...utils import download, check_sha1, _get_repo_file_url
 from .... import nd, image, recordio, base
+from .... import numpy as _mx_np  # pylint: disable=reimported
+from ....util import is_np_array
 
 
 class MNIST(dataset._DownloadedDataset):
@@ -87,7 +89,8 @@ class MNIST(dataset._DownloadedDataset):
             data = np.frombuffer(fin.read(), dtype=np.uint8)
             data = data.reshape(len(label), 28, 28, 1)
 
-        self._data = nd.array(data, dtype=data.dtype)
+        array_fn = _mx_np.array if is_np_array() else nd.array
+        self._data = array_fn(data, dtype=data.dtype)
         self._label = label
 
 
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 454b562..a4a05af 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -519,7 +519,24 @@ class ndarray(NDArray):
         return _mx_nd_np.argmax(self, axis, out)
 
     def as_in_context(self, context):
-        return super(ndarray, self).as_in_context(context).as_np_ndarray()
+        """Returns an array on the target device with the same value as this array.
+
+        If the target context is the same as ``self.context``, then ``self`` is
+        returned.  Otherwise, a copy is made.
+
+        Parameters
+        ----------
+        context : Context
+            The target context.
+
+        Returns
+        -------
+        ndarray
+            The target array.
+        """
+        if self.context == context:
+            return self
+        return self.copyto(context)
 
     def copy(self, order='C'):  # pylint: disable=arguments-differ
         if order != 'C':
diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py
index a15a1d4..e2ccaa1 100644
--- a/python/mxnet/numpy_extension/__init__.py
+++ b/python/mxnet/numpy_extension/__init__.py
@@ -24,8 +24,9 @@ from . import _op
 from . import _register
 from ._op import *  # pylint: disable=wildcard-import
 from ..context import *  # pylint: disable=wildcard-import
-from ..util import use_np_shape, np_shape, is_np_shape, set_np_shape
-from ..util import use_np_array, np_array, is_np_array, set_np_array
-from ..util import set_np, use_np
+# TODO(junwu): revisit what functions should be exposed to users
+from ..util import use_np_shape, np_shape, is_np_shape
+from ..util import use_np_array, np_array, is_np_array
+from ..util import set_np, use_np, reset_np
 
 __all__ = []
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
index 11ec16e..d4e95e0 100644
--- a/python/mxnet/util.py
+++ b/python/mxnet/util.py
@@ -79,14 +79,17 @@ def set_np_shape(active):
     >>> print(mx.is_np_shape())
     True
     """
-    # TODO(junwu): Consider uncommenting the following lines.
-    # import logging
-    # logging.info('NumPy-shape semantics has been activated in your code global scope. '
-    #              'This is required for using `mxnet.numpy` and `mxnet.numpy_extension` '
-    #              'modules as it enables creating and manipulating scalar and zero-size '
-    #              'tensors, which were not supported in MXNet before, as in the official '
-    #              'NumPy library. Please DO NOT manually deactivate this semantics while '
-    #              'using `mxnet.numpy` and `mxnet.numpy_extension` modules.')
+    if active:
+        import logging
+        logging.info('NumPy-shape semantics has been activated in your code. '
+                     'This is required for creating and manipulating scalar and zero-size '
+                     'tensors, which were not supported in MXNet before, as in the official '
+                     'NumPy library. Please DO NOT manually deactivate this semantics while '
+                     'using `mxnet.numpy` and `mxnet.numpy_extension` modules.')
+    elif is_np_array():
+        raise ValueError('Deactivating NumPy shape semantics while NumPy array semantics is still'
+                         ' active is not allowed. Please consider calling `npx.reset_np()` to'
+                         ' deactivate both of them.')
     prev = ctypes.c_int()
     check_call(_LIB.MXSetIsNumpyShape(ctypes.c_int(active), ctypes.byref(prev)))
     return bool(prev.value)
@@ -552,10 +555,10 @@ def use_np(func):
     Function or class
         A function or class wrapped in the Numpy-shape and NumPy-array scope.
     """
-    return use_np_array(use_np_shape(func))
+    return use_np_shape(use_np_array(func))
 
 
-def set_np_array(active):
+def _set_np_array(active):
     """Turns on/off NumPy array semantics for the current thread in which `mxnet.numpy.ndarray`
     is expected to be created, instead of the legacy `mx.nd.NDArray`.
 
@@ -568,13 +571,20 @@ def set_np_array(active):
     -------
         A bool value indicating the previous state of NumPy array semantics.
     """
+    if active:
+        import logging
+        logging.info('NumPy array semantics has been activated in your code. This allows you'
+                     ' to use operators from MXNet NumPy and NumPy Extension modules as well'
+                     ' as MXNet NumPy `ndarray`s.')
     cur_state = is_np_array()
     _NumpyArrayScope._current.value = _NumpyArrayScope(active)
     return cur_state
 
 
 def set_np(shape=True, array=True):
-    """A convenience function for setting NumPy shape and array semantics at the same time.
+    """Setting NumPy shape and array semantics at the same time.
+    It is required to keep NumPy shape semantics active when activating NumPy array semantics.
+    Deactivating NumPy shape semantics while NumPy array semantics is still active is not allowed.
 
     Parameters
     ----------
@@ -582,10 +592,13 @@ def set_np(shape=True, array=True):
         A boolean value indicating whether the NumPy-shape semantics should be turned on or off.
     array : bool
         A boolean value indicating whether the NumPy-array semantics should be turned on or off.
-
-    Returns
-    -------
-        A tuple with elements indicating the previous states of shape and array
-        semantics, respectively.
     """
-    return set_np_shape(shape), set_np_array(array)
+    if not shape and array:
+        raise ValueError('NumPy Shape semantics is required in using NumPy array semantics.')
+    _set_np_array(array)
+    set_np_shape(shape)
+
+
+def reset_np():
+    """Deactivate NumPy shape and array semantics at the same time."""
+    set_np(shape=False, array=False)