You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by th...@apache.org on 2019/02/14 23:42:16 UTC

[incubator-mxnet] branch master updated: Add pixelshuffle layers (#13571)

This is an automated email from the ASF dual-hosted git repository.

thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new de5cda2  Add pixelshuffle layers (#13571)
de5cda2 is described below

commit de5cda27145d8f5fafda06446d060dcbbb3d3872
Author: Holger Kohr <ho...@zoho.com>
AuthorDate: Fri Feb 15 00:41:52 2019 +0100

    Add pixelshuffle layers (#13571)
    
    * Add pixelshuffle layers, closes #13548
    
    * Remove fmt comments
    
    * Use explicit class in super()
    
    * Add axis swapping to pixel shuffling, add tests
    
    * Add documentation to pixel shuffle layers
    
    * Use pixelshuffle layer and fix download in superres example
    
    * Add pixelshuffle layers to API doc page
---
 docs/api/python/gluon/contrib.md                   |   3 +
 example/gluon/super_resolution/super_resolution.py | 144 ++++++++++------
 python/mxnet/gluon/contrib/nn/__init__.py          |   2 +-
 python/mxnet/gluon/contrib/nn/basic_layers.py      | 181 ++++++++++++++++++++-
 tests/python/unittest/test_gluon_contrib.py        |  87 +++++++++-
 5 files changed, 362 insertions(+), 55 deletions(-)

diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md
index b893d58..790f6b4 100644
--- a/docs/api/python/gluon/contrib.md
+++ b/docs/api/python/gluon/contrib.md
@@ -54,6 +54,9 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
     Identity
     SparseEmbedding
     SyncBatchNorm
+    PixelShuffle1D
+    PixelShuffle2D
+    PixelShuffle3D
 ```
 
 ### Recurrent neural network
diff --git a/example/gluon/super_resolution/super_resolution.py b/example/gluon/super_resolution/super_resolution.py
index 0f2f21f..198f6fe 100644
--- a/example/gluon/super_resolution/super_resolution.py
+++ b/example/gluon/super_resolution/super_resolution.py
@@ -16,19 +16,27 @@
 # under the License.
 
 from __future__ import print_function
-import argparse, tarfile
+
+import argparse
 import math
 import os
+import shutil
+import sys
+import zipfile
+from os import path
+
 import numpy as np
 
 import mxnet as mx
-import mxnet.ndarray as F
-from mxnet import gluon
+from mxnet import gluon, autograd as ag
 from mxnet.gluon import nn
-from mxnet import autograd as ag
-from mxnet.test_utils import download
+from mxnet.gluon.contrib import nn as contrib_nn
 from mxnet.image import CenterCropAug, ResizeAug
 from mxnet.io import PrefetchingIter
+from mxnet.test_utils import download
+
+this_dir = path.abspath(path.dirname(__file__))
+sys.path.append(path.join(this_dir, path.pardir))
 
 from data import ImagePairIter
 
@@ -51,19 +59,45 @@ upscale_factor = opt.upscale_factor
 batch_size, test_batch_size = opt.batch_size, opt.test_batch_size
 color_flag = 0
 
-# get data
-dataset_path = "dataset"
-dataset_url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
-def get_dataset(prefetch=False):
-    image_path = os.path.join(dataset_path, "BSDS300/images")
+# Get data
+datasets_dir = path.expanduser(path.join("~", ".mxnet", "datasets"))
+datasets_tmpdir = path.join(datasets_dir, "tmp")
+dataset_url = "https://github.com/BIDS/BSDS500/archive/master.zip"
+data_dir = path.expanduser(path.join(datasets_dir, "BSDS500"))
+tmp_dir = path.join(data_dir, "tmp")
 
-    if not os.path.exists(image_path):
-        os.makedirs(dataset_path)
-        file_name = download(dataset_url)
-        with tarfile.open(file_name) as tar:
-            for item in tar:
-                tar.extract(item, dataset_path)
-        os.remove(file_name)
+def get_dataset(prefetch=False):
+    """Download the BSDS500 dataset and return train and test iters."""
+
+    if path.exists(data_dir):
+        print(
+            "Directory {} already exists, skipping.\n"
+            "To force download and extraction, delete the directory and re-run."
+            "".format(data_dir),
+            file=sys.stderr,
+        )
+    else:
+        print("Downloading dataset...", file=sys.stderr)
+        downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
+        print("done", file=sys.stderr)
+
+        print("Extracting files...", end="", file=sys.stderr)
+        os.makedirs(data_dir)
+        os.makedirs(tmp_dir)
+        with zipfile.ZipFile(downloaded_file) as archive:
+            archive.extractall(tmp_dir)
+        shutil.rmtree(datasets_tmpdir)
+
+        shutil.copytree(
+            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
+            path.join(data_dir, "images"),
+        )
+        shutil.copytree(
+            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "groundTruth"),
+            path.join(data_dir, "groundTruth"),
+        )
+        shutil.rmtree(tmp_dir)
+        print("done", file=sys.stderr)
 
     crop_size = 256
     crop_size -= crop_size % upscale_factor
@@ -72,15 +106,26 @@ def get_dataset(prefetch=False):
     input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)]
     target_transform = [CenterCropAug((crop_size, crop_size))]
 
-    iters = (ImagePairIter(os.path.join(image_path, "train"),
-                           (input_crop_size, input_crop_size),
-                           (crop_size, crop_size),
-                           batch_size, color_flag, input_transform, target_transform),
-             ImagePairIter(os.path.join(image_path, "test"),
-                           (input_crop_size, input_crop_size),
-                           (crop_size, crop_size),
-                           test_batch_size, color_flag,
-                           input_transform, target_transform))
+    iters = (
+        ImagePairIter(
+            path.join(data_dir, "images", "train"),
+            (input_crop_size, input_crop_size),
+            (crop_size, crop_size),
+            batch_size,
+            color_flag,
+            input_transform,
+            target_transform,
+        ),
+        ImagePairIter(
+            path.join(data_dir, "images", "test"),
+            (input_crop_size, input_crop_size),
+            (crop_size, crop_size),
+            test_batch_size,
+            color_flag,
+            input_transform,
+            target_transform,
+        ),
+    )
 
     return [PrefetchingIter(i) for i in iters] if prefetch else iters
 
@@ -90,33 +135,23 @@ mx.random.seed(opt.seed)
 ctx = [mx.gpu(0)] if opt.use_gpu else [mx.cpu()]
 
 
-# define model
-def _rearrange(raw, F, upscale_factor):
-    # (N, C * r^2, H, W) -> (N, C, r^2, H, W)
-    splitted = F.reshape(raw, shape=(0, -4, -1, upscale_factor**2, 0, 0))
-    # (N, C, r^2, H, W) -> (N, C, r, r, H, W)
-    unflatten = F.reshape(splitted, shape=(0, 0, -4, upscale_factor, upscale_factor, 0, 0))
-    # (N, C, r, r, H, W) -> (N, C, H, r, W, r)
-    swapped = F.transpose(unflatten, axes=(0, 1, 4, 2, 5, 3))
-    # (N, C, H, r, W, r) -> (N, C, H*r, W*r)
-    return F.reshape(swapped, shape=(0, 0, -3, -3))
-
-
-class SuperResolutionNet(gluon.Block):
+class SuperResolutionNet(gluon.HybridBlock):
     def __init__(self, upscale_factor):
         super(SuperResolutionNet, self).__init__()
         with self.name_scope():
-            self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2))
-            self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1))
-            self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1))
+            self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2), activation='relu')
+            self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
+            self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
             self.conv4 = nn.Conv2D(upscale_factor ** 2, (3, 3), strides=(1, 1), padding=(1, 1))
-        self.upscale_factor = upscale_factor
+            self.pxshuf = contrib_nn.PixelShuffle2D(upscale_factor)
 
-    def forward(self, x):
-        x = F.Activation(self.conv1(x), act_type='relu')
-        x = F.Activation(self.conv2(x), act_type='relu')
-        x = F.Activation(self.conv3(x), act_type='relu')
-        return _rearrange(self.conv4(x), F, self.upscale_factor)
+    def hybrid_forward(self, F, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.conv3(x)
+        x = self.conv4(x)
+        x = self.pxshuf(x)
+        return x
 
 net = SuperResolutionNet(upscale_factor)
 metric = mx.metric.MSE()
@@ -136,7 +171,7 @@ def test(ctx):
         avg_psnr += 10 * math.log10(1/metric.get()[1])
         metric.reset()
     avg_psnr /= batches
-    print('validation avg psnr: %f'%avg_psnr)
+    print('validation avg psnr: %f' % avg_psnr)
 
 
 def train(epoch, ctx):
@@ -168,13 +203,18 @@ def train(epoch, ctx):
         print('training mse at epoch %d: %s=%f'%(i, name, acc))
         test(ctx)
 
-    net.save_parameters('superres.params')
+    net.save_parameters(path.join(this_dir, 'superres.params'))
 
 def resolve(ctx):
     from PIL import Image
+
     if isinstance(ctx, list):
         ctx = [ctx[0]]
-    net.load_parameters('superres.params', ctx=ctx)
+
+    img_basename = path.splitext(path.basename(opt.resolve_img))[0]
+    img_dirname = path.dirname(opt.resolve_img)
+
+    net.load_parameters(path.join(this_dir, 'superres.params'), ctx=ctx)
     img = Image.open(opt.resolve_img).convert('YCbCr')
     y, cb, cr = img.split()
     data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0)
@@ -186,7 +226,7 @@ def resolve(ctx):
     out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
     out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
 
-    out_img.save('resolved.png')
+    out_img.save(path.join(img_dirname, '{}-resolved.png'.format(img_basename)))
 
 if opt.resolve_img:
     resolve(ctx)
diff --git a/python/mxnet/gluon/contrib/nn/__init__.py b/python/mxnet/gluon/contrib/nn/__init__.py
index 62440cd..5eb46f6 100644
--- a/python/mxnet/gluon/contrib/nn/__init__.py
+++ b/python/mxnet/gluon/contrib/nn/__init__.py
@@ -17,7 +17,7 @@
 
 # coding: utf-8
 # pylint: disable=wildcard-import
-"""Contrib recurrent neural network module."""
+"""Contributed neural network modules."""
 
 from . import basic_layers
 
diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py
index 56f0809..ebe136e 100644
--- a/python/mxnet/gluon/contrib/nn/basic_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -18,8 +18,10 @@
 # coding: utf-8
 # pylint: disable= arguments-differ
 """Custom neural network layers in model_zoo."""
+
 __all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding',
-           'SyncBatchNorm']
+           'SyncBatchNorm', 'PixelShuffle1D', 'PixelShuffle2D',
+           'PixelShuffle3D']
 
 import warnings
 from .... import nd, test_utils
@@ -238,3 +240,180 @@ class SyncBatchNorm(BatchNorm):
     def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
         return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var,
                                        name='fwd', **self._kwargs)
+
+class PixelShuffle1D(HybridBlock):
+
+    r"""Pixel-shuffle layer for upsampling in 1 dimension.
+
+    Pixel-shuffling is the operation of taking groups of values along
+    the *channel* dimension and regrouping them into blocks of pixels
+    along the ``W`` dimension, thereby effectively multiplying that dimension
+    by a constant factor in size.
+
+    For example, a feature map of shape :math:`(fC, W)` is reshaped
+    into :math:`(C, fW)` by forming little value groups of size :math:`f`
+    and arranging them in a grid of size :math:`W`.
+
+    Parameters
+    ----------
+    factor : int or 1-tuple of int
+        Upsampling factor, applied to the ``W`` dimension.
+
+    Inputs:
+        - **data**: Tensor of shape ``(N, f*C, W)``.
+    Outputs:
+        - **out**: Tensor of shape ``(N, C, W*f)``.
+
+    Examples
+    --------
+    >>> pxshuf = PixelShuffle1D(2)
+    >>> x = mx.nd.zeros((1, 8, 3))
+    >>> pxshuf(x).shape
+    (1, 4, 6)
+    """
+
+    def __init__(self, factor):
+        super(PixelShuffle1D, self).__init__()
+        self._factor = int(factor)
+
+    def hybrid_forward(self, F, x):
+        """Perform pixel-shuffling on the input."""
+        f = self._factor
+                                             # (N, C*f, W)
+        x = F.reshape(x, (0, -4, -1, f, 0))  # (N, C, f, W)
+        x = F.transpose(x, (0, 1, 3, 2))     # (N, C, W, f)
+        x = F.reshape(x, (0, 0, -3))         # (N, C, W*f)
+        return x
+
+    def __repr__(self):
+        return "{}({})".format(self.__class__.__name__, self._factor)
+
+
+class PixelShuffle2D(HybridBlock):
+
+    r"""Pixel-shuffle layer for upsampling in 2 dimensions.
+
+    Pixel-shuffling is the operation of taking groups of values along
+    the *channel* dimension and regrouping them into blocks of pixels
+    along the ``H`` and ``W`` dimensions, thereby effectively multiplying
+    those dimensions by a constant factor in size.
+
+    For example, a feature map of shape :math:`(f^2 C, H, W)` is reshaped
+    into :math:`(C, fH, fW)` by forming little :math:`f \times f` blocks
+    of pixels and arranging them in an :math:`H \times W` grid.
+
+    Pixel-shuffling together with regular convolution is an alternative,
+    learnable way of upsampling an image by arbitrary factors. It is reported
+    to help overcome checkerboard artifacts that are common in upsampling with
+    transposed convolutions (also called deconvolutions). See the paper
+    `Real-Time Single Image and Video Super-Resolution Using an Efficient
+    Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_
+    for further details.
+
+    Parameters
+    ----------
+    factor : int or 2-tuple of int
+        Upsampling factors, applied to the ``H`` and ``W`` dimensions,
+        in that order.
+
+    Inputs:
+        - **data**: Tensor of shape ``(N, f1*f2*C, H, W)``.
+    Outputs:
+        - **out**: Tensor of shape ``(N, C, H*f1, W*f2)``.
+
+    Examples
+    --------
+    >>> pxshuf = PixelShuffle2D((2, 3))
+    >>> x = mx.nd.zeros((1, 12, 3, 5))
+    >>> pxshuf(x).shape
+    (1, 2, 6, 15)
+    """
+
+    def __init__(self, factor):
+        super(PixelShuffle2D, self).__init__()
+        try:
+            self._factors = (int(factor),) * 2
+        except TypeError:
+            self._factors = tuple(int(fac) for fac in factor)
+            assert len(self._factors) == 2, "wrong length {}".format(len(self._factors))
+
+    def hybrid_forward(self, F, x):
+        """Perform pixel-shuffling on the input."""
+        f1, f2 = self._factors
+                                                      # (N, f1*f2*C, H, W)
+        x = F.reshape(x, (0, -4, -1, f1 * f2, 0, 0))  # (N, C, f1*f2, H, W)
+        x = F.reshape(x, (0, 0, -4, f1, f2, 0, 0))    # (N, C, f1, f2, H, W)
+        x = F.transpose(x, (0, 1, 4, 2, 5, 3))        # (N, C, H, f1, W, f2)
+        x = F.reshape(x, (0, 0, -3, -3))              # (N, C, H*f1, W*f2)
+        return x
+
+    def __repr__(self):
+        return "{}({})".format(self.__class__.__name__, self._factors)
+
+
+class PixelShuffle3D(HybridBlock):
+
+    r"""Pixel-shuffle layer for upsampling in 3 dimensions.
+
+    Pixel-shuffling (or voxel-shuffling in 3D) is the operation of taking
+    groups of values along the *channel* dimension and regrouping them into
+    blocks of voxels along the ``D``, ``H`` and ``W`` dimensions, thereby
+    effectively multiplying those dimensions by a constant factor in size.
+
+    For example, a feature map of shape :math:`(f^3 C, D, H, W)` is reshaped
+    into :math:`(C, fD, fH, fW)` by forming little :math:`f \times f \times f`
+    blocks of voxels and arranging them in a :math:`D \times H \times W` grid.
+
+    Pixel-shuffling together with regular convolution is an alternative,
+    learnable way of upsampling an image by arbitrary factors. It is reported
+    to help overcome checkerboard artifacts that are common in upsampling with
+    transposed convolutions (also called deconvolutions). See the paper
+    `Real-Time Single Image and Video Super-Resolution Using an Efficient
+    Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_
+    for further details.
+
+    Parameters
+    ----------
+    factor : int or 3-tuple of int
+        Upsampling factors, applied to the ``D``, ``H`` and ``W``
+        dimensions, in that order.
+
+    Inputs:
+        - **data**: Tensor of shape ``(N, f1*f2*f3*C, D, H, W)``.
+    Outputs:
+        - **out**: Tensor of shape ``(N, C, D*f1, H*f2, W*f3)``.
+
+    Examples
+    --------
+    >>> pxshuf = PixelShuffle3D((2, 3, 4))
+    >>> x = mx.nd.zeros((1, 48, 3, 5, 7))
+    >>> pxshuf(x).shape
+    (1, 2, 6, 15, 28)
+    """
+
+    def __init__(self, factor):
+        super(PixelShuffle3D, self).__init__()
+        try:
+            self._factors = (int(factor),) * 3
+        except TypeError:
+            self._factors = tuple(int(fac) for fac in factor)
+            assert len(self._factors) == 3, "wrong length {}".format(len(self._factors))
+
+    def hybrid_forward(self, F, x):
+        """Perform pixel-shuffling on the input."""
+        # `transpose` doesn't support 8D, need other implementation
+        f1, f2, f3 = self._factors
+                                                              # (N, C*f1*f2*f3, D, H, W)
+        x = F.reshape(x, (0, -4, -1, f1 * f2 * f3, 0, 0, 0))  # (N, C, f1*f2*f3, D, H, W)
+        x = F.swapaxes(x, 2, 3)                               # (N, C, D, f1*f2*f3, H, W)
+        x = F.reshape(x, (0, 0, 0, -4, f1, f2*f3, 0, 0))      # (N, C, D, f1, f2*f3, H, W)
+        x = F.reshape(x, (0, 0, -3, 0, 0, 0))                 # (N, C, D*f1, f2*f3, H, W)
+        x = F.swapaxes(x, 3, 4)                               # (N, C, D*f1, H, f2*f3, W)
+        x = F.reshape(x, (0, 0, 0, 0, -4, f2, f3, 0))         # (N, C, D*f1, H, f2, f3, W)
+        x = F.reshape(x, (0, 0, 0, -3, 0, 0))                 # (N, C, D*f1, H*f2, f3, W)
+        x = F.swapaxes(x, 4, 5)                               # (N, C, D*f1, H*f2, W, f3)
+        x = F.reshape(x, (0, 0, 0, 0, -3))                    # (N, C, D*f1, H*f2, W*f3)
+        return x
+
+    def __repr__(self):
+        return "{}({})".format(self.__class__.__name__, self._factors)
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index a1cd8ea..6901e8b 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -19,7 +19,9 @@ from __future__ import print_function
 import mxnet as mx
 from mxnet.gluon import contrib
 from mxnet.gluon import nn
-from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding
+from mxnet.gluon.contrib.nn import (
+    Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D,
+    PixelShuffle2D, PixelShuffle3D)
 from mxnet.test_utils import almost_equal
 from common import setup_module, with_seed, teardown
 import numpy as np
@@ -204,6 +206,89 @@ def test_sparse_embedding():
     assert (layer.weight.grad().asnumpy()[:5] == 1).all()
     assert (layer.weight.grad().asnumpy()[5:] == 0).all()
 
+def test_pixelshuffle1d():
+    nchan = 2
+    up_x = 2
+    nx = 3
+    shape_before = (1, nchan * up_x, nx)
+    shape_after = (1, nchan, nx * up_x)
+    layer = PixelShuffle1D(up_x)
+    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
+    y = layer(x)
+    assert y.shape == shape_after
+    assert_allclose(
+        y.asnumpy(),
+        [[[0, 3, 1, 4, 2, 5],
+          [6, 9, 7, 10, 8, 11]]]
+    )
+
+def test_pixelshuffle2d():
+    nchan = 2
+    up_x = 2
+    up_y = 3
+    nx = 2
+    ny = 3
+    shape_before = (1, nchan * up_x * up_y, nx, ny)
+    shape_after = (1, nchan, nx * up_x, ny * up_y)
+    layer = PixelShuffle2D((up_x, up_y))
+    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
+    y = layer(x)
+    assert y.shape == shape_after
+    # - Channels are reshaped to form 2x3 blocks
+    # - Within each block, the increment is `nx * ny` when increasing the column
+    #   index by 1
+    # - Increasing the block index adds an offset of 1
+    # - Increasing the channel index adds an offset of `nx * up_x * ny * up_y`
+    assert_allclose(
+        y.asnumpy(),
+        [[[[ 0,  6, 12,  1,  7, 13,  2,  8, 14],
+           [18, 24, 30, 19, 25, 31, 20, 26, 32],
+           [ 3,  9, 15,  4, 10, 16,  5, 11, 17],
+           [21, 27, 33, 22, 28, 34, 23, 29, 35]],
+
+          [[36, 42, 48, 37, 43, 49, 38, 44, 50],
+           [54, 60, 66, 55, 61, 67, 56, 62, 68],
+           [39, 45, 51, 40, 46, 52, 41, 47, 53],
+           [57, 63, 69, 58, 64, 70, 59, 65, 71]]]]
+    )
+
+def test_pixelshuffle3d():
+    nchan = 1
+    up_x = 2
+    up_y = 1
+    up_z = 2
+    nx = 2
+    ny = 3
+    nz = 4
+    shape_before = (1, nchan * up_x * up_y * up_z, nx, ny, nz)
+    shape_after = (1, nchan, nx * up_x, ny * up_y, nz * up_z)
+    layer = PixelShuffle3D((up_x, up_y, up_z))
+    x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
+    y = layer(x)
+    assert y.shape == shape_after
+    # - Channels are reshaped to form 2x1x2 blocks
+    # - Within each block, the increment is `nx * ny * nz` when increasing the
+    #   column index by 1, e.g. the block [[[ 0, 24]], [[48, 72]]]
+    # - Increasing the block index adds an offset of 1
+    assert_allclose(
+        y.asnumpy(),
+        [[[[[ 0, 24,  1, 25,  2, 26,  3, 27],
+            [ 4, 28,  5, 29,  6, 30,  7, 31],
+            [ 8, 32,  9, 33, 10, 34, 11, 35]],
+
+           [[48, 72, 49, 73, 50, 74, 51, 75],
+            [52, 76, 53, 77, 54, 78, 55, 79],
+            [56, 80, 57, 81, 58, 82, 59, 83]],
+
+           [[12, 36, 13, 37, 14, 38, 15, 39],
+            [16, 40, 17, 41, 18, 42, 19, 43],
+            [20, 44, 21, 45, 22, 46, 23, 47]],
+
+           [[60, 84, 61, 85, 62, 86, 63, 87],
+            [64, 88, 65, 89, 66, 90, 67, 91],
+            [68, 92, 69, 93, 70, 94, 71, 95]]]]]
+    )
+
 def test_datasets():
     wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train')
     wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation',