You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by th...@apache.org on 2019/02/14 23:42:16 UTC
[incubator-mxnet] branch master updated: Add pixelshuffle layers
(#13571)
This is an automated email from the ASF dual-hosted git repository.
thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new de5cda2 Add pixelshuffle layers (#13571)
de5cda2 is described below
commit de5cda27145d8f5fafda06446d060dcbbb3d3872
Author: Holger Kohr <ho...@zoho.com>
AuthorDate: Fri Feb 15 00:41:52 2019 +0100
Add pixelshuffle layers (#13571)
* Add pixelshuffle layers, closes #13548
* Remove fmt comments
* Use explicit class in super()
* Add axis swapping to pixel shuffling, add tests
* Add documentation to pixel shuffle layers
* Use pixelshuffle layer and fix download in superres example
* Add pixelshuffle layers to API doc page
---
docs/api/python/gluon/contrib.md | 3 +
example/gluon/super_resolution/super_resolution.py | 144 ++++++++++------
python/mxnet/gluon/contrib/nn/__init__.py | 2 +-
python/mxnet/gluon/contrib/nn/basic_layers.py | 181 ++++++++++++++++++++-
tests/python/unittest/test_gluon_contrib.py | 87 +++++++++-
5 files changed, 362 insertions(+), 55 deletions(-)
diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md
index b893d58..790f6b4 100644
--- a/docs/api/python/gluon/contrib.md
+++ b/docs/api/python/gluon/contrib.md
@@ -54,6 +54,9 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
Identity
SparseEmbedding
SyncBatchNorm
+ PixelShuffle1D
+ PixelShuffle2D
+ PixelShuffle3D
```
### Recurrent neural network
diff --git a/example/gluon/super_resolution/super_resolution.py b/example/gluon/super_resolution/super_resolution.py
index 0f2f21f..198f6fe 100644
--- a/example/gluon/super_resolution/super_resolution.py
+++ b/example/gluon/super_resolution/super_resolution.py
@@ -16,19 +16,27 @@
# under the License.
from __future__ import print_function
-import argparse, tarfile
+
+import argparse
import math
import os
+import shutil
+import sys
+import zipfile
+from os import path
+
import numpy as np
import mxnet as mx
-import mxnet.ndarray as F
-from mxnet import gluon
+from mxnet import gluon, autograd as ag
from mxnet.gluon import nn
-from mxnet import autograd as ag
-from mxnet.test_utils import download
+from mxnet.gluon.contrib import nn as contrib_nn
from mxnet.image import CenterCropAug, ResizeAug
from mxnet.io import PrefetchingIter
+from mxnet.test_utils import download
+
+this_dir = path.abspath(path.dirname(__file__))
+sys.path.append(path.join(this_dir, path.pardir))
from data import ImagePairIter
@@ -51,19 +59,45 @@ upscale_factor = opt.upscale_factor
batch_size, test_batch_size = opt.batch_size, opt.test_batch_size
color_flag = 0
-# get data
-dataset_path = "dataset"
-dataset_url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
-def get_dataset(prefetch=False):
- image_path = os.path.join(dataset_path, "BSDS300/images")
+# Get data
+datasets_dir = path.expanduser(path.join("~", ".mxnet", "datasets"))
+datasets_tmpdir = path.join(datasets_dir, "tmp")
+dataset_url = "https://github.com/BIDS/BSDS500/archive/master.zip"
+data_dir = path.expanduser(path.join(datasets_dir, "BSDS500"))
+tmp_dir = path.join(data_dir, "tmp")
- if not os.path.exists(image_path):
- os.makedirs(dataset_path)
- file_name = download(dataset_url)
- with tarfile.open(file_name) as tar:
- for item in tar:
- tar.extract(item, dataset_path)
- os.remove(file_name)
+def get_dataset(prefetch=False):
+ """Download the BSDS500 dataset and return train and test iters."""
+
+ if path.exists(data_dir):
+ print(
+ "Directory {} already exists, skipping.\n"
+ "To force download and extraction, delete the directory and re-run."
+ "".format(data_dir),
+ file=sys.stderr,
+ )
+ else:
+ print("Downloading dataset...", file=sys.stderr)
+ downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
+ print("done", file=sys.stderr)
+
+ print("Extracting files...", end="", file=sys.stderr)
+ os.makedirs(data_dir)
+ os.makedirs(tmp_dir)
+ with zipfile.ZipFile(downloaded_file) as archive:
+ archive.extractall(tmp_dir)
+ shutil.rmtree(datasets_tmpdir)
+
+ shutil.copytree(
+ path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
+ path.join(data_dir, "images"),
+ )
+ shutil.copytree(
+ path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "groundTruth"),
+ path.join(data_dir, "groundTruth"),
+ )
+ shutil.rmtree(tmp_dir)
+ print("done", file=sys.stderr)
crop_size = 256
crop_size -= crop_size % upscale_factor
@@ -72,15 +106,26 @@ def get_dataset(prefetch=False):
input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)]
target_transform = [CenterCropAug((crop_size, crop_size))]
- iters = (ImagePairIter(os.path.join(image_path, "train"),
- (input_crop_size, input_crop_size),
- (crop_size, crop_size),
- batch_size, color_flag, input_transform, target_transform),
- ImagePairIter(os.path.join(image_path, "test"),
- (input_crop_size, input_crop_size),
- (crop_size, crop_size),
- test_batch_size, color_flag,
- input_transform, target_transform))
+ iters = (
+ ImagePairIter(
+ path.join(data_dir, "images", "train"),
+ (input_crop_size, input_crop_size),
+ (crop_size, crop_size),
+ batch_size,
+ color_flag,
+ input_transform,
+ target_transform,
+ ),
+ ImagePairIter(
+ path.join(data_dir, "images", "test"),
+ (input_crop_size, input_crop_size),
+ (crop_size, crop_size),
+ test_batch_size,
+ color_flag,
+ input_transform,
+ target_transform,
+ ),
+ )
return [PrefetchingIter(i) for i in iters] if prefetch else iters
@@ -90,33 +135,23 @@ mx.random.seed(opt.seed)
ctx = [mx.gpu(0)] if opt.use_gpu else [mx.cpu()]
-# define model
-def _rearrange(raw, F, upscale_factor):
- # (N, C * r^2, H, W) -> (N, C, r^2, H, W)
- splitted = F.reshape(raw, shape=(0, -4, -1, upscale_factor**2, 0, 0))
- # (N, C, r^2, H, W) -> (N, C, r, r, H, W)
- unflatten = F.reshape(splitted, shape=(0, 0, -4, upscale_factor, upscale_factor, 0, 0))
- # (N, C, r, r, H, W) -> (N, C, H, r, W, r)
- swapped = F.transpose(unflatten, axes=(0, 1, 4, 2, 5, 3))
- # (N, C, H, r, W, r) -> (N, C, H*r, W*r)
- return F.reshape(swapped, shape=(0, 0, -3, -3))
-
-
-class SuperResolutionNet(gluon.Block):
+class SuperResolutionNet(gluon.HybridBlock):
def __init__(self, upscale_factor):
super(SuperResolutionNet, self).__init__()
with self.name_scope():
- self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2))
- self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1))
- self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1))
+ self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2), activation='relu')
+ self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
+ self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
self.conv4 = nn.Conv2D(upscale_factor ** 2, (3, 3), strides=(1, 1), padding=(1, 1))
- self.upscale_factor = upscale_factor
+ self.pxshuf = contrib_nn.PixelShuffle2D(upscale_factor)
- def forward(self, x):
- x = F.Activation(self.conv1(x), act_type='relu')
- x = F.Activation(self.conv2(x), act_type='relu')
- x = F.Activation(self.conv3(x), act_type='relu')
- return _rearrange(self.conv4(x), F, self.upscale_factor)
+ def hybrid_forward(self, F, x):
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+ x = self.pxshuf(x)
+ return x
net = SuperResolutionNet(upscale_factor)
metric = mx.metric.MSE()
@@ -136,7 +171,7 @@ def test(ctx):
avg_psnr += 10 * math.log10(1/metric.get()[1])
metric.reset()
avg_psnr /= batches
- print('validation avg psnr: %f'%avg_psnr)
+ print('validation avg psnr: %f' % avg_psnr)
def train(epoch, ctx):
@@ -168,13 +203,18 @@ def train(epoch, ctx):
print('training mse at epoch %d: %s=%f'%(i, name, acc))
test(ctx)
- net.save_parameters('superres.params')
+ net.save_parameters(path.join(this_dir, 'superres.params'))
def resolve(ctx):
from PIL import Image
+
if isinstance(ctx, list):
ctx = [ctx[0]]
- net.load_parameters('superres.params', ctx=ctx)
+
+ img_basename = path.splitext(path.basename(opt.resolve_img))[0]
+ img_dirname = path.dirname(opt.resolve_img)
+
+ net.load_parameters(path.join(this_dir, 'superres.params'), ctx=ctx)
img = Image.open(opt.resolve_img).convert('YCbCr')
y, cb, cr = img.split()
data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0)
@@ -186,7 +226,7 @@ def resolve(ctx):
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
- out_img.save('resolved.png')
+ out_img.save(path.join(img_dirname, '{}-resolved.png'.format(img_basename)))
if opt.resolve_img:
resolve(ctx)
diff --git a/python/mxnet/gluon/contrib/nn/__init__.py b/python/mxnet/gluon/contrib/nn/__init__.py
index 62440cd..5eb46f6 100644
--- a/python/mxnet/gluon/contrib/nn/__init__.py
+++ b/python/mxnet/gluon/contrib/nn/__init__.py
@@ -17,7 +17,7 @@
# coding: utf-8
# pylint: disable=wildcard-import
-"""Contrib recurrent neural network module."""
+"""Contributed neural network modules."""
from . import basic_layers
diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py
index 56f0809..ebe136e 100644
--- a/python/mxnet/gluon/contrib/nn/basic_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -18,8 +18,10 @@
# coding: utf-8
# pylint: disable= arguments-differ
"""Custom neural network layers in model_zoo."""
+
__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding',
- 'SyncBatchNorm']
+ 'SyncBatchNorm', 'PixelShuffle1D', 'PixelShuffle2D',
+ 'PixelShuffle3D']
import warnings
from .... import nd, test_utils
@@ -238,3 +240,180 @@ class SyncBatchNorm(BatchNorm):
def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var,
name='fwd', **self._kwargs)
+
+class PixelShuffle1D(HybridBlock):
+
+ r"""Pixel-shuffle layer for upsampling in 1 dimension.
+
+ Pixel-shuffling is the operation of taking groups of values along
+ the *channel* dimension and regrouping them into blocks of pixels
+ along the ``W`` dimension, thereby effectively multiplying that dimension
+ by a constant factor in size.
+
+ For example, a feature map of shape :math:`(fC, W)` is reshaped
+ into :math:`(C, fW)` by forming little value groups of size :math:`f`
+ and arranging them in a grid of size :math:`W`.
+
+ Parameters
+ ----------
+ factor : int or 1-tuple of int
+ Upsampling factor, applied to the ``W`` dimension.
+
+ Inputs:
+ - **data**: Tensor of shape ``(N, f*C, W)``.
+ Outputs:
+ - **out**: Tensor of shape ``(N, C, W*f)``.
+
+ Examples
+ --------
+ >>> pxshuf = PixelShuffle1D(2)
+ >>> x = mx.nd.zeros((1, 8, 3))
+ >>> pxshuf(x).shape
+ (1, 4, 6)
+ """
+
+ def __init__(self, factor):
+ super(PixelShuffle1D, self).__init__()
+ self._factor = int(factor)
+
+ def hybrid_forward(self, F, x):
+ """Perform pixel-shuffling on the input."""
+ f = self._factor
+ # (N, C*f, W)
+ x = F.reshape(x, (0, -4, -1, f, 0)) # (N, C, f, W)
+ x = F.transpose(x, (0, 1, 3, 2)) # (N, C, W, f)
+ x = F.reshape(x, (0, 0, -3)) # (N, C, W*f)
+ return x
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, self._factor)
+
+
+class PixelShuffle2D(HybridBlock):
+
+ r"""Pixel-shuffle layer for upsampling in 2 dimensions.
+
+ Pixel-shuffling is the operation of taking groups of values along
+ the *channel* dimension and regrouping them into blocks of pixels
+ along the ``H`` and ``W`` dimensions, thereby effectively multiplying
+ those dimensions by a constant factor in size.
+
+ For example, a feature map of shape :math:`(f^2 C, H, W)` is reshaped
+ into :math:`(C, fH, fW)` by forming little :math:`f \times f` blocks
+ of pixels and arranging them in an :math:`H \times W` grid.
+
+ Pixel-shuffling together with regular convolution is an alternative,
+ learnable way of upsampling an image by arbitrary factors. It is reported
+ to help overcome checkerboard artifacts that are common in upsampling with
+ transposed convolutions (also called deconvolutions). See the paper
+ `Real-Time Single Image and Video Super-Resolution Using an Efficient
+ Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_
+ for further details.
+
+ Parameters
+ ----------
+ factor : int or 2-tuple of int
+ Upsampling factors, applied to the ``H`` and ``W`` dimensions,
+ in that order.
+
+ Inputs:
+ - **data**: Tensor of shape ``(N, f1*f2*C, H, W)``.
+ Outputs:
+ - **out**: Tensor of shape ``(N, C, H*f1, W*f2)``.
+
+ Examples
+ --------
+ >>> pxshuf = PixelShuffle2D((2, 3))
+ >>> x = mx.nd.zeros((1, 12, 3, 5))
+ >>> pxshuf(x).shape
+ (1, 2, 6, 15)
+ """
+
+ def __init__(self, factor):
+ super(PixelShuffle2D, self).__init__()
+ try:
+ self._factors = (int(factor),) * 2
+ except TypeError:
+ self._factors = tuple(int(fac) for fac in factor)
+ assert len(self._factors) == 2, "wrong length {}".format(len(self._factors))
+
+ def hybrid_forward(self, F, x):
+ """Perform pixel-shuffling on the input."""
+ f1, f2 = self._factors
+ # (N, f1*f2*C, H, W)
+ x = F.reshape(x, (0, -4, -1, f1 * f2, 0, 0)) # (N, C, f1*f2, H, W)
+ x = F.reshape(x, (0, 0, -4, f1, f2, 0, 0)) # (N, C, f1, f2, H, W)
+ x = F.transpose(x, (0, 1, 4, 2, 5, 3)) # (N, C, H, f1, W, f2)
+ x = F.reshape(x, (0, 0, -3, -3)) # (N, C, H*f1, W*f2)
+ return x
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, self._factors)
+
+
+class PixelShuffle3D(HybridBlock):
+
+ r"""Pixel-shuffle layer for upsampling in 3 dimensions.
+
+ Pixel-shuffling (or voxel-shuffling in 3D) is the operation of taking
+ groups of values along the *channel* dimension and regrouping them into
+ blocks of voxels along the ``D``, ``H`` and ``W`` dimensions, thereby
+ effectively multiplying those dimensions by a constant factor in size.
+
+ For example, a feature map of shape :math:`(f^3 C, D, H, W)` is reshaped
+ into :math:`(C, fD, fH, fW)` by forming little :math:`f \times f \times f`
+ blocks of voxels and arranging them in a :math:`D \times H \times W` grid.
+
+ Pixel-shuffling together with regular convolution is an alternative,
+ learnable way of upsampling an image by arbitrary factors. It is reported
+ to help overcome checkerboard artifacts that are common in upsampling with
+ transposed convolutions (also called deconvolutions). See the paper
+ `Real-Time Single Image and Video Super-Resolution Using an Efficient
+ Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158>`_
+ for further details.
+
+ Parameters
+ ----------
+ factor : int or 3-tuple of int
+ Upsampling factors, applied to the ``D``, ``H`` and ``W``
+ dimensions, in that order.
+
+ Inputs:
+ - **data**: Tensor of shape ``(N, f1*f2*f3*C, D, H, W)``.
+ Outputs:
+ - **out**: Tensor of shape ``(N, C, D*f1, H*f2, W*f3)``.
+
+ Examples
+ --------
+ >>> pxshuf = PixelShuffle3D((2, 3, 4))
+ >>> x = mx.nd.zeros((1, 48, 3, 5, 7))
+ >>> pxshuf(x).shape
+ (1, 2, 6, 15, 28)
+ """
+
+ def __init__(self, factor):
+ super(PixelShuffle3D, self).__init__()
+ try:
+ self._factors = (int(factor),) * 3
+ except TypeError:
+ self._factors = tuple(int(fac) for fac in factor)
+ assert len(self._factors) == 3, "wrong length {}".format(len(self._factors))
+
+ def hybrid_forward(self, F, x):
+ """Perform pixel-shuffling on the input."""
+ # `transpose` doesn't support 8D, need other implementation
+ f1, f2, f3 = self._factors
+ # (N, C*f1*f2*f3, D, H, W)
+ x = F.reshape(x, (0, -4, -1, f1 * f2 * f3, 0, 0, 0)) # (N, C, f1*f2*f3, D, H, W)
+ x = F.swapaxes(x, 2, 3) # (N, C, D, f1*f2*f3, H, W)
+ x = F.reshape(x, (0, 0, 0, -4, f1, f2*f3, 0, 0)) # (N, C, D, f1, f2*f3, H, W)
+ x = F.reshape(x, (0, 0, -3, 0, 0, 0)) # (N, C, D*f1, f2*f3, H, W)
+ x = F.swapaxes(x, 3, 4) # (N, C, D*f1, H, f2*f3, W)
+ x = F.reshape(x, (0, 0, 0, 0, -4, f2, f3, 0)) # (N, C, D*f1, H, f2, f3, W)
+ x = F.reshape(x, (0, 0, 0, -3, 0, 0)) # (N, C, D*f1, H*f2, f3, W)
+ x = F.swapaxes(x, 4, 5) # (N, C, D*f1, H*f2, W, f3)
+ x = F.reshape(x, (0, 0, 0, 0, -3)) # (N, C, D*f1, H*f2, W*f3)
+ return x
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, self._factors)
diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py
index a1cd8ea..6901e8b 100644
--- a/tests/python/unittest/test_gluon_contrib.py
+++ b/tests/python/unittest/test_gluon_contrib.py
@@ -19,7 +19,9 @@ from __future__ import print_function
import mxnet as mx
from mxnet.gluon import contrib
from mxnet.gluon import nn
-from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding
+from mxnet.gluon.contrib.nn import (
+ Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D,
+ PixelShuffle2D, PixelShuffle3D)
from mxnet.test_utils import almost_equal
from common import setup_module, with_seed, teardown
import numpy as np
@@ -204,6 +206,89 @@ def test_sparse_embedding():
assert (layer.weight.grad().asnumpy()[:5] == 1).all()
assert (layer.weight.grad().asnumpy()[5:] == 0).all()
+def test_pixelshuffle1d():
+ nchan = 2
+ up_x = 2
+ nx = 3
+ shape_before = (1, nchan * up_x, nx)
+ shape_after = (1, nchan, nx * up_x)
+ layer = PixelShuffle1D(up_x)
+ x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
+ y = layer(x)
+ assert y.shape == shape_after
+ assert_allclose(
+ y.asnumpy(),
+ [[[0, 3, 1, 4, 2, 5],
+ [6, 9, 7, 10, 8, 11]]]
+ )
+
+def test_pixelshuffle2d():
+ nchan = 2
+ up_x = 2
+ up_y = 3
+ nx = 2
+ ny = 3
+ shape_before = (1, nchan * up_x * up_y, nx, ny)
+ shape_after = (1, nchan, nx * up_x, ny * up_y)
+ layer = PixelShuffle2D((up_x, up_y))
+ x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
+ y = layer(x)
+ assert y.shape == shape_after
+ # - Channels are reshaped to form 2x3 blocks
+ # - Within each block, the increment is `nx * ny` when increasing the column
+ # index by 1
+ # - Increasing the block index adds an offset of 1
+ # - Increasing the channel index adds an offset of `nx * up_x * ny * up_y`
+ assert_allclose(
+ y.asnumpy(),
+ [[[[ 0, 6, 12, 1, 7, 13, 2, 8, 14],
+ [18, 24, 30, 19, 25, 31, 20, 26, 32],
+ [ 3, 9, 15, 4, 10, 16, 5, 11, 17],
+ [21, 27, 33, 22, 28, 34, 23, 29, 35]],
+
+ [[36, 42, 48, 37, 43, 49, 38, 44, 50],
+ [54, 60, 66, 55, 61, 67, 56, 62, 68],
+ [39, 45, 51, 40, 46, 52, 41, 47, 53],
+ [57, 63, 69, 58, 64, 70, 59, 65, 71]]]]
+ )
+
+def test_pixelshuffle3d():
+ nchan = 1
+ up_x = 2
+ up_y = 1
+ up_z = 2
+ nx = 2
+ ny = 3
+ nz = 4
+ shape_before = (1, nchan * up_x * up_y * up_z, nx, ny, nz)
+ shape_after = (1, nchan, nx * up_x, ny * up_y, nz * up_z)
+ layer = PixelShuffle3D((up_x, up_y, up_z))
+ x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before)
+ y = layer(x)
+ assert y.shape == shape_after
+ # - Channels are reshaped to form 2x1x2 blocks
+ # - Within each block, the increment is `nx * ny * nz` when increasing the
+ # column index by 1, e.g. the block [[[ 0, 24]], [[48, 72]]]
+ # - Increasing the block index adds an offset of 1
+ assert_allclose(
+ y.asnumpy(),
+ [[[[[ 0, 24, 1, 25, 2, 26, 3, 27],
+ [ 4, 28, 5, 29, 6, 30, 7, 31],
+ [ 8, 32, 9, 33, 10, 34, 11, 35]],
+
+ [[48, 72, 49, 73, 50, 74, 51, 75],
+ [52, 76, 53, 77, 54, 78, 55, 79],
+ [56, 80, 57, 81, 58, 82, 59, 83]],
+
+ [[12, 36, 13, 37, 14, 38, 15, 39],
+ [16, 40, 17, 41, 18, 42, 19, 43],
+ [20, 44, 21, 45, 22, 46, 23, 47]],
+
+ [[60, 84, 61, 85, 62, 86, 63, 87],
+ [64, 88, 65, 89, 66, 90, 67, 91],
+ [68, 92, 69, 93, 70, 94, 71, 95]]]]]
+ )
+
def test_datasets():
wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train')
wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation',