You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/03/27 06:03:50 UTC

[GitHub] reminisce closed pull request #10139: [REQUEST FOR REVIEW] [MXNET-109] Logging APIs for Visualizing MXNet Data in TensorBoard

reminisce closed pull request #10139: [REQUEST FOR REVIEW] [MXNET-109] Logging APIs for Visualizing MXNet Data in TensorBoard
URL: https://github.com/apache/incubator-mxnet/pull/10139
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/api/python/contrib/summary.md b/docs/api/python/contrib/summary.md
new file mode 100644
index 00000000000..478286490ef
--- /dev/null
+++ b/docs/api/python/contrib/summary.md
@@ -0,0 +1,258 @@
+# Logging MXNet Data for Visualization in TensorBoard
+
+## Overview
+
+The module `mxnet.contrib.summary` enables MXNet users to visualize data in
+[TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard). 
+Please note that this module only provides the APIs for data logging. For visualization,
+users still need to install TensorBoard.
+
+### How to install TensorBoard
+To launch TensorBoard for visualization, make sure you have the
+[official release of TensorBoard](https://pypi.python.org/pypi/tensorboard) installed.
+You can type `pip install tensorboard` on you machine to install TensorBoard.
+
+### How to launch TensorBoard
+After you installed the TensorBoar Python package, type the following command in the terminal
+to launch TensorBoard:
+```
+tensorborad --logdir=/path/to/your/log/dir --host=your_host_ip --port=your_port_number
+```
+As an example of visualizing data using the browser on your machine, you can type
+```
+tensorborad --logdir=/path/to/your/log/dir --host=127.0.0.1 --port=8888
+```
+Then in the browser, type address `127.0.0.1:8888`. Note that in some situations,
+the port number `8888` may be occupied by other applications and launching TensorBoard
+may fail. You may choose a different port number that is available in those situations.
+
+
+### How to use TensorBoard GUI for data visualization
+Please find the tutorials on
+[TensorFlow website](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) for details.
+
+### What are other required packages for using the MXNet logging APIs
+Please make sure the following Python packages have been installed before using
+the MXNet logging APIs:
+- [protobuf3](https://pypi.python.org/pypi/protobuf)
+- [six](https://pypi.python.org/pypi/six)
+- [pillow](https://pypi.python.org/pypi/Pillow)
+
+
+### What data types in TensorBoard GUI are supported by MXNet logging APIs
+We currently support the following data types that you can find on the TensorBoard GUI:
+- SCALARS
+- IMAGES
+- HISTOGRAMS
+- PROJECTOR ([EMBEDDINGS VISUALIZATION](https://www.tensorflow.org/programmers_guide/embedding))
+- AUDIO
+- TEXT
+- PR CURVES
+
+```eval_rst
+.. warning:: This package contains experimental APIs and may change in the near future.
+```
+
+The `summary` module provides the logging APIs through the `SummaryWriter` class.
+
+```eval_rst
+.. autosummary::
+    :nosignatures:
+
+    mxnet.contrib.summary.SummaryWriter
+    mxnet.contrib.summary.SummaryWriter.add_audio
+    mxnet.contrib.summary.SummaryWriter.add_embedding
+    mxnet.contrib.summary.SummaryWriter.add_histogram
+    mxnet.contrib.summary.SummaryWriter.add_image
+    mxnet.contrib.summary.SummaryWriter.add_pr_curve
+    mxnet.contrib.summary.SummaryWriter.add_scalar
+    mxnet.contrib.summary.SummaryWriter.add_text
+    mxnet.contrib.summary.SummaryWriter.close
+    mxnet.contrib.summary.SummaryWriter.flush
+    mxnet.contrib.summary.SummaryWriter.get_logdir
+    mxnet.contrib.summary.SummaryWriter.reopen
+```
+
+## Examples
+Let's take a look at several simple examples demonstrating how to use the MXNet logging APIs.
+
+### Scalar
+Scalar values are often plotted in terms of curves, such as training accuracy as time evolves. Here
+is an example of plotting the curve of `y=sin(x/100)` where `x` is in the range of `[0, 2*pi]`.
+```python
+import numpy as np
+from mxnet.contrib.summary import SummaryWriter
+
+x_vals = np.arange(start=0, stop=2 * np.pi, step=0.01)
+y_vals = np.sin(x_vals)
+with SummaryWriter(logdir='./logs') as sw:
+    for x, y in zip(x_vals, y_vals):
+        sw.add_scalar(tag='sin_function_curve', value=y, global_step=x * 100)
+```
+![png](https://github.com/reminisce/web-data/blob/tensorboard_doc/mxnet/tensorboard/doc/summary_scalar_sin.png)
+
+
+### Histogram
+We can visulize the value distributions of tensors by logging `NDArray`s in terms of histograms.
+The following code snippet generates a series of normal distributions with smaller and smaller standard deviations.
+```python
+import mxnet as mx
+from mxnet.contrib.summary import SummaryWriter
+
+
+with SummaryWriter(logdir='./logs') as sw:
+    for i in range(10):
+        data = mx.nd.normal(loc=0, scale=10.0/(i+1), shape=(10, 3, 8, 8))
+        sw.add_histogram(tag='norml_dist', values=data, bins=200, global_step=i)
+```
+![png](https://github.com/reminisce/web-data/blob/tensorboard_doc/mxnet/tensorboard/doc/summary_histogram_norm.png)
+
+
+### Image
+The image logging API can take MXNet `NDArray` or `numpy.ndarray` of 2-4 dimensions.
+It will preprocess the input image and write the processed image to the event file.
+When the input image data is 2D or 3D, it represents a single image.
+When the input image data is a 4D tensor, which represents a batch of images, the logging
+API would make a grid of those images by stitching them together before write
+them to the event file. The following code snippet saves 15 same images
+for visualization in TensorBoard.
+```python
+import mxnet as mx
+import numpy as np
+from mxnet.contrib.summary import SummaryWriter
+from scipy import misc
+
+face = misc.face().transpose((2, 0, 1))
+face = face.reshape((1,) + face.shape)
+faces = [face] * 15
+faces = np.concatenate(faces, axis=0)
+
+img = mx.nd.array(faces, dtype=faces.dtype)
+with SummaryWriter(logdir='./logs') as sw:
+    sw.add_image(tag='faces', image=img)
+```
+![png](https://github.com/reminisce/web-data/blob/tensorboard_doc/mxnet/tensorboard/doc/summary_image_faces.png)
+
+
+### Embedding
+Embedding visualization enables people to get an intuition on how data is clustered
+in 2D or 3D space. The following code takes 2,560 images of handwritten digits
+from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) and log them
+as embedding vectors with labels and original images.
+```python
+import numpy as np
+import mxnet as mx
+from mxnet import gluon
+from mxnet.contrib.summary import SummaryWriter
+
+
+batch_size = 128
+
+
+def transformer(data, label):
+    data = data.reshape((-1,)).astype(np.float32)/255
+    return data, label
+
+
+train_data = gluon.data.DataLoader(
+    gluon.data.vision.MNIST('./data', train=True, transform=transformer),
+    batch_size=batch_size, shuffle=True, last_batch='discard')
+
+initialized = False
+embedding = None
+labels = None
+images = None
+
+for i, (data, label) in enumerate(train_data):
+    if i >= 20:
+        break
+    if initialized:
+        embedding = mx.nd.concat(*(embedding, data), dim=0)
+        labels = mx.nd.concat(*(labels, label), dim=0)
+        images = mx.nd.concat(*(images, data.reshape(batch_size, 1, 28, 28)), dim=0)
+    else:
+        embedding = data
+        labels = label
+        images = data.reshape(batch_size, 1, 28, 28)
+        initialized = True
+
+with SummaryWriter(logdir='./logs') as sw:
+    sw.add_embedding(tag='mnist', embedding=embedding, labels=labels, images=images)
+```
+![png](https://github.com/reminisce/web-data/blob/tensorboard_doc/mxnet/tensorboard/doc/summary_embedding_mnist.png)
+
+
+### Audio
+The following code generates audio data uniformly sampled in range `[-1, 1]`
+and write the data to the event file for TensorBoard to playback.
+```python
+import mxnet as mx
+from mxnet.contrib.summary import SummaryWriter
+
+
+frequency = 44100
+# 44100 random samples between -1 and 1
+data = mx.random.uniform(low=-1, high=1, shape=(frequency,))
+max_abs_val = data.abs().max()
+# rescale the data to the range [-1, 1]
+data = data / max_abs_val
+with SummaryWriter(logdir='./logs') as sw:
+    sw.add_audio(tag='uniform_audio', audio=data, global_step=0)
+```
+![png](https://github.com/reminisce/web-data/blob/tensorboard_doc/mxnet/tensorboard/doc/summary_audio_uniform.png)
+
+
+### Text
+TensorBoard is able to render plain text as well as text in the markdown format.
+The following code demonstrates these two use cases.
+```python
+from mxnet.contrib.summary import SummaryWriter
+
+
+def simple_example(sw, step):
+    greeting = 'Hello MXNet from step {}'.format(str(step))
+    sw.add_text(tag='simple_example', text=greeting, global_step=step)
+
+
+def markdown_table(sw):
+    header_row = 'Hello | MXNet,\n'
+    delimiter = '----- | -----\n'
+    table_body = 'This | is\n' + 'so | awesome!'
+    sw.add_text(tag='markdown_table', text=header_row+delimiter+table_body)
+
+
+with SummaryWriter(logdir='./logs') as sw:
+    simple_example(sw, 100)
+    markdown_table(sw)
+```
+![png](https://github.com/reminisce/web-data/blob/tensorboard_doc/mxnet/tensorboard/doc/summary_text.png)
+
+
+### PR Curve
+Precision-Recall is a useful metric of success of prediction when the categories are imbalanced.
+The relationship between recall and precision can be visualized in terms of precision-recall curves.
+The following code snippet logs the data of predictions and labels for visualizing
+the precision-recall curve in TensorBoard. It generates 100 numbers uniformly distributed in range `[0, 1]` representing
+the predictions of 100 examples. The labels are also generated randomly by picking either 0 or 1.
+```python
+import mxnet as mx
+import numpy as np
+from mxnet.contrib.summary import SummaryWriter
+
+with SummaryWriter(logdir='./logs') as sw:
+    predictions = mx.nd.uniform(low=0, high=1, shape=(100,), dtype=np.float32)
+    labels = mx.nd.uniform(low=0, high=2, shape=(100,), dtype=np.float32).astype(np.int32)
+    sw.add_pr_curve(tag='pseudo_pr_curve', predictions=predictions, labels=labels, num_thresholds=120)
+```
+![png](https://github.com/reminisce/web-data/blob/tensorboard_doc/mxnet/tensorboard/doc/summary_pr_curve_uniform.png)
+
+
+## API Reference
+
+<script type="text/javascript" src='../../_static/js/auto_module_index.js'></script>
+
+```eval_rst
+.. autoclass:: mxnet.contrib.summary.SummaryWriter
+    :members:
+```
+<script>auto_index("api-reference");</script>
\ No newline at end of file
diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index f65d3abfb15..140de352828 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -151,4 +151,5 @@ imported by running:
 
    contrib/contrib.md
    contrib/text.md
+   contrib/summary.md
 ```
diff --git a/example/visualization/tensorboard/README.md b/example/visualization/tensorboard/README.md
new file mode 100644
index 00000000000..0982022e03b
--- /dev/null
+++ b/example/visualization/tensorboard/README.md
@@ -0,0 +1,14 @@
+Visualizing Training MNIST Model
+=============================
+
+This folder contains an example of logging MXNet data for visualization in TensorBoard
+in the process of training the MNIST model using Gluon interfaces. To run the example,
+type `python mnist.py` in the terminal. While the training program is running, launch
+TensorBoard by typing the following command under the current path:
+```bash
+tensorboard --logdir=./logs --host=127.0.0.1 --port=8888
+```
+Then open the browser and enter the address `127.0.0.1:8888`.
+You would be able to see the figures of training/validation accuracy curves,
+histograms of the gradients of all the parameters evolving with time, and training images
+of the first mini-batch of each epoch.
diff --git a/example/visualization/tensorboard/mnist.py b/example/visualization/tensorboard/mnist.py
new file mode 100644
index 00000000000..8c36ab1583d
--- /dev/null
+++ b/example/visualization/tensorboard/mnist.py
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+from __future__ import print_function
+
+import argparse
+import logging
+import numpy as np
+import mxnet as mx
+from mxnet import gluon, autograd
+from mxnet.gluon import nn
+from mxnet.contrib.summary import SummaryWriter
+
+logging.basicConfig(level=logging.DEBUG)
+
+# Parse CLI arguments
+
+parser = argparse.ArgumentParser(description='MXNet Gluon MNIST Example')
+parser.add_argument('--batch-size', type=int, default=100,
+                    help='batch size for training and testing (default: 100)')
+parser.add_argument('--epochs', type=int, default=10,
+                    help='number of epochs to train (default: 10)')
+parser.add_argument('--lr', type=float, default=0.1,
+                    help='learning rate (default: 0.1)')
+parser.add_argument('--momentum', type=float, default=0.9,
+                    help='SGD momentum (default: 0.9)')
+parser.add_argument('--cuda', action='store_true', default=False,
+                    help='Train on GPU with CUDA')
+parser.add_argument('--log-interval', type=int, default=100, metavar='N',
+                    help='how many batches to wait before logging training status')
+opt = parser.parse_args()
+
+# define network
+
+net = nn.Sequential()
+with net.name_scope():
+    net.add(nn.Dense(128, activation='relu'))
+    net.add(nn.Dense(64, activation='relu'))
+    net.add(nn.Dense(10))
+
+
+# data
+
+def transformer(data, label):
+    data = data.reshape((-1,)).astype(np.float32) / 255
+    return data, label
+
+
+train_data = gluon.data.DataLoader(
+    gluon.data.vision.MNIST('./data', train=True, transform=transformer),
+    batch_size=opt.batch_size, shuffle=True, last_batch='discard')
+
+val_data = gluon.data.DataLoader(
+    gluon.data.vision.MNIST('./data', train=False, transform=transformer),
+    batch_size=opt.batch_size, shuffle=False)
+
+
+def test(ctx):
+    metric = mx.metric.Accuracy()
+    for data, label in val_data:
+        data = data.as_in_context(ctx)
+        label = label.as_in_context(ctx)
+        output = net(data)
+        metric.update([label], [output])
+
+    return metric.get()
+
+
+def train(epochs, ctx):
+    # Collect all parameters from net and its children, then initialize them.
+    net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
+
+    # Trainer is for updating parameters with gradient.
+    trainer = gluon.Trainer(net.collect_params(), 'sgd',
+                            {'learning_rate': opt.lr, 'momentum': opt.momentum})
+    metric = mx.metric.Accuracy()
+    loss = gluon.loss.SoftmaxCrossEntropyLoss()
+
+    # collect parameter names for logging the gradients of parameters in each epoch
+    params = net.collect_params()
+    param_names = params.keys()
+
+    # define the number of bins for logging histograms
+    num_bins = 1000
+
+    # define a summary writer that logs data and flushes to the file every 5 seconds
+    sw = SummaryWriter(logdir='logs', flush_secs=5)
+
+    for epoch in range(epochs):
+        # reset data iterator and metric at begining of epoch.
+        metric.reset()
+        for i, (data, label) in enumerate(train_data):
+            # Copy data to ctx if necessary
+            data = data.as_in_context(ctx)
+            label = label.as_in_context(ctx)
+            # Start recording computation graph with record() section.
+            # Recorded graphs can then be differentiated with backward.
+            with autograd.record():
+                output = net(data)
+                L = loss(output, label)
+            L.backward()
+
+            # take a gradient step with batch_size equal to data.shape[0]
+            trainer.step(data.shape[0])
+            # update metric at last.
+            metric.update([label], [output])
+
+            if i % opt.log_interval == 0 and i > 0:
+                name, acc = metric.get()
+                print('[Epoch %d Batch %d] Training: %s=%f' % (epoch, i, name, acc))
+
+            # log the first batch of images of each epoch to make sure that we are
+            # training based upon the correct examples and dataset as well as
+            # the dataset is really shuffled in each epoch.
+            if i == 0:
+                sw.add_image(('epoch%d_minibatch%d' % (epoch, i)),
+                             data.reshape((opt.batch_size, 1, 28, 28)), epoch)
+
+        grads = [i.grad() for i in net.collect_params().values()]
+        assert len(grads) == len(param_names)
+        # logging the gradients of parameters for checking convergence
+        for i, name in enumerate(param_names):
+            sw.add_histogram(tag=name, values=grads[i], global_step=epoch, bins=num_bins)
+
+        name, acc = metric.get()
+        print('[Epoch %d] Training: %s=%f' % (epoch, name, acc))
+        # logging training accuracy for visualizing the training accuracy curve
+        sw.add_scalar(tag='train_acc', value=acc, global_step=epoch)
+
+        name, val_acc = test(ctx)
+        # logging the validation accuracy for visualizing the validation accuracy curve
+        print('[Epoch %d] Validation: %s=%f' % (epoch, name, val_acc))
+        sw.add_scalar(tag='valid_acc', value=val_acc, global_step=epoch)
+
+    net.save_params('mnist.params')
+    sw.close()
+
+
+if __name__ == '__main__':
+    if opt.cuda:
+        ctx = mx.gpu(0)
+    else:
+        ctx = mx.cpu()
+    train(opt.epochs, ctx)
+    print('finished training')
diff --git a/python/mxnet/contrib/summary/__init__.py b/python/mxnet/contrib/summary/__init__.py
new file mode 100644
index 00000000000..eb02d62b45c
--- /dev/null
+++ b/python/mxnet/contrib/summary/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Summary module for logging MXNet data for visualization in TensorBoard"""
+
+from .writer import SummaryWriter
diff --git a/python/mxnet/contrib/summary/crc32c.py b/python/mxnet/contrib/summary/crc32c.py
new file mode 100644
index 00000000000..db26bdfe83c
--- /dev/null
+++ b/python/mxnet/contrib/summary/crc32c.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Encoding serialized protobuf strings conforming tfrecord format. This file is copied from
+https://github.com/TeamHG-Memex/tensorboard_logger/blob/master/tensorboard_logger/crc32c.py"""
+import array
+
+# CRC table copied from table0_ in
+# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/hash/crc32c.cc
+CRC_TABLE = (
+    0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4,
+    0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb,
+    0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b,
+    0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24,
+    0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b,
+    0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384,
+    0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54,
+    0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b,
+    0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a,
+    0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35,
+    0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5,
+    0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa,
+    0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45,
+    0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a,
+    0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a,
+    0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595,
+    0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48,
+    0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957,
+    0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687,
+    0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198,
+    0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927,
+    0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38,
+    0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8,
+    0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7,
+    0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096,
+    0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789,
+    0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859,
+    0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46,
+    0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9,
+    0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6,
+    0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36,
+    0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829,
+    0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c,
+    0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93,
+    0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043,
+    0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c,
+    0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3,
+    0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc,
+    0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c,
+    0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033,
+    0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652,
+    0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d,
+    0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d,
+    0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982,
+    0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d,
+    0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622,
+    0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2,
+    0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed,
+    0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530,
+    0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f,
+    0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff,
+    0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0,
+    0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f,
+    0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540,
+    0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90,
+    0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f,
+    0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee,
+    0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1,
+    0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321,
+    0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e,
+    0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81,
+    0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e,
+    0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e,
+    0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351,
+)
+
+
+_CRC_INIT = 0
+
+_MASK = 0xFFFFFFFF
+
+
+def crc_update(crc, data):
+    """Updates CRC-32C checksum with data. Copied from
+    https://github.com/TeamHG-Memex/tensorboard_logger/blob/master/tensorboard_logger/crc32c.py
+
+    Parameter
+    ---------
+      crc : int
+          32-bit checksum to update as long.
+      data : byte array
+          string or iterable over bytes.
+
+    Returns
+    -------
+    int
+        32-bit updated CRC-32C as long.
+    """
+
+    if not isinstance(data, array.array) or data.itemsize != 1:
+        buf = array.array("B", data)
+    else:
+        buf = data
+
+    crc ^= _MASK
+    for b in buf:
+        table_index = (crc ^ b) & 0xff
+        crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK
+    return crc ^ _MASK
+
+
+def crc_finalize(crc):
+    """Finalizes CRC-32C checksum. Copied from
+    https://github.com/TeamHG-Memex/tensorboard_logger/blob/master/tensorboard_logger/crc32c.py
+    This function should be called as last step of crc calculation.
+
+    Parameter
+    ---------
+      crc : int
+          32-bit checksum to update as long.
+
+    Returns
+    -------
+    int
+        finalized 32-bit checksum as long
+    """
+    return crc & _MASK
+
+
+def crc32c(data):
+    """Compute CRC-32C checksum of the data. Copied from
+    https://github.com/TeamHG-Memex/tensorboard_logger/blob/master/tensorboard_logger/crc32c.py
+
+    Parameter
+    ---------
+      data : byte array
+          string or iterable over bytes.
+
+    Returns
+    -------
+    int
+        32-bit CRC-32C checksum of data as long.
+    """
+    return crc_finalize(crc_update(_CRC_INIT, data))
diff --git a/python/mxnet/contrib/summary/event_file_writer.py b/python/mxnet/contrib/summary/event_file_writer.py
new file mode 100644
index 00000000000..b377a15b032
--- /dev/null
+++ b/python/mxnet/contrib/summary/event_file_writer.py
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Writes events to disk in a logdir."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import os.path
+import socket
+import threading
+import time
+
+import six
+
+from .proto import event_pb2
+from .record_writer import RecordWriter
+
+
+class EventsWriter(object):
+    """Writes `Event` protocol buffers to an event file. This class is ported from
+    EventsWriter defined in
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/events_writer.cc"""
+    def __init__(self, file_prefix):
+        """
+        Events files have a name of the form
+        '/file/path/events.out.tfevents.[timestamp].[hostname][file_suffix]'
+        """
+        self._file_prefix = file_prefix
+        self._file_suffix = ''
+        self._filename = None
+        self._recordio_writer = None
+        self._num_outstanding_events = 0
+
+    def __del__(self):
+        self.close()
+
+    def _init_if_needed(self):
+        if self._recordio_writer is not None:
+            return
+        self._filename = self._file_prefix + ".out.tfevents." + str(time.time())[:10] \
+                         + "." + socket.gethostname() + self._file_suffix
+        self._recordio_writer = RecordWriter(self._filename)
+        logging.basicConfig(filename=self._filename)
+        logging.info('Successfully opened events file: %s', self._filename)
+        event = event_pb2.Event()
+        event.wall_time = time.time()
+        self.write_event(event)
+        self.flush()  # flush the first event
+
+    def init_with_suffix(self, file_suffix):
+        """Initializes the events writer with file_suffix"""
+        self._file_suffix = file_suffix
+        self._init_if_needed()
+
+    def write_event(self, event):
+        """Appends event to the file."""
+        # Check if event is of type event_pb2.Event proto.
+        if not isinstance(event, event_pb2.Event):
+            raise TypeError("Expected an event_pb2.Event proto, "
+                            " but got %s" % type(event))
+        return self._write_serialized_event(event.SerializeToString())
+
+    def _write_serialized_event(self, event_str):
+        if self._recordio_writer is None:
+            self._init_if_needed()
+        self._num_outstanding_events += 1
+        self._recordio_writer.write_record(event_str)
+
+    def flush(self):
+        """Flushes the event file to disk."""
+        if self._num_outstanding_events == 0 or self._recordio_writer is None:
+            return
+        self._recordio_writer.flush()
+        if self._num_outstanding_events != 1:
+            logging.info('Wrote %d events to disk', self._num_outstanding_events)
+        else:
+            logging.info('Wrote %d event to disk', self._num_outstanding_events)
+        self._num_outstanding_events = 0
+
+    def close(self):
+        """Flushes the pending events and closes the writer after it is done."""
+        self.flush()
+        if self._recordio_writer is not None:
+            self._recordio_writer.close()
+            self._recordio_writer = None
+
+
+class EventFileWriter(object):
+    """This class is adapted from EventFileWriter in Tensorflow:
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/summary/writer/event_file_writer.py
+    Writes `Event` protocol buffers to an event file.
+    The `EventFileWriter` class creates an event file in the specified directory,
+    and asynchronously writes Event protocol buffers to the file. The Event file
+    is encoded using the tfrecord format, which is similar to RecordIO.
+    """
+
+    def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=None):
+        """Creates a `EventFileWriter` and an event file to write to.
+        On construction the summary writer creates a new event file in `logdir`.
+        This event file will contain `Event` protocol buffers, which are written to
+        disk via the add_event method.
+        The other arguments to the constructor control the asynchronous writes to
+        the event file:
+        """
+        self._logdir = logdir
+        if not os.path.exists(self._logdir):
+            os.makedirs(self._logdir)
+        self._event_queue = six.moves.queue.Queue(max_queue)
+        self._ev_writer = EventsWriter(os.path.join(self._logdir, "events"))
+        self._flush_secs = flush_secs
+        self._sentinel_event = self._get_sentinel_event()
+        if filename_suffix is not None:
+            self._ev_writer.init_with_suffix(filename_suffix)
+        self._closed = False
+        self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
+                                          self._flush_secs, self._sentinel_event)
+
+        self._worker.start()
+
+    def _get_sentinel_event(self):
+        """Generate a sentinel event for terminating worker."""
+        return event_pb2.Event()
+
+    def get_logdir(self):
+        """Returns the directory where event file will be written."""
+        return self._logdir
+
+    def reopen(self):
+        """Reopens the EventFileWriter.
+        Can be called after `close()` to add more events in the same directory.
+        The events will go into a new events file.
+        Does nothing if the `EventFileWriter` was not closed.
+        """
+        if self._closed:
+            self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
+                                              self._flush_secs, self._sentinel_event)
+            self._worker.start()
+            self._closed = False
+
+    def add_event(self, event):
+        """Adds an event to the event file."""
+        if not self._closed:
+            self._event_queue.put(event)
+
+    def flush(self):
+        """Flushes the event file to disk.
+        Call this method to make sure that all pending events have been written to disk.
+        """
+        self._event_queue.join()
+        self._ev_writer.flush()
+
+    def close(self):
+        """Flushes the event file to disk and close the file.
+        Call this method when you do not need the summary writer anymore.
+        """
+        if not self._closed:
+            self.add_event(self._sentinel_event)
+            self.flush()
+            self._worker.join()
+            self._ev_writer.close()
+            self._closed = True
+
+
+class _EventLoggerThread(threading.Thread):
+    """Thread that logs events. Copied from
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/summary/writer/event_file_writer.py#L133"""
+
+    def __init__(self, queue, ev_writer, flush_secs, sentinel_event):
+        """Creates an _EventLoggerThread."""
+        threading.Thread.__init__(self)
+        self.daemon = True
+        self._queue = queue
+        self._ev_writer = ev_writer
+        self._flush_secs = flush_secs
+        # The first event will be flushed immediately.
+        self._next_event_flush_time = 0
+        self._sentinel_event = sentinel_event
+
+    def run(self):
+        while True:
+            event = self._queue.get()
+            if event is self._sentinel_event:
+                self._queue.task_done()
+                break
+            try:
+                self._ev_writer.write_event(event)
+                # Flush the event writer every so often.
+                now = time.time()
+                if now > self._next_event_flush_time:
+                    self._ev_writer.flush()
+                    # Do it again in two minutes.
+                    self._next_event_flush_time = now + self._flush_secs
+            finally:
+                self._queue.task_done()
diff --git a/python/mxnet/contrib/summary/proto/__init__.py b/python/mxnet/contrib/summary/proto/__init__.py
new file mode 100644
index 00000000000..908bb2d60b9
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Protobuf definitions for the logging data"""
diff --git a/python/mxnet/contrib/summary/proto/attr_value.proto b/python/mxnet/contrib/summary/proto/attr_value.proto
new file mode 100644
index 00000000000..79c9d893686
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/attr_value.proto
@@ -0,0 +1,62 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "AttrValueProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+import "mxnet/contrib/summary/proto/tensor.proto";
+import "mxnet/contrib/summary/proto/tensor_shape.proto";
+import "mxnet/contrib/summary/proto/types.proto";
+
+// Protocol buffer representing the value for an attr used to configure an Op.
+// Comment indicates the corresponding attr type.  Only the field matching the
+// attr type may be filled.
+message AttrValue {
+  // LINT.IfChange
+  message ListValue {
+    repeated bytes s = 2;                        // "list(string)"
+    repeated int64 i = 3 [packed = true];        // "list(int)"
+    repeated float f = 4 [packed = true];        // "list(float)"
+    repeated bool b = 5 [packed = true];         // "list(bool)"
+    repeated DataType type = 6 [packed = true];  // "list(type)"
+    repeated TensorShapeProto shape = 7;         // "list(shape)"
+    repeated TensorProto tensor = 8;             // "list(tensor)"
+    repeated NameAttrList func = 9;              // "list(attr)"
+  }
+  // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc)
+
+  oneof value {
+    bytes s = 2;                 // "string"
+    int64 i = 3;                 // "int"
+    float f = 4;                 // "float"
+    bool b = 5;                  // "bool"
+    DataType type = 6;           // "type"
+    TensorShapeProto shape = 7;  // "shape"
+    TensorProto tensor = 8;      // "tensor"
+    ListValue list = 1;          // any "list(...)"
+
+    // "func" represents a function. func.name is a function's name or
+    // a primitive op's name. func.attr.first is the name of an attr
+    // defined for that function. func.attr.second is the value for
+    // that attr in the instantiation.
+    NameAttrList func = 10;
+
+    // This is a placeholder only used in nodes defined inside a
+    // function.  It indicates the attr value will be supplied when
+    // the function is instantiated.  For example, let us suppose a
+    // node "N" in function "FN". "N" has an attr "A" with value
+    // placeholder = "foo". When FN is instantiated with attr "foo"
+    // set to "bar", the instantiated node N's attr A will have been
+    // given the value "bar".
+    string placeholder = 9;
+  }
+}
+
+// A list of attr names and their values. The whole list is attached
+// with a string name.  E.g., MatMul[T=float].
+message NameAttrList {
+  string name = 1;
+  map<string, AttrValue> attr = 2;
+}
diff --git a/python/mxnet/contrib/summary/proto/event.proto b/python/mxnet/contrib/summary/proto/event.proto
new file mode 100644
index 00000000000..f682331fd8d
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/event.proto
@@ -0,0 +1,78 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "EventProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.util";
+
+import "mxnet/contrib/summary/proto/summary.proto";
+
+// Protocol buffer representing an event that happened during
+// the execution of a Brain model.
+message Event {
+  // Timestamp of the event.
+  double wall_time = 1;
+
+  // Global step of the event.
+  int64 step = 2;
+
+  oneof what {
+    // An event file was started, with the specified version.
+    // This is use to identify the contents of the record IO files
+    // easily.  Current version is "brain.Event:2".  All versions
+    // start with "brain.Event:".
+    string file_version = 3;
+    // An encoded version of a GraphDef.
+    bytes graph_def = 4;
+    // A summary was generated.
+    Summary summary = 5;
+    // The user output a log message. Not all messages are logged, only ones
+    // generated via the Python tensorboard_logging module.
+    LogMessage log_message = 6;
+    // The state of the session which can be used for restarting after crashes.
+    SessionLog session_log = 7;
+    // The metadata returned by running a session.run() call.
+    TaggedRunMetadata tagged_run_metadata = 8;
+    // An encoded version of a MetaGraphDef.
+    bytes meta_graph_def = 9;
+  }
+}
+
+// Protocol buffer used for logging messages to the events file.
+message LogMessage {
+  enum Level {
+    UNKNOWN = 0;
+    DEBUG = 10;
+    INFO = 20;
+    WARN = 30;
+    ERROR = 40;
+    FATAL = 50;
+  }
+  Level level = 1;
+  string message = 2;
+}
+
+// Protocol buffer used for logging session state.
+message SessionLog {
+  enum SessionStatus {
+    STATUS_UNSPECIFIED = 0;
+    START = 1;
+    STOP = 2;
+    CHECKPOINT = 3;
+  }
+
+  SessionStatus status = 1;
+  // This checkpoint_path contains both the path and filename.
+  string checkpoint_path = 2;
+  string msg = 3;
+}
+
+// For logging the metadata output for a single session.run() call.
+message TaggedRunMetadata {
+  // Tag name associated with this metadata.
+  string tag = 1;
+  // Byte-encoded version of the `RunMetadata` proto in order to allow lazy
+  // deserialization.
+  bytes run_metadata = 2;
+}
diff --git a/python/mxnet/contrib/summary/proto/graph.proto b/python/mxnet/contrib/summary/proto/graph.proto
new file mode 100644
index 00000000000..6d25f53bc86
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/graph.proto
@@ -0,0 +1,55 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "GraphProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+import "mxnet/contrib/summary/proto/node_def.proto";
+import "mxnet/contrib/summary/proto/versions.proto";
+
+// Represents the graph of operations
+message GraphDef {
+  repeated NodeDef node = 1;
+
+  // Compatibility versions of the graph.  See core/public/version.h for version
+  // history.  The GraphDef version is distinct from the TensorFlow version, and
+  // each release of TensorFlow will support a range of GraphDef versions.
+  VersionDef versions = 4;
+
+  // Deprecated single version field; use versions above instead.  Since all
+  // GraphDef changes before "versions" was introduced were forward
+  // compatible, this field is entirely ignored.
+  int32 version = 3 [deprecated = true];
+
+  // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET.
+  //
+  // "library" provides user-defined functions.
+  //
+  // Naming:
+  //   * library.function.name are in a flat namespace.
+  //     NOTE: We may need to change it to be hierarchical to support
+  //     different orgs. E.g.,
+  //     { "/google/nn", { ... }},
+  //     { "/google/vision", { ... }}
+  //     { "/org_foo/module_bar", { ... }}
+  //     map<string, FunctionDefLib> named_lib;
+  //   * If node[i].op is the name of one function in "library",
+  //     node[i] is deemed as a function call. Otherwise, node[i].op
+  //     must be a primitive operation supported by the runtime.
+  //
+  //
+  // Function call semantics:
+  //
+  //   * The callee may start execution as soon as some of its inputs
+  //     are ready. The caller may want to use Tuple() mechanism to
+  //     ensure all inputs are ready in the same time.
+  //
+  //   * The consumer of return values may start executing as soon as
+  //     the return values the consumer depends on are ready.  The
+  //     consumer may want to use Tuple() mechanism to ensure the
+  //     consumer does not start until all return values of the callee
+  //     function are ready.
+  //FunctionDefLibrary library = 2;
+};
diff --git a/python/mxnet/contrib/summary/proto/node_def.proto b/python/mxnet/contrib/summary/proto/node_def.proto
new file mode 100644
index 00000000000..6335883bfd0
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/node_def.proto
@@ -0,0 +1,63 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "NodeProto";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+import "mxnet/contrib/summary/proto/attr_value.proto";
+
+message NodeDef {
+  // The name given to this operator. Used for naming inputs,
+  // logging, visualization, etc.  Unique within a single GraphDef.
+  // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*".
+  string name = 1;
+
+  // The operation name.  There may be custom parameters in attrs.
+  // Op names starting with an underscore are reserved for internal use.
+  string op = 2;
+
+  // Each input is "node:src_output" with "node" being a string name and
+  // "src_output" indicating which output tensor to use from "node". If
+  // "src_output" is 0 the ":0" suffix can be omitted.  Regular inputs
+  // may optionally be followed by control inputs that have the format
+  // "^node".
+  repeated string input = 3;
+
+  // A (possibly partial) specification for the device on which this
+  // node should be placed.
+  // The expected syntax for this string is as follows:
+  //
+  // DEVICE_SPEC ::= PARTIAL_SPEC
+  //
+  // PARTIAL_SPEC ::= ("/" CONSTRAINT) *
+  // CONSTRAINT ::= ("job:" JOB_NAME)
+  //              | ("replica:" [1-9][0-9]*)
+  //              | ("task:" [1-9][0-9]*)
+  //              | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") )
+  //
+  // Valid values for this string include:
+  // * "/job:worker/replica:0/task:1/gpu:3"  (full specification)
+  // * "/job:worker/gpu:3"                   (partial specification)
+  // * ""                                    (no specification)
+  //
+  // If the constraints do not resolve to a single device (or if this
+  // field is empty or not present), the runtime will attempt to
+  // choose a device automatically.
+  string device = 4;
+
+  // Operation-specific graph-construction-time configuration.
+  // Note that this should include all attrs defined in the
+  // corresponding OpDef, including those with a value matching
+  // the default -- this allows the default to change and makes
+  // NodeDefs easier to interpret on their own.  However, if
+  // an attr with a default is not specified in this list, the
+  // default will be used.
+  // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
+  // one of the names from the corresponding OpDef's attr field).
+  // The values must have a type matching the corresponding OpDef
+  // attr's type field.
+  // TODO(josh11b): Add some examples here showing best practices.
+  map<string, AttrValue> attr = 5;
+};
diff --git a/python/mxnet/contrib/summary/proto/plugin_pr_curve.proto b/python/mxnet/contrib/summary/proto/plugin_pr_curve.proto
new file mode 100644
index 00000000000..33e0f91641f
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/plugin_pr_curve.proto
@@ -0,0 +1,25 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorboard;
+
+message PrCurvePluginData {
+  // Version `0` is the only supported version.
+  int32 version = 1;
+
+  uint32 num_thresholds = 2;
+}
diff --git a/python/mxnet/contrib/summary/proto/resource_handle.proto b/python/mxnet/contrib/summary/proto/resource_handle.proto
new file mode 100644
index 00000000000..17826e0085e
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/resource_handle.proto
@@ -0,0 +1,29 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "ResourceHandle";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// Protocol buffer representing a handle to a tensorflow resource. Handles are
+// not valid across executions, but can be serialized back and forth from within
+// a single run.
+message ResourceHandleProto {
+  // Unique name for the device containing the resource.
+  string device = 1;
+
+  // Container in which this resource is placed.
+  string container = 2;
+
+  // Unique name of this resource.
+  string name = 3;
+
+  // Hash code for the type of the resource. Is only valid in the same device
+  // and in the same execution.
+  uint64 hash_code = 4;
+
+  // For debug-only, the name of the type pointed to by this handle, if
+  // available.
+  string maybe_type_name = 5;
+};
diff --git a/python/mxnet/contrib/summary/proto/summary.proto b/python/mxnet/contrib/summary/proto/summary.proto
new file mode 100644
index 00000000000..6603500632b
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/summary.proto
@@ -0,0 +1,123 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "SummaryProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+import "mxnet/contrib/summary/proto/tensor.proto";
+
+// Metadata associated with a series of Summary data
+message SummaryDescription {
+  // Hint on how plugins should process the data in this series.
+  // Supported values include "scalar", "histogram", "image", "audio"
+  string type_hint = 1;
+}
+
+// Serialization format for histogram module in
+// core/lib/histogram/histogram.h
+message HistogramProto {
+  double min = 1;
+  double max = 2;
+  double num = 3;
+  double sum = 4;
+  double sum_squares = 5;
+
+  // Parallel arrays encoding the bucket boundaries and the bucket values.
+  // bucket(i) is the count for the bucket i.  The range for
+  // a bucket is:
+  //   i == 0:  -DBL_MAX .. bucket_limit(0)
+  //   i != 0:  bucket_limit(i-1) .. bucket_limit(i)
+  repeated double bucket_limit = 6 [packed = true];
+  repeated double bucket = 7 [packed = true];
+};
+
+// A SummaryMetadata encapsulates information on which plugins are able to make
+// use of a certain summary value.
+message SummaryMetadata {
+  message PluginData {
+    // The name of the plugin this data pertains to.
+    string plugin_name = 1;
+
+    // The content to store for the plugin. The best practice is for this JSON
+    // string to be the canonical JSON serialization of a protocol buffer
+    // defined by the plugin. Converting that protobuf to and from JSON is the
+    // responsibility of the plugin code, and is not enforced by
+    // TensorFlow/TensorBoard.
+    string content = 2;
+  }
+
+  // A list of plugin data. A single summary value instance may be used by more
+  // than 1 plugin.
+  repeated PluginData plugin_data = 1;
+};
+
+// A Summary is a set of named values to be displayed by the
+// visualizer.
+//
+// Summaries are produced regularly during training, as controlled by
+// the "summary_interval_secs" attribute of the training operation.
+// Summaries are also produced at the end of an evaluation.
+message Summary {
+  message Image {
+    // Dimensions of the image.
+    int32 height = 1;
+    int32 width = 2;
+    // Valid colorspace values are
+    //   1 - grayscale
+    //   2 - grayscale + alpha
+    //   3 - RGB
+    //   4 - RGBA
+    //   5 - DIGITAL_YUV
+    //   6 - BGRA
+    int32 colorspace = 3;
+    // Image data in encoded format.  All image formats supported by
+    // image_codec::CoderUtil can be stored here.
+    bytes encoded_image_string = 4;
+  }
+
+  message Audio {
+    // Sample rate of the audio in Hz.
+    float sample_rate = 1;
+    // Number of channels of audio.
+    int64 num_channels = 2;
+    // Length of the audio in frames (samples per channel).
+    int64 length_frames = 3;
+    // Encoded audio data and its associated RFC 2045 content type (e.g.
+    // "audio/wav").
+    bytes encoded_audio_string = 4;
+    string content_type = 5;
+  }
+
+  message Value {
+    // Name of the node that output this summary; in general, the name of a
+    // TensorSummary node. If the node in question has multiple outputs, then
+    // a ":\d+" suffix will be appended, like "some_op:13".
+    // Might not be set for legacy summaries (i.e. those not using the tensor
+    // value field)
+    string node_name = 7;
+
+    // Tag name for the data.  Will only be used by legacy summaries
+    // (ie. those not using the tensor value field)
+    // For legacy summaries, will be used as the title of the graph
+    // in the visualizer.
+    //
+    // Tag is usually "op_name:value_name", where "op_name" itself can have
+    // structure to indicate grouping.
+    string tag = 1;
+    SummaryMetadata metadata = 9;
+    // Value associated with the tag.
+    oneof value {
+      float simple_value = 2;
+      bytes obsolete_old_style_histogram = 3;
+      Image image = 4;
+      HistogramProto histo = 5;
+      Audio audio = 6;
+      TensorProto tensor = 8;
+    }
+  }
+
+  // Set of values for the summary.
+  repeated Value value = 1;
+}
diff --git a/python/mxnet/contrib/summary/proto/tensor.proto b/python/mxnet/contrib/summary/proto/tensor.proto
new file mode 100644
index 00000000000..6312943e66b
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/tensor.proto
@@ -0,0 +1,75 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "TensorProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+import "mxnet/contrib/summary/proto/resource_handle.proto";
+import "mxnet/contrib/summary/proto/tensor_shape.proto";
+import "mxnet/contrib/summary/proto/types.proto";
+
+// Protocol buffer representing a tensor.
+message TensorProto {
+  DataType dtype = 1;
+
+  // Shape of the tensor.  TODO(touts): sort out the 0-rank issues.
+  TensorShapeProto tensor_shape = 2;
+
+  // Only one of the representations below is set, one of "tensor_contents" and
+  // the "xxx_val" attributes.  We are not using oneof because as oneofs cannot
+  // contain repeated fields it would require another extra set of messages.
+
+  // Version number.
+  //
+  // In version 0, if the "repeated xxx" representations contain only one
+  // element, that element is repeated to fill the shape.  This makes it easy
+  // to represent a constant Tensor with a single value.
+  int32 version_number = 3;
+
+  // Serialized raw tensor content from either Tensor::AsProtoTensorContent or
+  // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
+  // can be used for all tensor types. The purpose of this representation is to
+  // reduce serialization overhead during RPC call by avoiding serialization of
+  // many repeated small items.
+  bytes tensor_content = 4;
+
+  // Type specific representations that make it easy to create tensor protos in
+  // all languages.  Only the representation corresponding to "dtype" can
+  // be set.  The values hold the flattened representation of the tensor in
+  // row major order.
+
+  // DT_HALF. Note that since protobuf has no int16 type, we'll have some
+  // pointless zero padding for each value here.
+  repeated int32 half_val = 13 [packed = true];
+
+  // DT_FLOAT.
+  repeated float float_val = 5 [packed = true];
+
+  // DT_DOUBLE.
+  repeated double double_val = 6 [packed = true];
+
+  // DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
+  repeated int32 int_val = 7 [packed = true];
+
+  // DT_STRING
+  repeated bytes string_val = 8;
+
+  // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
+  // and imaginary parts of i-th single precision complex.
+  repeated float scomplex_val = 9 [packed = true];
+
+  // DT_INT64
+  repeated int64 int64_val = 10 [packed = true];
+
+  // DT_BOOL
+  repeated bool bool_val = 11 [packed = true];
+
+  // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
+  // and imaginary parts of i-th double precision complex.
+  repeated double dcomplex_val = 12 [packed = true];
+
+  // DT_RESOURCE
+  repeated ResourceHandleProto resource_handle_val = 14;
+};
diff --git a/python/mxnet/contrib/summary/proto/tensor_shape.proto b/python/mxnet/contrib/summary/proto/tensor_shape.proto
new file mode 100644
index 00000000000..7c7474387e1
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/tensor_shape.proto
@@ -0,0 +1,45 @@
+// Protocol buffer representing the shape of tensors.
+
+syntax = "proto3";
+option cc_enable_arenas = true;
+option java_outer_classname = "TensorShapeProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+package tensorboard;
+
+// Dimensions of a tensor.
+message TensorShapeProto {
+  // One dimension of the tensor.
+  message Dim {
+    // Size of the tensor in that dimension.
+    // This value must be >= -1, but values of -1 are reserved for "unknown"
+    // shapes (values of -1 mean "unknown" dimension).  Certain wrappers
+    // that work with TensorShapeProto may fail at runtime when deserializing
+    // a TensorShapeProto containing a dim value of -1.
+    int64 size = 1;
+
+    // Optional name of the tensor dimension.
+    string name = 2;
+  };
+
+  // Dimensions of the tensor, such as {"input", 30}, {"output", 40}
+  // for a 30 x 40 2D tensor.  If an entry has size -1, this
+  // corresponds to a dimension of unknown size. The names are
+  // optional.
+  //
+  // The order of entries in "dim" matters: It indicates the layout of the
+  // values in the tensor in-memory representation.
+  //
+  // The first entry in "dim" is the outermost dimension used to layout the
+  // values, the last entry is the innermost dimension.  This matches the
+  // in-memory layout of RowMajor Eigen tensors.
+  //
+  // If "dim.size()" > 0, "unknown_rank" must be false.
+  repeated Dim dim = 2;
+
+  // If true, the number of dimensions in the shape is unknown.
+  //
+  // If true, "dim.size()" must be 0.
+  bool unknown_rank = 3;
+};
diff --git a/python/mxnet/contrib/summary/proto/types.proto b/python/mxnet/contrib/summary/proto/types.proto
new file mode 100644
index 00000000000..71e8ed7800e
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/types.proto
@@ -0,0 +1,64 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "TypesProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// LINT.IfChange
+enum DataType {
+  // Not a legal value for DataType.  Used to indicate a DataType field
+  // has not been set.
+  DT_INVALID = 0;
+
+  // Data types that all computation devices are expected to be
+  // capable to support.
+  DT_FLOAT = 1;
+  DT_DOUBLE = 2;
+  DT_INT32 = 3;
+  DT_UINT8 = 4;
+  DT_INT16 = 5;
+  DT_INT8 = 6;
+  DT_STRING = 7;
+  DT_COMPLEX64 = 8;  // Single-precision complex
+  DT_INT64 = 9;
+  DT_BOOL = 10;
+  DT_QINT8 = 11;     // Quantized int8
+  DT_QUINT8 = 12;    // Quantized uint8
+  DT_QINT32 = 13;    // Quantized int32
+  DT_BFLOAT16 = 14;  // Float32 truncated to 16 bits.  Only for cast ops.
+  DT_QINT16 = 15;    // Quantized int16
+  DT_QUINT16 = 16;   // Quantized uint16
+  DT_UINT16 = 17;
+  DT_COMPLEX128 = 18;  // Double-precision complex
+  DT_HALF = 19;
+  DT_RESOURCE = 20;
+
+  // TODO(josh11b): DT_GENERIC_PROTO = ??;
+  // TODO(jeff,josh11b): DT_UINT64?  DT_UINT32?
+
+  // Do not use!  These are only for parameters.  Every enum above
+  // should have a corresponding value below (verified by types_test).
+  DT_FLOAT_REF = 101;
+  DT_DOUBLE_REF = 102;
+  DT_INT32_REF = 103;
+  DT_UINT8_REF = 104;
+  DT_INT16_REF = 105;
+  DT_INT8_REF = 106;
+  DT_STRING_REF = 107;
+  DT_COMPLEX64_REF = 108;
+  DT_INT64_REF = 109;
+  DT_BOOL_REF = 110;
+  DT_QINT8_REF = 111;
+  DT_QUINT8_REF = 112;
+  DT_QINT32_REF = 113;
+  DT_BFLOAT16_REF = 114;
+  DT_QINT16_REF = 115;
+  DT_QUINT16_REF = 116;
+  DT_UINT16_REF = 117;
+  DT_COMPLEX128_REF = 118;
+  DT_HALF_REF = 119;
+  DT_RESOURCE_REF = 120;
+}
+// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go)
diff --git a/python/mxnet/contrib/summary/proto/versions.proto b/python/mxnet/contrib/summary/proto/versions.proto
new file mode 100644
index 00000000000..fb001426ab3
--- /dev/null
+++ b/python/mxnet/contrib/summary/proto/versions.proto
@@ -0,0 +1,31 @@
+syntax = "proto3";
+
+package tensorboard;
+option cc_enable_arenas = true;
+option java_outer_classname = "VersionsProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// Version information for a piece of serialized data
+//
+// There are different types of versions for each type of data
+// (GraphDef, etc.), but they all have the same common shape
+// described here.
+//
+// Each consumer has "consumer" and "min_producer" versions (specified
+// elsewhere).  A consumer is allowed to consume this data if
+//
+//   producer >= min_producer
+//   consumer >= min_consumer
+//   consumer not in bad_consumers
+//
+message VersionDef {
+  // The version of the code that produced this data.
+  int32 producer = 1;
+
+  // Any consumer below this version is not allowed to consume this data.
+  int32 min_consumer = 2;
+
+  // Specific consumer versions which are disallowed (e.g. due to bugs).
+  repeated int32 bad_consumers = 3;
+};
diff --git a/python/mxnet/contrib/summary/record_writer.py b/python/mxnet/contrib/summary/record_writer.py
new file mode 100644
index 00000000000..a73646603e9
--- /dev/null
+++ b/python/mxnet/contrib/summary/record_writer.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Writer for writing events to the event file."""
+
+import struct
+from .crc32c import crc32c
+
+
+class RecordWriter(object):
+    """Write records in the following format for a single record event_str:
+    uint64 len(event_str)
+    uint32 masked crc of len(event_str)
+    byte event_str
+    uint32 masked crc of event_str
+    The implementation is ported from
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/io/record_writer.cc
+    Here we simply define a byte string _dest to buffer the record to be written to files.
+    The flush and close mechanism is totally controlled in this class.
+    In TensorFlow, _dest is a object instance of ZlibOutputBuffer (C++) which has its own flush
+    and close mechanism defined."""
+    def __init__(self, path):
+        self._writer = None
+        try:
+            self._writer = open(path, 'wb')
+        except (OSError, IOError) as e:
+            raise ValueError('failed to open file {}: {}'.format(path, str(e)))
+
+    def __del__(self):
+        self.close()
+
+    def write_record(self, event_str):
+        header = struct.pack('Q', len(event_str))
+        header += struct.pack('I', masked_crc32c(header))
+        footer = struct.pack('I', masked_crc32c(event_str))
+        self._writer.write(header + event_str + footer)
+
+    def flush(self):
+        assert self._writer is not None
+        self._writer.flush()
+
+    def close(self):
+        if self._writer is not None:
+            self.flush()
+            self._writer.close()
+            self._writer = None
+
+
+def masked_crc32c(data):
+    """Copied from
+    https://github.com/TeamHG-Memex/tensorboard_logger/blob/master/tensorboard_logger/tensorboard_logger.py"""
+    x = u32(crc32c(data))
+    return u32(((x >> 15) | u32(x << 17)) + 0xa282ead8)
+
+
+def u32(x):
+    """Copied from
+    https://github.com/TeamHG-Memex/tensorboard_logger/blob/master/tensorboard_logger/tensorboard_logger.py"""
+    return x & 0xffffffff
diff --git a/python/mxnet/contrib/summary/summary.py b/python/mxnet/contrib/summary/summary.py
new file mode 100644
index 00000000000..4456726c0a7
--- /dev/null
+++ b/python/mxnet/contrib/summary/summary.py
@@ -0,0 +1,338 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Functions of generating summary protocol buffers. Adapted from
+https://github.com/lanpa/tensorboard-pytorch/blob/master/tensorboardX/summary.py"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import io
+import wave
+import struct
+import re as _re
+import numpy as np
+from .proto.summary_pb2 import Summary
+from .proto.summary_pb2 import HistogramProto
+from .proto.summary_pb2 import SummaryMetadata
+from .proto.tensor_pb2 import TensorProto
+from .proto.tensor_shape_pb2 import TensorShapeProto
+from .proto.plugin_pr_curve_pb2 import PrCurvePluginData
+from .utils import _make_numpy_array, _prepare_image
+from ...ndarray import NDArray
+try:
+    from PIL import Image
+except ImportError:
+    Image = None
+
+
+_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')
+
+
+def _clean_tag(name):
+    """Cleans a tag. Removes illegal characters for instance.
+    Adapted from the TensorFlow function `clean_tag()` at
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/summary_op_util.py
+
+    Parameters
+    ----------
+        name : str
+            The original tag name to be processed.
+
+    Returns
+    -------
+        The cleaned tag name.
+    """
+    # In the past, the first argument to summary ops was a tag, which allowed
+    # arbitrary characters. Now we are changing the first argument to be the node
+    # name. This has a number of advantages (users of summary ops now can
+    # take advantage of the tf name scope system) but risks breaking existing
+    # usage, because a much smaller set of characters are allowed in node names.
+    # This function replaces all illegal characters with _s, and logs a warning.
+    # It also strips leading slashes from the name.
+    if name is not None:
+        new_name = _INVALID_TAG_CHARACTERS.sub('_', name)
+        new_name = new_name.lstrip('/')  # Remove leading slashes
+        if new_name != name:
+            logging.warning('Summary name %s is illegal; using %s instead.', name, new_name)
+            name = new_name
+    return name
+
+
+def scalar_summary(tag, scalar):
+    """Outputs a `Summary` protocol buffer containing a single scalar value.
+    The generated Summary has a Tensor.proto containing the input Tensor.
+    Adapted from the TensorFlow function `scalar()` at
+    https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/python/summary/summary.py
+
+    Parameters
+    ----------
+      tag : str
+          A name for the generated summary. Will also serve as the series name in TensorBoard.
+      scalar : int, MXNet `NDArray`, or `numpy.ndarray`
+          A scalar value or an ndarray of shape (1,).
+
+    Returns
+    -------
+      A `Summary` protobuf of the `scalar` value.
+
+    Raises
+    ------
+      ValueError: If the scalar has the wrong shape or type.
+    """
+    tag = _clean_tag(tag)
+    scalar = _make_numpy_array(scalar)
+    assert(scalar.squeeze().ndim == 0), 'scalar should be 0D'
+    scalar = float(scalar)
+    return Summary(value=[Summary.Value(tag=tag, simple_value=scalar)])
+
+
+def histogram_summary(tag, values, bins):
+    """Outputs a `Summary` protocol buffer with a histogram.
+    Adding a histogram summary makes it possible to visualize the data's distribution in
+    TensorBoard. See detailed explanation of the TensorBoard histogram dashboard at
+    https://www.tensorflow.org/get_started/tensorboard_histograms
+    This op reports an `InvalidArgument` error if any value is not finite.
+    Adapted from the TensorFlow function `histogram()` at
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/summary/summary.py
+
+    Parameters
+    ----------
+        tag : str
+            A name for the summary of the histogram. Will also serve as a series name in
+            TensorBoard.
+        values : MXNet `NDArray` or `numpy.ndarray`
+            Values for building the histogram.
+
+    Returns
+    -------
+        A `Summary` protobuf of the histogram.
+    """
+    tag = _clean_tag(tag)
+    values = _make_numpy_array(values)
+    hist = _make_histogram(values.astype(float), bins)
+    return Summary(value=[Summary.Value(tag=tag, histo=hist)])
+
+
+def _make_histogram(values, bins):
+    """Converts values into a histogram proto using logic from
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/histogram/histogram.cc"""
+    values = values.reshape(-1)
+    counts, limits = np.histogram(values, bins=bins)
+    limits = limits[1:]
+
+    sum_sq = values.dot(values)
+    return HistogramProto(min=values.min(),
+                          max=values.max(),
+                          num=len(values),
+                          sum=values.sum(),
+                          sum_squares=sum_sq,
+                          bucket_limit=limits,
+                          bucket=counts)
+
+
+def image_summary(tag, image):
+    """Outputs a `Summary` protocol buffer with image(s).
+
+    Parameters
+    ----------
+        tag : str
+            A name for the generated summary. Will also serve as a series name in TensorBoard.
+        image : MXNet `NDArray` or `numpy.ndarray`
+            Image data that is one of the following layout: (H, W), (C, H, W), (N, C, H, W).
+            The pixel values of the image are assumed to be normalized in the range [0, 1].
+            The image will be rescaled to the range [0, 255] and cast to `np.uint8` before creating
+            the image protobuf.
+
+    Returns
+    -------
+        A `Summary` protobuf of the image.
+    """
+    tag = _clean_tag(tag)
+    image = _prepare_image(image)
+    image = _make_image(image)
+    return Summary(value=[Summary.Value(tag=tag, image=image)])
+
+
+def _make_image(tensor):
+    """Converts an NDArray type image to Image protobuf"""
+    assert isinstance(tensor, NDArray)
+    if Image is None:
+        raise ImportError('need to install PIL for visualizing images')
+    height, width, channel = tensor.shape
+    tensor = _make_numpy_array(tensor)
+    image = Image.fromarray(tensor)
+    output = io.BytesIO()
+    image.save(output, format='PNG')
+    image_string = output.getvalue()
+    output.close()
+    return Summary.Image(height=height, width=width, colorspace=channel,
+                         encoded_image_string=image_string)
+
+
+def audio_summary(tag, audio, sample_rate=44100):
+    """Outputs a `Summary` protocol buffer with audio data.
+
+    Parameters
+    ----------
+        tag : str
+            A name for the generated summary. Will also serve as a series name in TensorBoard.
+        audio : MXNet `NDArray` or `numpy.ndarray`
+            Audio data that can be squeezed into 1D array. The values are in the range [-1, 1].
+        sample_rate : int
+            Sampling frequency. 44,100Hz is a common sampling frequency.
+
+    Returns
+    -------
+        A `Summary` protobuf of the audio data.
+    """
+    audio = audio.squeeze()
+    if audio.ndim != 1:
+        raise ValueError('input audio must be squeezable to 1D, input audio squeezed '
+                         'shape is {}'.format(audio.shape))
+    audio = _make_numpy_array(audio)
+    tensor_list = [int(32767.0 * x) for x in audio]
+    fio = io.BytesIO()
+    wave_writer = wave.open(fio, 'wb')
+    wave_writer.setnchannels(1)
+    wave_writer.setsampwidth(2)
+    wave_writer.setframerate(sample_rate)
+    tensor_enc = b''
+    for v in tensor_list:
+        tensor_enc += struct.pack('<h', v)
+    wave_writer.writeframes(tensor_enc)
+    wave_writer.close()
+    audio_string = fio.getvalue()
+    fio.close()
+    audio = Summary.Audio(sample_rate=sample_rate,
+                          num_channels=1,
+                          length_frames=len(tensor_list),
+                          encoded_audio_string=audio_string,
+                          content_type='audio/wav')
+    return Summary(value=[Summary.Value(tag=tag, audio=audio)])
+
+
+def text_summary(tag, text):
+    """Outputs a `Summary` protocol buffer with audio data.
+
+    Parameters
+    ----------
+        tag : str
+            A name for the generated summary. Will also serve as a series name in TensorBoard.
+        text : str
+            Text data.
+
+    Returns
+    -------
+        A `Summary` protobuf of the audio data.
+    """
+    plugin_data = [SummaryMetadata.PluginData(plugin_name='text')]
+    smd = SummaryMetadata(plugin_data=plugin_data)
+    tensor = TensorProto(dtype='DT_STRING',
+                         string_val=[text.encode(encoding='utf_8')],
+                         tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]))
+    return Summary(value=[Summary.Value(node_name=tag, metadata=smd, tensor=tensor)])
+
+
+def pr_curve_summary(tag, labels, predictions, num_thresholds, weights=None):
+    """Outputs a precision-recall curve `Summary` protocol buffer.
+
+    Parameters
+    ----------
+        tag : str
+            A tag attached to the summary. Used by TensorBoard for organization.
+        labels : MXNet `NDArray` or `numpy.ndarray`.
+            The ground truth values. A tensor of 0/1 values with arbitrary shape.
+        predictions : MXNet `NDArray` or `numpy.ndarray`.
+            A float32 tensor whose values are in the range `[0, 1]`. Dimensions must
+            match those of `labels`.
+        num_thresholds : int
+            Number of thresholds, evenly distributed in `[0, 1]`, to compute PR metrics for.
+            Should be `>= 2`. This value should be a constant integer value, not a tensor
+            that stores an integer.
+            The thresholds for computing the pr curves are calculated in the following way:
+            `width = 1.0 / (num_thresholds - 1),
+            thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]`.
+        weights : MXNet `NDArray` or `numpy.ndarray`.
+            Optional float32 tensor. Individual counts are multiplied by this value.
+            This tensor must be either the same shape as or broadcastable to the `labels` tensor.
+
+    Returns
+    -------
+        A `Summary` protobuf of the pr_curve.
+    """
+    # num_thresholds > 127 results in failure of creating protobuf,
+    # probably a bug of protobuf
+    if num_thresholds > 127:
+        logging.warning('num_thresholds>127 would result in failure of creating pr_curve protobuf,'
+                        ' clipping it at 127')
+        num_thresholds = 127
+    labels = _make_numpy_array(labels)
+    predictions = _make_numpy_array(predictions)
+    if weights is not None:
+        weights = _make_numpy_array(weights)
+    data = _compute_curve(labels, predictions, num_thresholds=num_thresholds, weights=weights)
+    pr_curve_plugin_data = PrCurvePluginData(version=0,
+                                             num_thresholds=num_thresholds).SerializeToString()
+    plugin_data = [SummaryMetadata.PluginData(plugin_name='pr_curves',
+                                              content=pr_curve_plugin_data)]
+    smd = SummaryMetadata(plugin_data=plugin_data)
+    tensor = TensorProto(dtype='DT_FLOAT',
+                         float_val=data.reshape(-1).tolist(),
+                         tensor_shape=TensorShapeProto(
+                             dim=[TensorShapeProto.Dim(size=data.shape[0]),
+                                  TensorShapeProto.Dim(size=data.shape[1])]))
+    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
+
+
+# A value that we use as the minimum value during division of counts to prevent
+# division by 0. 1.0 does not work: Certain weights could cause counts below 1.
+_MINIMUM_COUNT = 1e-7
+
+
+def _compute_curve(labels, predictions, num_thresholds, weights=None):
+    """This function is another implementation of functions in
+    https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py"""
+
+    if weights is None:
+        weights = 1.0
+
+    # Compute bins of true positives and false positives.
+    bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
+    float_labels = labels.astype(np.float)
+    histogram_range = (0, num_thresholds - 1)
+    tp_buckets, _ = np.histogram(
+        bucket_indices,
+        bins=num_thresholds,
+        range=histogram_range,
+        weights=float_labels * weights)
+    fp_buckets, _ = np.histogram(
+        bucket_indices,
+        bins=num_thresholds,
+        range=histogram_range,
+        weights=(1.0 - float_labels) * weights)
+
+    # Obtain the reverse cumulative sum.
+    tp = np.cumsum(tp_buckets[::-1])[::-1]
+    fp = np.cumsum(fp_buckets[::-1])[::-1]
+    tn = fp[0] - fp
+    fn = tp[0] - tp
+    precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
+    recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
+    return np.stack((tp, fp, tn, fn, precision, recall))
diff --git a/python/mxnet/contrib/summary/utils.py b/python/mxnet/contrib/summary/utils.py
new file mode 100644
index 00000000000..22d1adf8588
--- /dev/null
+++ b/python/mxnet/contrib/summary/utils.py
@@ -0,0 +1,276 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Util functions for writing summaries."""
+
+import os
+import logging
+import numpy as np
+
+try:
+    from PIL import Image
+except ImportError:
+    Image = None
+from ...ndarray import NDArray
+from ...ndarray import ndarray as nd
+from ...ndarray import op
+
+
+def _make_numpy_array(x):
+    if isinstance(x, np.ndarray):
+        return x
+    elif np.isscalar(x):
+        return np.array([x])
+    elif isinstance(x, NDArray):
+        return x.asnumpy()
+    else:
+        raise TypeError('_make_numpy_array only accepts input types of numpy.ndarray, scalar,'
+                        ' and MXNet NDArray, while received type {}'.format(str(type(x))))
+
+
+def make_image_grid(tensor, nrow=8, padding=2, normalize=False, norm_range=None,
+                    scale_each=False, pad_value=0):
+    """Make a grid of images. This is an MXNet version of torchvision.utils.make_grid
+    Ref: https://github.com/pytorch/vision/blob/master/torchvision/utils.py
+
+    Parameters
+    ----------
+        tensor : `NDArray` or list of `NDArray`s
+            Input image(s) in the format of HW, CHW, or NCHW.
+        nrow : int
+            Number of images displayed in each row of the grid. The Final grid size is
+            (batch_size / `nrow`, `nrow`).
+        padding : int
+            Padding value for each image in the grid.
+        normalize : bool
+            If True, shift the image to the range (0, 1), by subtracting the
+            minimum and dividing by the maximum pixel value.
+        norm_range : tuple
+            Tuple of (min, max) where min and max are numbers. These numbers are used
+            to normalize the image. By default, `min` and `max` are computed from the `tensor`.
+        scale_each : bool
+            If True, scale each image in the batch of images separately rather than the
+            `(min, max)` over all images.
+        pad_value : float
+            Value for the padded pixels.
+
+    Returns
+    -------
+    NDArray
+        A image grid made of the input images.
+    """
+    if not isinstance(tensor, NDArray) or not (isinstance(tensor, NDArray) and
+                                               all(isinstance(t, NDArray) for t in tensor)):
+        raise TypeError('MXNet NDArray or list of NDArrays expected, got {}'.format(
+            str(type(tensor))))
+
+    # if list of tensors, convert to a 4D mini-batch Tensor
+    if isinstance(tensor, list):
+        tensor = op.stack(tensor, axis=0)
+
+    if tensor.ndim <= 1 or tensor.ndim > 4:
+        raise ValueError('expected 2D, 3D, or 4D NDArrays, while received ndim={}'.format(
+            tensor.ndim))
+
+    if tensor.ndim == 2:  # single image H x W
+        tensor = tensor.reshape(((1,) + tensor.shape))
+    if tensor.ndim == 3:  # single image
+        if tensor.shape[0] == 1:  # if single-channel, convert to 3-channel
+            tensor = op.concat(*(tensor, tensor, tensor), dim=0)
+        tensor = tensor.reshape((1,) + tensor.shape)
+    if tensor.ndim == 4 and tensor.shape[1] == 1:  # single-channel images
+        tensor = op.concat(*(tensor, tensor, tensor), dim=1)
+
+    if normalize is True:
+        tensor = tensor.copy()  # avoid modifying tensor in-place
+        if norm_range is not None:
+            assert isinstance(norm_range, tuple) and len(norm_range) == 2, \
+                "norm_range has to be a tuple (min, max) if specified. min and max are numbers"
+
+        def norm_ip(img, val_min, val_max):
+            op.clip(img, a_min=val_min, a_max=val_max, out=img)
+            img -= val_min
+            img /= (val_max - val_min)
+
+        def norm_range_helper(t, val_range):
+            if val_range is not None:
+                norm_ip(t, val_range[0], val_range[1])
+            else:
+                norm_ip(t, t.min(), t.max())
+
+        if scale_each is True:
+            for t in tensor:  # loop over mini-batch dimension
+                norm_range_helper(t, norm_range)
+        else:
+            norm_range_helper(tensor, norm_range)
+
+    # if single image, just return
+    if tensor.shape[0] == 1:
+        return tensor.squeeze(axis=0)
+
+    # make the batch of images into a grid
+    nmaps = tensor.shape[0]
+    xmaps = min(nrow, nmaps)
+    ymaps = int(np.ceil(float(nmaps) / xmaps))
+    height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] + padding)
+    grid = nd.empty(shape=(3, height * ymaps + padding, width * xmaps + padding),
+                    dtype=tensor.dtype, ctx=tensor.context)
+    grid[:] = pad_value
+    k = 0
+    for y in range(ymaps):
+        for x in range(xmaps):
+            if k >= nmaps:
+                break
+            start1 = y * height + padding
+            end1 = start1 + height - padding
+            start2 = x * width + padding
+            end2 = start2 + width - padding
+            grid[:, start1:end1, start2:end2] = tensor[k]
+            k = k + 1
+    return grid
+
+
+def _save_image(image, filename, nrow=8, padding=2):
+    """Saves a given Tensor into an image file. If the input tensor contains multiple images,
+    a grid of images will be saved.
+
+    Parameters
+    ----------
+        image : `NDArray`
+            Input image(s) in the format of HW, CHW, or NCHW.
+        filename : str
+            Filename of the saved image(s).
+        nrow : int
+            Number of images displayed in each row of the grid. The Final grid size is
+            (batch_size / `nrow`, `nrow`).
+        padding : int
+            Padding value for each image in the grid.
+    """
+    if not isinstance(image, NDArray):
+        raise TypeError('MXNet NDArray expected, received {}'.format(str(type(image))))
+    image = _prepare_image(image, nrow=nrow, padding=padding)
+    if Image is None:
+        raise ImportError('saving image failed because PIL is not found')
+    im = Image.fromarray(image.asnumpy())
+    im.save(filename)
+
+
+def _prepare_image(img, nrow=8, padding=2):
+    """Given an image of format HW, CHW, or NCHW, returns a image of format HWC.
+    If the input is a batch of images, a grid of images is made by stitching them together.
+    For float input data types, the values are normalized one image at a time to fit in the range
+    `[0, 255]`. 'uint8` values are unchanged. The following two normalization algorithms are used
+    for different conditions:
+    1. If the input values are all positive, they are rescaled so that the largest one is 255.
+    2. If any input value is negative, the values are shifted so that the input value 0.0 is at 127.
+    They are then rescaled so that either the smallest value is 0, or the largest one is 255.
+    This logic is adapted from the `image()` function in
+    https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/python/summary/summary.py
+    It returns an image with as `NDArray` with the color channel in the end of the dimensions.
+    """
+    assert img.ndim == 2 or img.ndim == 3 or img.ndim == 4
+    if isinstance(img, NDArray):
+        if img.dtype == np.uint8:
+            return make_image_grid(img, nrow=nrow, padding=padding).transpose((1, 2, 0))
+        elif img.dtype == np.float16 or img.dtype == np.float32 or img.dtype == np.float64:
+            min_val = img.min().asscalar()
+            max_val = img.max().asscalar()
+            if min_val >= 0:
+                min_val = 0.0
+            else:
+                min_val += 127.0
+                max_val += 127.0
+                img = img + 127.0
+            return (make_image_grid(img, nrow=nrow, padding=padding, normalize=True,
+                                    norm_range=(min_val, max_val),
+                                    scale_each=True) * 255.0).astype(np.uint8).transpose((1, 2, 0))
+        else:
+            raise ValueError('expected input image dtype is one of uint8, float16, float32, '
+                             'and float64, received dtype {}'.format(str(img.dtype)))
+    else:
+        raise TypeError('expected MXNet NDArray, while received type {}'.format(str(type(img))))
+
+
+def _make_metadata_tsv(metadata, save_path):
+    """Given an `NDArray` or a `numpy.ndarray` as metadata e.g. labels, save the flattened array
+    into the file metadata.tsv under the path provided by the user. Made to satisfy the requirement
+    in the following link:
+    https://www.tensorflow.org/programmers_guide/embedding#metadata"""
+    if isinstance(metadata, NDArray):
+        metadata = metadata.asnumpy().flatten()
+    elif isinstance(metadata, np.ndarray):
+        metadata = metadata.flatten()
+    else:
+        raise TypeError('expected NDArray of np.ndarray, while received '
+                        'type {}'.format(str(type(metadata))))
+    metadata = [str(x) for x in metadata]
+    with open(os.path.join(save_path, 'metadata.tsv'), 'w') as f:
+        for x in metadata:
+            f.write(x + '\n')
+
+
+def _make_sprite_image(images, save_path):
+    """Given an NDArray as a batch images, make a sprite image out of it following the rule
+    defined in
+    https://www.tensorflow.org/programmers_guide/embedding
+    and save it in sprite.png under the path provided by the user."""
+    assert isinstance(images, NDArray)
+    shape = images.shape
+    nrow = int(np.ceil(np.sqrt(shape[0])))
+    _save_image(images, os.path.join(save_path, 'sprite.png'), nrow=nrow, padding=0)
+
+
+def _add_embedding_config(file_path, global_step, has_metadata=False,
+                          label_img_shape=None, tag='default'):
+    """Creates a config file used by the embedding projector.
+    Adapted from the TensorFlow function `visualize_embeddings()` at
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorboard/plugins/projector/__init__.py"""
+    with open(os.path.join(file_path, 'projector_config.pbtxt'), 'a') as f:
+        s = 'embeddings {\n'
+        s += 'tensor_name: "{}:{}"\n'.format(tag, global_step)
+        s += 'tensor_path: "{}"\n'.format(os.path.join(global_step, 'tensors.tsv'))
+        if has_metadata:
+            s += 'metadata_path: "{}"\n'.format(os.path.join(global_step, 'metadata.tsv'))
+        if label_img_shape is not None:
+            if len(label_img_shape) != 4:
+                logging.warning('expected 4D sprite image in the format NCHW, while received image'
+                                ' ndim=%d, skipping saving sprite'
+                                ' image info', len(label_img_shape))
+            else:
+                s += 'sprite {\n'
+                s += 'image_path: "{}"\n'.format(os.path.join(global_step, 'sprite.png'))
+                s += 'single_image_dim: {}\n'.format(label_img_shape[3])
+                s += 'single_image_dim: {}\n'.format(label_img_shape[2])
+                s += '}\n'
+        s += '}\n'
+        f.write(s)
+
+
+def _save_embedding_tsv(data, file_path):
+    """Given a 2D `NDarray` or a `numpy.ndarray` as embeding,
+    save it in tensors.tsv under the path provided by the user."""
+    if isinstance(data, np.ndarray):
+        data_list = data.tolist()
+    elif isinstance(data, NDArray):
+        data_list = data.asnumpy().tolist()
+    else:
+        raise TypeError('expected NDArray of np.ndarray, while received type {}'.format(
+            str(type(data))))
+    with open(os.path.join(file_path, 'tensors.tsv'), 'w') as f:
+        for x in data_list:
+            x = [str(i) for i in x]
+            f.write('\t'.join(x) + '\n')
diff --git a/python/mxnet/contrib/summary/writer.py b/python/mxnet/contrib/summary/writer.py
new file mode 100644
index 00000000000..e33dd5da212
--- /dev/null
+++ b/python/mxnet/contrib/summary/writer.py
@@ -0,0 +1,477 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""APIs for logging data in the event file."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+import json
+import os
+import logging
+from .proto import event_pb2
+from .proto import summary_pb2
+from .event_file_writer import EventFileWriter
+from .summary import scalar_summary, histogram_summary, image_summary, audio_summary
+from .summary import text_summary, pr_curve_summary
+from .utils import _save_embedding_tsv, _make_sprite_image, _make_metadata_tsv
+from .utils import _add_embedding_config, _make_numpy_array
+
+
+class SummaryToEventTransformer(object):
+    """This class is adapted with minor modifications from
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/summary/writer/writer.py#L125
+    Users should not use this class directly for logging MXNet data.
+    This class abstractly implements the SummaryWriter API: add_summary.
+    The endpoint generates an event protobuf from the Summary object, and passes
+    the event protobuf to _event_writer, which is of type EventFileWriter, for logging.
+    """
+    # TODO(junwu): Need to check its compatibility with using ONNX for visualizing MXNet graphs.
+    def __init__(self, event_writer):
+        """Initializes the _event_writer with the passed-in value.
+
+        Parameters
+        ----------
+          event_writer: EventFileWriter
+              An event file writer writing events to the files in the path `logdir`.
+        """
+        self._event_writer = event_writer
+        # This set contains tags of Summary Values that have been encountered
+        # already. The motivation here is that the SummaryWriter only keeps the
+        # metadata property (which is a SummaryMetadata proto) of the first Summary
+        # Value encountered for each tag. The SummaryWriter strips away the
+        # SummaryMetadata for all subsequent Summary Values with tags seen
+        # previously. This saves space.
+        self._seen_summary_tags = set()
+
+    def add_summary(self, summary, global_step=None):
+        """Adds a `Summary` protocol buffer to the event file.
+        This method wraps the provided summary in an `Event` protocol buffer and adds it
+        to the event file.
+
+        Parameters
+        ----------
+          summary : A `Summary` protocol buffer
+              Optionally serialized as a string.
+          global_step: Number
+              Optional global step value to record with the summary.
+        """
+        if isinstance(summary, bytes):
+            summ = summary_pb2.Summary()
+            summ.ParseFromString(summary)
+            summary = summ
+
+        # We strip metadata from values with tags that we have seen before in order
+        # to save space - we just store the metadata on the first value with a
+        # specific tag.
+        for value in summary.value:
+            if not value.metadata:
+                continue
+
+            if value.tag in self._seen_summary_tags:
+                # This tag has been encountered before. Strip the metadata.
+                value.ClearField("metadata")
+                continue
+
+            # We encounter a value with a tag we have not encountered previously. And
+            # it has metadata. Remember to strip metadata from future values with this
+            # tag string.
+            self._seen_summary_tags.add(value.tag)
+
+        event = event_pb2.Event(summary=summary)
+        self._add_event(event, global_step)
+
+    def _add_event(self, event, step):
+        event.wall_time = time.time()
+        if step is not None:
+            event.step = int(step)
+        self._event_writer.add_event(event)
+
+
+class FileWriter(SummaryToEventTransformer):
+    """This class is adapted from
+    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/summary/writer/writer.py.
+    Even though this class provides user-level APIs in TensorFlow, it is recommended to use the
+    interfaces defined in the class `SummaryWriter` (see below) for logging in MXNet as they are
+    directly compatible with the MXNet NDArray type.
+    This class writes `Summary` protocol buffers to event files. The `FileWriter` class provides
+    a mechanism to create an event file in a given directory and add summaries and events to it.
+    The class updates the file contents asynchronously.
+    """
+    def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=None):
+        """Creates a `FileWriter` and an event file.
+        On construction the summary writer creates a new event file in `logdir`.
+        This event file will contain `Event` protocol buffers constructed when you
+        call one of the following functions: `add_summary()`, or `add_event()`.
+
+        Parameters
+        ----------
+            logdir : str
+                Directory where event file will be written.
+            max_queue : int
+                Size of the queue for pending events and summaries.
+            flush_secs: Number
+                How often, in seconds, to flush the pending events and summaries to disk.
+            filename_suffix : str
+                Every event file's name is suffixed with `filename_suffix` if provided.
+        """
+        event_writer = EventFileWriter(logdir, max_queue, flush_secs, filename_suffix)
+        super(FileWriter, self).__init__(event_writer)
+
+    def __enter__(self):
+        """Make usable with "with" statement."""
+        return self
+
+    def __exit__(self, unused_type, unused_value, unused_traceback):
+        """Make usable with "with" statement."""
+        self.close()
+
+    def get_logdir(self):
+        """Returns the directory where event file will be written."""
+        return self._event_writer.get_logdir()
+
+    def add_event(self, event):
+        """Adds an event to the event file.
+
+        Parameters
+        ----------
+            event : An `Event` protocol buffer.
+        """
+        self._event_writer.add_event(event)
+
+    def flush(self):
+        """Flushes the event file to disk.
+        Call this method to make sure that all pending events have been written to disk.
+        """
+        self._event_writer.flush()
+
+    def close(self):
+        """Flushes the event file to disk and close the file.
+        Call this method when you do not need the summary writer anymore.
+        """
+        self._event_writer.close()
+
+    def reopen(self):
+        """Reopens the EventFileWriter.
+        Can be called after `close()` to add more events in the same directory.
+        The events will go into a new events file. Does nothing if the EventFileWriter
+        was not closed.
+        """
+        self._event_writer.reopen()
+
+
+class SummaryWriter(object):
+    """This class is adapted with modifications in support of the MXNet NDArray types from
+    https://github.com/lanpa/tensorboard-pytorch/blob/master/tensorboardX/writer.py.
+    The `SummaryWriter` class provides a high-level api to create an event file in a
+    given directory and add summaries and events to it. This class writes data to the
+    event file asynchronously.
+    This class is a wrapper of the FileWriter class. It's recommended that users use
+    the APIs of this class to log MXNet data for visualization as they are directly compatible with
+    the MXNet data types.
+
+    Examples
+    --------
+    >>> data = mx.nd.random.uniform(size=(10, 10))
+    >>> with SummaryWriter(logdir='logs') as sw:
+    >>>     sw.add_histogram(tag='my_hist', values=data, global_step=0, bins=100)
+    """
+    def __init__(self, logdir, max_queue=10, flush_secs=120, filename_suffix=None):
+        """
+        Creates a `SummaryWriter` and an event file.
+        On construction the summary writer creates a new event file in `logdir`.
+        This event file will contain `Event` protocol buffers constructed when you
+        call one of the following functions: `add_audio()`, `add_embedding()`,
+        `add_histogram()`, `add_image()`, `add_pr_curve()`, `add_scalar()`, and `add_text()`.
+        Please make sure that the `logdir` used here for initiailizing `SummaryWriter`
+        matches the `--logdir` parameter you passed to the `tensorboard` binary in the command line
+        for launching TensorBoard.
+
+        Parameters
+        ----------
+            logdir : str
+                Directory where event file will be written.
+            max_queue : int
+                Size of the queue for pending events and summaries.
+            flush_secs: Number
+                How often, in seconds, to flush the pending events and summaries to disk.
+            filename_suffix : str
+                Every event file's name is suffixed with `filename_suffix` if provided.
+        """
+        self._file_writer = FileWriter(logdir=logdir, max_queue=max_queue,
+                                       flush_secs=flush_secs, filename_suffix=filename_suffix)
+        self._default_bins = None
+        self._text_tags = []
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
+    def _get_default_bins(self):
+        """Ported from the C++ function InitDefaultBucketsInner() in the following file.
+        https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/lib/histogram/histogram.cc
+        See the following tutorial for more details on how TensorFlow initialize bin distribution.
+        https://www.tensorflow.org/programmers_guide/tensorboard_histograms"""
+        if self._default_bins is None:
+            v = 1E-12
+            buckets = []
+            neg_buckets = []
+            while v < 1E20:
+                buckets.append(v)
+                neg_buckets.append(-v)
+                v *= 1.1
+            self._default_bins = neg_buckets[::-1] + [0] + buckets
+        return self._default_bins
+
+    def get_logdir(self):
+        """Returns the logging directory associated with this `SummaryWriter`."""
+        return self._file_writer.get_logdir()
+
+    def add_scalar(self, tag, value, global_step=None):
+        """Adds scalar data to the event file.
+
+        Parameters
+        ----------
+            tag : str
+                Name for the `value`.
+            value : float
+                Value to be saved.
+            global_step : int
+                Global step value to record.
+        """
+        self._file_writer.add_summary(scalar_summary(tag, value), global_step)
+
+    def add_histogram(self, tag, values, global_step=None, bins='default'):
+        """Add histogram data to the event file.
+
+        Note: This function internally calls `asnumpy()` if `values` is an MXNet NDArray.
+        Since `asnumpy()` is a blocking function call, this function would block the main
+        thread till it returns. It may consequently affect the performance of async execution
+        of the MXNet engine.
+
+        Parameters
+        ----------
+            tag : str
+                Name for the `values`.
+            values : MXNet `NDArray` or `numpy.ndarray`
+                Values for building histogram.
+            global_step : int
+                Global step value to record.
+            bins : int or sequence of scalars or str
+                If `bins` is an int, it defines the number equal-width bins in the range
+                `(values.min(), values.max())`.
+                If `bins` is a sequence, it defines the bin edges, including the rightmost edge,
+                allowing for non-uniform bin width.
+                If `bins` is a str equal to 'default', it will use the bin distribution
+                defined in TensorFlow for building histogram.
+                Ref: https://www.tensorflow.org/programmers_guide/tensorboard_histograms
+                The rest of supported strings for `bins` are 'auto', 'fd', 'doane', 'scott',
+                'rice', 'sturges', and 'sqrt'. etc. See the documentation of `numpy.histogram`
+                for detailed definitions of those strings.
+                https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
+        """
+        if bins == 'default':
+            bins = self._get_default_bins()
+        self._file_writer.add_summary(histogram_summary(tag, values, bins), global_step)
+
+    def add_image(self, tag, image, global_step=None):
+        """Add image data to the event file.
+        This function supports input as a 2D, 3D, or 4D image.
+        If the input image is 2D, a channel axis is prepended as the first dimension
+        and image will be replicated three times and concatenated along the channel axis.
+        If the input image is 3D, it will be replicated three times and concatenated along
+        the channel axis. If the input image is 4D, which is a batch images, all the
+        images will be spliced as a big square image for display.
+
+        Note: This function requires the ``pillow`` package.
+
+        Note: This function internally calls `asnumpy()` if `values` is an MXNet NDArray.
+        Since `asnumpy()` is a blocking function call, this function would block the main
+        thread till it returns. It may consequently affect the performance of async execution
+        of the MXNet engine.
+
+        Parameters
+        ----------
+            tag : str
+                Name for the `image`.
+            image : MXNet `NDArray` or `numpy.ndarray`
+                Image is one of the following formats: (H, W), (C, H, W), (N, C, H, W).
+                For float image data types, the values are normalized one image at a time to fit
+                in the range `[0, 255]`. 'uint8` values are unchanged. The following two
+                normalization algorithms are used for different conditions:
+                1. If the input values are all positive, they are rescaled so that the largest one
+                is 255.
+                2. If any input value is negative, the values are shifted so that the input value
+                0.0 is at 127.
+                They are then rescaled so that either the smallest value is 0, or the largest
+                one is 255.
+            global_step : int
+                Global step value to record.
+        """
+        self._file_writer.add_summary(image_summary(tag, image), global_step)
+
+    def add_audio(self, tag, audio, sample_rate=44100, global_step=None):
+        """Add audio data to the event file.
+        Note: This function internally calls `asnumpy()` if `values` is an MXNet NDArray.
+        Since `asnumpy()` is a blocking function call, this function would block the main
+        thread till it returns. It may consequently affect the performance of async execution
+        of the MXNet engine.
+
+        Parameters
+        ----------
+            tag : str
+                Name for the `audio`.
+            audio : MXNet `NDArray` or `numpy.ndarray`
+                Audio data squeezable to a 1D tensor. The values of the tensor are in the range
+                `[-1, 1]`.
+            sample_rate : int
+                Sample rate in Hz.
+            global_step : int
+                Global step value to record.
+        """
+        self._file_writer.add_summary(audio_summary(tag, audio, sample_rate=sample_rate),
+                                      global_step)
+
+    def add_text(self, tag, text, global_step=None):
+        """Add text data to the event file.
+
+        Parameters
+        ----------
+            tag : str
+                Name for the `text`.
+            text : str
+                Text to be saved to the event file.
+            global_step : int
+                Global step value to record.
+        """
+        self._file_writer.add_summary(text_summary(tag, text), global_step)
+        if tag not in self._text_tags:
+            self._text_tags.append(tag)
+            extension_dir = self.get_logdir() + '/plugins/tensorboard_text/'
+            if not os.path.exists(extension_dir):
+                os.makedirs(extension_dir)
+            with open(extension_dir + 'tensors.json', 'w') as fp:
+                json.dump(self._text_tags, fp)
+
+    def add_embedding(self, tag, embedding, labels=None, images=None, global_step=None):
+        """Adds embedding projector data to the event file. It will also create a config file
+        used by the embedding projector in TensorBoard.
+        See the following reference for the meanings of labels and images.
+        Ref: https://www.tensorflow.org/versions/r1.2/get_started/embedding_viz
+
+        Note: This function internally calls `asnumpy()` if `values` is an MXNet NDArray.
+        Since `asnumpy()` is a blocking function call, this function would block the main
+        thread till it returns. It may consequently affect the performance of async execution
+        of the MXNet engine.
+
+        Parameters
+        ----------
+            tag : str
+                Name for the `embedding`.
+            embedding : MXNet `NDArray` or  `numpy.ndarray`
+                A matrix whose each row is the feature vector of a data point.
+            labels : list of elements that can be converted to strings
+                Labels corresponding to the data points in the `embedding`.
+            images : MXNet `NDArray` or `numpy.ndarray`
+                Images of format NCHW corresponding to the data points in the `embedding`.
+            global_step : int
+                Global step value to record.
+        """
+        embedding_shape = embedding.shape
+        if len(embedding_shape) != 2:
+            raise ValueError('expected 2D NDArray as embedding data, while received an array with'
+                             ' ndim=%d' % len(embedding_shape))
+        if global_step is None:
+            global_step = 0
+        save_path = os.path.join(self.get_logdir(), str(global_step).zfill(5))
+        try:
+            os.makedirs(save_path)
+        except OSError:
+            logging.warning('embedding dir exists, did you set global_step for add_embedding()?')
+        if labels is not None:
+            if labels.ndim != 1:
+                raise ValueError('expected 1D ndarray as labels')
+            if embedding_shape[0] != len(labels):
+                raise ValueError('expected equal values of embedding first dim and length of '
+                                 'labels, while received %d and %d for each'
+                                 % (embedding_shape[0], len(labels)))
+            _make_metadata_tsv(labels, save_path)
+        if images is not None:
+            img_labels_shape = images.shape
+            if embedding_shape[0] != img_labels_shape[0]:
+                raise ValueError('expected equal first dim size of embedding and images,'
+                                 ' while received %d and %d for each' % (embedding_shape[0],
+                                                                         img_labels_shape[0]))
+            _make_sprite_image(images, save_path)
+        _save_embedding_tsv(embedding, save_path)
+        _add_embedding_config(self.get_logdir(), str(global_step).zfill(5), labels is not None,
+                              images.shape, tag)
+
+    def add_pr_curve(self, tag, labels, predictions, num_thresholds,
+                     global_step=None, weights=None):
+        """Adds precision-recall curve.
+
+        Note: This function internally calls `asnumpy()` if `values` is an MXNet NDArray.
+        Since `asnumpy()` is a blocking function call, this function would block the main
+        thread till it returns. It may consequently affect the performance of async execution
+        of the MXNet engine.
+
+        Parameters
+        ----------
+            tag : str
+                A tag attached to the summary. Used by TensorBoard for organization.
+            labels : MXNet `NDArray` or `numpy.ndarray`.
+                The ground truth values. A tensor of 0/1 values with arbitrary shape.
+            predictions : MXNet `NDArray` or `numpy.ndarray`.
+                A float32 tensor whose values are in the range `[0, 1]`. Dimensions must match
+                those of `labels`.
+            num_thresholds : int
+                Number of thresholds, evenly distributed in `[0, 1]`, to compute PR metrics for.
+                Should be `>= 2`. This value should be a constant integer value, not a tensor
+                that stores an integer.
+                The thresholds for computing the pr curves are calculated in the following way:
+                `width = 1.0 / (num_thresholds - 1),
+                thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]`.
+            global_step : int
+                Global step value to record.
+            weights : MXNet `NDArray` or `numpy.ndarray`.
+                Optional float32 tensor. Individual counts are multiplied by this value.
+                This tensor must be either the same shape as or broadcastable to the `labels`
+                tensor.
+        """
+        if num_thresholds < 2:
+            raise ValueError('num_thresholds must be >= 2')
+        labels = _make_numpy_array(labels)
+        predictions = _make_numpy_array(predictions)
+        self._file_writer.add_summary(pr_curve_summary(tag, labels, predictions,
+                                                       num_thresholds, weights), global_step)
+
+    def flush(self):
+        """Flushes pending events to the file."""
+        self._file_writer.flush()
+
+    def close(self):
+        """Closes the event file for writing."""
+        self._file_writer.close()
+
+    def reopen(self):
+        """Reopens the event file for writing."""
+        self._file_writer.reopen()
diff --git a/python/setup.py b/python/setup.py
index cf94adf982d..d8d5f204ea4 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -99,6 +99,19 @@ def config_cython():
         return []
 
 
+def compile_summary_protobuf():
+    proto_path = 'mxnet/contrib/summary/proto'
+    proto_files = os.path.join(proto_path, '*.proto')
+    cmd = 'protoc ' + proto_files + ' --python_out=.'
+    return os.system(cmd)
+
+
+if compile_summary_protobuf() != 0:
+    print('WARNING: Compiling summary protocol buffers failed. You will not be '
+          'able to use the summary logging APIs for visualizing data in TensorBoard. '
+          'Please make sure that you have installed protobuf3 compiler and runtime correctly.')
+
+
 setup(name='mxnet',
       version=__version__,
       description=open(os.path.join(CURRENT_DIR, 'README.md')).read(),
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index ef025e4b938..978bf1c4c73 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -45,6 +45,7 @@
 from test_sparse_ndarray import test_sparse_nd_setitem, test_sparse_nd_binary_scalar_op
 from test_sparse_operator import *
 from test_ndarray import *
+from test_summary import *
 
 set_default_context(mx.gpu(0))
 del test_support_vector_machine_l1_svm
diff --git a/tests/python/unittest/test_summary.py b/tests/python/unittest/test_summary.py
new file mode 100644
index 00000000000..df815aaabcc
--- /dev/null
+++ b/tests/python/unittest/test_summary.py
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import shutil
+from mxnet.contrib.summary import SummaryWriter
+from mxnet.contrib.summary.utils import _make_metadata_tsv, make_image_grid, _make_sprite_image
+from mxnet.contrib.summary.utils import _add_embedding_config, _save_embedding_tsv
+from mxnet.test_utils import *
+from common import with_seed
+
+# DO NOT CHANGE THESE NAMES AS THEY FOLLOW THE DEFINITIONS IN TENSORBOARD
+_LOGDIR = './logs_for_tensorboard'
+_METADATA_TSV = 'metadata.tsv'
+_SPRITE_PNG = 'sprite.png'
+_PROJECTOR_CONFIG_PBTXT = 'projector_config.pbtxt'
+_TENSORS_TSV = 'tensors.tsv'
+_EVENT_FILE_PREFIX = 'events.out.tfevents'
+_PLUGINS = 'plugins'
+_TENSORBOARD_TEXT = 'tensorboard_text'
+_TENSORS_JSON = 'tensors.json'
+
+
+def make_logdir():
+    if not os.path.exists(_LOGDIR):
+        try:
+            os.mkdir(_LOGDIR)
+        except:
+            raise OSError('failed to make dir at {}'.format(_LOGDIR))
+
+
+def safe_remove_file(file_path):
+    if file_exists(file_path):
+        try:
+            os.remove(file_path)
+        except:
+            raise OSError('failed to remove file at {}'.format(file_path))
+
+
+def remove_logdir():
+    if dir_exists(_LOGDIR):
+        try:
+            shutil.rmtree(_LOGDIR)
+        except:
+            raise OSError('failed to remove dir {}'.format(_LOGDIR))
+
+
+def safe_remove_dir(dir_path):
+    if dir_empty(dir_path):
+        try:
+            shutil.rmtree(dir_path)
+        except:
+            raise OSError('failed to remove dir {}'.format(dir_path))
+
+
+def safe_remove_logdir():
+    safe_remove_dir(_LOGDIR)
+
+
+def file_exists(file_path):
+    return os.path.exists(file_path) and os.path.isfile(file_path)
+
+
+def dir_exists(dir_path):
+    return os.path.exists(dir_path) and os.path.isdir(dir_path)
+
+
+def file_exists_with_prefix(file_path, prefix=None):
+    if prefix is None:
+        return file_exists(file_path)
+    filename = os.path.basename(file_path)
+    if filename.startswith(prefix):
+        return True
+    return False
+
+
+def dir_empty(dir_path):
+    for _, dirnames, files in os.walk(dir_path):
+        if len(dirnames) != 0 or len(files) != 0:
+            return False
+    return True
+
+
+def logdir_empty():
+    return dir_empty(_LOGDIR)
+
+
+@with_seed()
+def test_make_metadata_tsv():
+    make_logdir()
+    shape = rand_shape_nd(num_dim=4, dim=10)
+    data = rand_ndarray(shape=shape, stype='default')
+    _make_metadata_tsv(data, _LOGDIR)
+    file_path = os.path.join(_LOGDIR, 'metadata.tsv')
+    data_loaded = np.loadtxt(file_path, dtype=data.dtype)
+    assert same(data.asnumpy(), data_loaded.reshape(data.shape))
+    safe_remove_file(file_path)
+    safe_remove_logdir()
+
+
+@with_seed()
+def test_make_image_grid():
+    def test_2d_input():
+        shape = rand_shape_2d()
+        data = rand_ndarray(shape, 'default')
+        grid = make_image_grid(data)
+        assert grid.ndim == 3
+        assert grid.shape[0] == 3
+        assert grid.shape[1:] == data.shape
+        assert same(grid[0].asnumpy(), grid[1].asnumpy())
+        assert same(grid[0].asnumpy(), grid[2].asnumpy())
+        assert same(grid[0].asnumpy(), data.asnumpy())
+
+    def test_3d_single_channel_input():
+        shape = rand_shape_3d(dim0=1)
+        data = rand_ndarray(shape, 'default')
+        assert data.shape[0] == 1  # single channel
+        grid = make_image_grid(data)
+        assert grid.ndim == 3
+        assert grid.shape[0] == 3
+        assert same(grid[0].asnumpy(), grid[1].asnumpy())
+        assert same(grid[0].asnumpy(), grid[2].asnumpy())
+        assert same(grid[0:1].asnumpy(), data.asnumpy())
+
+    def test_3d_three_channel_input():
+        shape = rand_shape_3d()
+        shape = (3,) + shape[1:]
+        data = rand_ndarray(shape, 'default')
+        grid = make_image_grid(data)
+        assert grid.ndim == 3
+        assert grid.shape[0] == 3
+        assert same(grid.asnumpy(), data.asnumpy())
+
+    def test_4d_single_batch_single_channel_input():
+        shape = list(rand_shape_nd(4))
+        shape[0] = 1
+        shape[1] = 1
+        shape = tuple(shape)
+        data = rand_ndarray(shape, 'default')
+        grid = make_image_grid(data)
+        assert grid.ndim == 3
+        assert grid.shape[0] == 3
+        assert same(grid[0].asnumpy(), grid[1].asnumpy())
+        assert same(grid[0].asnumpy(), grid[2].asnumpy())
+        assert same(grid[0].reshape(data.shape).asnumpy(), data.asnumpy())
+
+    def test_4d_multiple_batch_input():
+        shape_list = list(rand_shape_nd(4))
+        shape_list[0] = 10
+        num_channels = [1, 3]
+        for c in num_channels:
+            shape_list[1] = c
+            shape = tuple(shape_list)
+            data = rand_ndarray(shape, 'default')
+            grid = make_image_grid(data)
+            assert grid.ndim == 3
+            assert grid.shape[0] == 3
+
+    test_2d_input()
+    test_3d_single_channel_input()
+    test_3d_three_channel_input()
+    test_4d_single_batch_single_channel_input()
+    test_4d_multiple_batch_input()
+
+
+def test_make_sprite_image():
+    dtypes = [np.uint8, np.float32, np.float64]
+    ndims = [2, 3, 4]
+    for dtype in dtypes:
+        for ndim in ndims:
+            shape_list = list(rand_shape_nd(num_dim=ndim))
+            if ndim == 3:
+                shape_list[0] = 3
+            elif ndim == 4:
+                shape_list[1] = 3
+            data = rand_ndarray(tuple(shape_list), 'default', dtype=dtype)
+            make_logdir()
+            _make_sprite_image(data, _LOGDIR)
+            file_path = os.path.join(_LOGDIR, _SPRITE_PNG)
+            assert file_exists(file_path)
+            safe_remove_file(file_path)
+            safe_remove_logdir()
+
+
+def test_add_embedding_config():
+    make_logdir()
+    _add_embedding_config(_LOGDIR, str(10000), True, (4, 3, 5, 5))
+    file_path = os.path.join(_LOGDIR, _PROJECTOR_CONFIG_PBTXT)
+    assert file_exists(file_path)
+    safe_remove_file(file_path)
+    safe_remove_logdir()
+
+
+def test_save_ndarray_tsv():
+    dtypes = [np.uint8, np.float32, np.float64]
+    ndims = [2, 3, 4]
+    for dtype in dtypes:
+        for ndim in ndims:
+            shape = rand_shape_nd(ndim)
+            data = rand_ndarray(shape, 'default', dtype=dtype)
+            make_logdir()
+            _save_embedding_tsv(data, _LOGDIR)
+            file_path = os.path.join(_LOGDIR, _TENSORS_TSV)
+            safe_remove_file(file_path)
+            safe_remove_logdir()
+
+
+def check_event_file_and_remove_logdir():
+    """Check whether the event file exists and then remove the logdir."""
+    files = os.listdir(_LOGDIR)
+    assert len(files) == 1
+    file_path = os.path.join(_LOGDIR, files[0])
+    assert file_exists_with_prefix(file_path, _EVENT_FILE_PREFIX)
+    safe_remove_file(file_path)
+    safe_remove_logdir()
+
+
+@with_seed()
+def test_add_scalar():
+    sw = SummaryWriter(logdir=_LOGDIR)
+    sw.add_scalar(tag='test_add_scalar', value=np.random.uniform(), global_step=0)
+    sw.close()
+    check_event_file_and_remove_logdir()
+
+
+@with_seed()
+def test_add_histogram():
+    shape = rand_shape_nd(4)
+    sw = SummaryWriter(logdir=_LOGDIR)
+    sw.add_histogram(tag='test_add_histogram', values=mx.nd.random.normal(shape=shape), global_step=0, bins=100)
+    sw.close()
+    check_event_file_and_remove_logdir()
+
+
+@with_seed()
+def test_add_image():
+    shape = list(rand_shape_nd(4))
+    shape[1] = 3
+    shape = tuple(shape)
+    sw = SummaryWriter(logdir=_LOGDIR)
+    sw.add_image(tag='test_add_image', image=mx.nd.random.normal(shape=shape), global_step=0)
+    sw.close()
+    check_event_file_and_remove_logdir()
+
+
+@with_seed()
+def test_add_audio():
+    shape = (100,)
+    data = mx.nd.random.uniform(-1, 1, shape=shape)
+    sw = SummaryWriter(logdir=_LOGDIR)
+    sw.add_audio(tag='test_add_audio', audio=data)
+    sw.close()
+    check_event_file_and_remove_logdir()
+
+
+def check_and_remove_logdir_for_text():
+    """1. verify that tensors.json exists under _LOGDIR/plugins/tensorboard_text.
+    2. verify that the event files exists and remove it."""
+    # step 1
+    plugins_path = os.path.join(_LOGDIR, _PLUGINS)
+    tensorboard_text_path = os.path.join(plugins_path, _TENSORBOARD_TEXT)
+    file_path = os.path.join(tensorboard_text_path, _TENSORS_JSON)
+    assert file_exists(file_path)
+    safe_remove_file(file_path)
+    safe_remove_dir(tensorboard_text_path)
+    safe_remove_dir(plugins_path)
+    # step 2
+    event_files = os.listdir(_LOGDIR)
+    assert len(event_files) == 1
+    event_file_path = os.path.join(_LOGDIR, event_files[0])
+    assert file_exists_with_prefix(event_file_path, _EVENT_FILE_PREFIX)
+    safe_remove_file(event_file_path)
+    # remove logdir
+    safe_remove_logdir()
+
+
+def test_add_text():
+    # this will generate an event file under _LOGDIR and
+    # a json file called tensors.json under _LOGDIR/plugins/tensorboard_text/tensors.json
+    sw = SummaryWriter(logdir=_LOGDIR)
+    sw.add_text(tag='test_add_text', text='Hello MXNet!')
+    sw.close()
+    check_and_remove_logdir_for_text()
+
+
+def check_and_remove_for_embedding(global_step):
+    """1. verify projector_config.pbtxt exists under _LOGDIR.
+    2. verify folder str(global_step).zfill(5) exists under _LOGDIR.
+    3. verify metadata.tsv exists under _LOGDIR/str(global_step).zfill(5).
+    4. verify sprinte.png exists under _LOGDIR/str(global_step).zfill(5).
+    5. verify tensors.tsv exists under _LOGDIR/str(global_step).zfill(5).
+    6. remove all of them and _LOGDIR."""
+    # step 1
+    projector_file_path = os.path.join(_LOGDIR, _PROJECTOR_CONFIG_PBTXT)
+    assert file_exists(projector_file_path)
+    global_step_dir_path = os.path.join(_LOGDIR, str(global_step).zfill(5))
+    assert dir_exists(global_step_dir_path)
+    metadata_tsv_path = os.path.join(global_step_dir_path, _METADATA_TSV)
+    assert file_exists(metadata_tsv_path)
+    sprint_png_path = os.path.join(global_step_dir_path, _SPRITE_PNG)
+    assert file_exists(sprint_png_path)
+    tensors_tsv_path = os.path.join(global_step_dir_path, _TENSORS_TSV)
+    assert file_exists(tensors_tsv_path)
+
+    safe_remove_file(projector_file_path)
+    safe_remove_file(metadata_tsv_path)
+    safe_remove_file(sprint_png_path)
+    safe_remove_file(tensors_tsv_path)
+    safe_remove_dir(global_step_dir_path)
+    safe_remove_logdir()
+
+
+@with_seed()
+def test_add_embedding():
+    batch_size = 10
+    embedding = mx.nd.uniform(shape=(batch_size, 20))
+    labels = mx.nd.uniform(low=1, high=2, shape=(batch_size,)).astype('int32')
+    images = mx.nd.uniform(shape=(batch_size, 3, 10, 10))
+    global_step = np.random.randint(low=0, high=999999)
+    with SummaryWriter(logdir=_LOGDIR) as sw:
+        sw.add_embedding(tag='test_add_embedding', embedding=embedding, labels=labels,
+                         images=images, global_step=global_step)
+    check_and_remove_for_embedding(global_step)
+
+
+@with_seed()
+def test_add_pr_curve():
+    shape = (100,)
+    predictions = mx.nd.uniform(low=0.0, high=1.0, shape=shape)
+    labels = mx.nd.uniform(low=0, high=2, shape=shape).astype('int32')
+    num_threshodls = 100
+    with SummaryWriter(_LOGDIR) as sw:
+        sw.add_pr_curve(tag='test_add_pr_curve', labels=labels, predictions=predictions, num_thresholds=num_threshodls)
+    check_event_file_and_remove_logdir()
+
+
+if __name__ == '__main__':
+    remove_logdir()
+    import nose
+    nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services