You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/01/31 19:54:40 UTC

[GitHub] piiswrong closed pull request #8639: [WIP] Gluon object detection

piiswrong closed pull request #8639: [WIP] Gluon object detection
URL: https://github.com/apache/incubator-mxnet/pull/8639
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/object-detection/README.md b/example/object-detection/README.md
new file mode 100644
index 0000000000..0c0e2f864a
--- /dev/null
+++ b/example/object-detection/README.md
@@ -0,0 +1,5 @@
+# Object Detection API
+
+Here is the brand new object detection playground!
+
+ 
diff --git a/example/object-detection/__init__.py b/example/object-detection/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/block/__init__.py b/example/object-detection/block/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/block/anchor.py b/example/object-detection/block/anchor.py
new file mode 100644
index 0000000000..4c7f8da863
--- /dev/null
+++ b/example/object-detection/block/anchor.py
@@ -0,0 +1,169 @@
+"""Anchor generators.
+The job of the anchor generator is to create (or load) a collection
+of bounding boxes to be used as anchors.
+Generated anchors are assumed to match some convolutional grid or list of grid
+shapes.  For example, we might want to generate anchors matching an 8x8
+feature map and a 4x4 feature map.  If we place 3 anchors per grid location
+on the first feature map and 6 anchors per grid location on the second feature
+map, then 3*8*8 + 6*4*4 = 288 anchors are generated in total.
+To support fully convolutional settings, feature maps are passed as input,
+however, only shapes are used to infer the anchors.
+"""
+from __future__ import division
+import math
+from mxnet import gluon
+from mxnet import ndarray as nd
+from .registry import register, alias, create
+
+
+class ShapeExtractor(gluon.Block):
+    """
+
+    """
+    def __init__(self, positions=[0]):
+        super(ShapeExtractor, self).__init__()
+        if not isinstance(positions, (list, tuple)):
+            raise ValueError("positions must be list or tuple")
+        self._positions = positions
+
+    def forward(self, x, *args):
+        if x is None:
+            return x
+        xshape = x.shape
+        return nd.array([xshape[i] for i in self._positions])
+
+@register
+class GridAnchorGenerator(gluon.Block):
+    """
+
+    """
+    def __init__(self, size_ratios, strides=None, offsets=None, clip=None,
+                 im_size=(256.0, 256.0), layout='HWC'):
+        super(GridAnchorGenerator, self).__init__()
+        assert (isinstance(size_ratios, list) and size_ratios), (
+            "Invalid size_ratios list.")
+        for sr in size_ratios:
+            assert (isinstance(sr, (list, tuple)) and len(sr) == 2), (
+                "Each size_ratio pair must be length-2 tuple/list.")
+        self._size_ratios = size_ratios
+        if strides is not None:
+            assert len(strides) == 2, "strides must be either None or length-2 vector"
+        self._strides = strides
+        if offsets is not None:
+            assert len(offsets) == 2, "offsets must be either None or length-2 vector"
+        self._offsets = offsets
+        if clip is not None:
+            assert len(clip) == 4, "clip must be either None or length-4 vector"
+        self._clip = clip
+        assert len(im_size) == 2, "im_size must be (height, width)"
+        self._im_size = im_size
+        assert layout == 'HWC' or layout == 'CHW', "layout must be 'HWC' or 'CHW'"
+        self._layout = layout
+        with self.name_scope():
+            self.feat_size_extractor = ShapeExtractor([2, 3])
+            self.im_size_extractor = ShapeExtractor([2, 3])
+
+    @property
+    def num_depth(self):
+        """Returns the number of anchors per pixel/grid/location.
+        """
+        return len(self._size_ratios)
+
+    def forward(self, x, img=None, *args):
+        # input image size
+        im_height, im_width = self._get_im_size(self.im_size_extractor(img))
+        # feature size
+        feat_height, feat_width = self._get_feat_size(self.feat_size_extractor(x))
+        # stride
+        stride_h, stride_w = self._get_strides(feat_height, feat_width, im_height, im_width)
+        # offsets for center
+        offset_h, offset_w = self._get_offsets(feat_height, feat_width, im_height, im_width)
+        # generate anchors for each pixel/grid, as layout [HxWxC, 4]
+        centers = [[(i * stride_w + offset_w) / im_width, (j * stride_h + offset_h) / im_height]
+            for j in range(feat_height) for i in range(feat_width) for _ in self._size_ratios]
+        shapes = [[s * math.sqrt(r) / im_width, s / math.sqrt(r) / im_height]
+            for _ in range(feat_height) for _ in range(feat_width) for s, r in self._size_ratios]
+        # convert to ndarray and as corner [xmin, ymin, xmax, ymax]
+        shapes = nd.array(shapes) * 0.5
+        centers = nd.array(centers)
+        anchors = nd.concat(centers - shapes, centers + shapes, dim=1)
+
+        if self._clip is not None:
+            self._clip_anchors(anchors, self._clip)
+        # print(anchors.shape)
+        # anchors = anchors.reshape((feat_height, feat_width, self.num_depth, 4))
+        # if self._layout == 'CHW':
+        #     anchors = nd.transpose(anchors, (2, 0, 1, 3))
+
+        return anchors
+
+    def _get_im_size(self, im_size):
+        """Get original image size given ndarray shape data."""
+        im_height, im_width = self._im_size
+        if im_size is not None:
+            # infer image shape from data is available
+            assert im_size.size == 2, (
+                "Invalid data shape {}, expected (h, w)".format(im_size.shape))
+            im_height, im_width = im_size.asnumpy().astype('int')
+        return im_height, im_width
+
+    def _get_feat_size(self, feat_size):
+        """Get feature map size given ndarray shape data."""
+        assert feat_size.size == 2, (
+            "Invalid feat shape {}, expected (h, w)".format(feat_size.shape))
+        feat_height, feat_width = feat_size.asnumpy().astype('int')
+        return feat_height, feat_width
+
+    def _get_strides(self, feat_height, feat_width, im_height, im_width):
+        """Wrapping function for default grid strides."""
+        if self._strides is None:
+            stride_h = im_height / feat_height
+            stride_w = im_width / feat_width
+        else:
+            stride_h, stride_w = self._strides
+        return stride_h, stride_w
+
+    def _get_offsets(self, feat_height, feat_width, im_height, im_width):
+        """Wrapping function for default grid offsets."""
+        if self._offsets is None:
+            offset_h = 0.5 * im_height / feat_height
+            offset_w = 0.5 * im_width / feat_width
+        else:
+            offset_h , offset_w = self._offsets
+        return offset_h, offset_w
+
+    def _clip_anchors(self, anchors, clip_window):
+        """Clip all anchors to clip_window area.
+
+        Parameters
+        ----------
+        anchors : NDArray
+            N x 4 array
+        clip_window : list or tuple
+            [xmin, ymin, xmax, ymax] window
+
+        Returns
+        -------
+        a NDArray with clipped anchor boxes
+        """
+        l, t, r, b = nd.split(anchors, axis=1, num_outputs=4)
+        l = nd.maximum(clip_window[0], nd.minimum(clip_window[2], l))
+        t = nd.maximum(clip_window[1], nd.minimum(clip_window[3], t))
+        r = nd.maximum(clip_window[0], nd.minimum(clip_window[2], r))
+        b = nd.maximum(clip_window[1], nd.minimum(clip_window[3], b))
+        return nd.concat(l, t, r, b, dim=1)
+
+
+@register
+class SSDAnchorGenerator(GridAnchorGenerator):
+    """
+
+    """
+    def __init__(self, sizes, ratios, strides=None, offsets=None, clip=None,
+                 im_size=(300.0, 300.0), layout='HWC'):
+        assert len(sizes) > 0
+        assert len(ratios) > 0
+        size_ratios = [(s, ratios[0]) for s in sizes] + [(sizes[0], r) for r in ratios[1:]]
+        super(SSDAnchorGenerator, self).__init__(size_ratios, strides=strides,
+                                                 offsets=offsets, clip=clip,
+                                                 im_size=im_size, layout=layout)
diff --git a/example/object-detection/block/base.py b/example/object-detection/block/base.py
new file mode 100644
index 0000000000..de4cce70ba
--- /dev/null
+++ b/example/object-detection/block/base.py
@@ -0,0 +1,63 @@
+from mxnet import gluon
+
+
+class CornerToCenterBox(gluon.HybridBlock):
+    """Convert corner boxes to center boxes.
+    Corner boxes are encoded as (xmin, ymin, xmax, ymax)
+    Center boxes are encoded as (center_x, center_y, width, height)
+
+    Parameters
+    ----------
+    split : bool
+        Whether split boxes to individual elements after processing.
+
+    Returns
+    -------
+     A BxNx4 NDArray if split is False, or 4 BxNx1 NDArray if split is True
+    """
+    def __init__(self, split=False):
+        super(CornerToCenterBox, self).__init__()
+        self._split = split
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        xmin, ymin, xmax, ymax = F.split(x, axis=-1, num_outputs=4)
+        width = xmax - xmin
+        height = ymax - ymin
+        x = xmin + width / 2
+        y = ymin + height / 2
+        if not self._split:
+            return F.concat(x, y, width, height, dim=2)
+        else:
+            return x, y, width, height
+
+
+class CenterToCornerBox(gluon.HybridBlock):
+    """Convert center boxes to corner boxes.
+    Corner boxes are encoded as (xmin, ymin, xmax, ymax)
+    Center boxes are encoded as (center_x, center_y, width, height)
+
+    Parameters
+    ----------
+    split : bool
+        Whether split boxes to individual elements after processing.
+
+    Returns
+    -------
+     A BxNx4 NDArray if split is False, or 4 BxNx1 NDArray if split is True.
+    """
+    def __init__(self, split=False):
+        super(CenterToCornerBox, self).__init__()
+        self._split = split
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        x, y, w, h = F.split(x, axis=-1, num_outputs=4)
+        hw = w / 2
+        hh = h / 2
+        xmin = x - hw
+        ymin = y - hh
+        xmax = x + hw
+        ymax = y + hh
+        if not split:
+            return F.concat(xmin, ymin, xmax, ymax, dim=2)
+        else:
+            return xmin, ymin, xmax, ymax
diff --git a/example/object-detection/block/coder.py b/example/object-detection/block/coder.py
new file mode 100644
index 0000000000..ac967f7677
--- /dev/null
+++ b/example/object-detection/block/coder.py
@@ -0,0 +1,145 @@
+"""Encoder and Decoder functions.
+Encoders are used during training, which assign training targets.
+Decoders are used during testing/validation, which convert predictions back to
+normal boxes, etc.
+"""
+from mxnet import nd
+from mxnet import gluon
+from block.base import CornerToCenterBox
+from block.registry import register, alias, create
+
+
+class BoxEncoder(gluon.Block):
+    """A base class for box encoder."""
+    def __init__(self):
+        super(BoxEncoder, self).__init__()
+
+
+class HybridBoxEncoder(gluon.HybridBlock):
+    """A base class for hybrid box encoder."""
+    def __init__(self):
+        super(HybridBoxEncoder, self).__init__()
+
+
+class BoxDecoder(gluon.Block):
+    """A base class for box decoder."""
+    def __init__(self):
+        super(BoxDecoder, self).__init__()
+
+
+class HybridBoxDecoder(gluon.HybridBlock):
+    """A base class for hybrid box decoder."""
+    def __init__(self):
+        super(HybridBoxDecoder, self).__init__()
+
+
+class ClassEncoder(gluon.Block):
+    """A base class for classification encoder."""
+    def __init__(self):
+        super(ClassEncoder, self).__init__()
+
+
+class HybridClassEncoder(gluon.HybridBlock):
+    """A base class for hybrid classification encoder."""
+    def __init__(self):
+        super(HybridClassEncoder, self).__init__()
+
+
+class ClassDecoder(gluon.Block):
+    """A base class for classification decoder."""
+    def __init__(self):
+        super(ClassDecoder, self).__init__()
+
+
+class HybridClassDecoder(gluon.HybridBlock):
+    """A base class for hybrid classification decoder."""
+    def __init__(self):
+        super(HybridClassDecoder, self).__init__()
+
+
+@register
+@alias('rcnn_box_encoder')
+class NormalizedBoxCenterEncoder(BoxEncoder):
+    """
+
+    """
+    def __init__(self, stds=(0.1, 0.1, 0.2, 0.2)):
+        super(NormalizedBoxCenterEncoder, self).__init__()
+        assert len(stds) == 4, "Box Encoder requires 4 std values."
+        self._stds = stds
+        with self.name_scope():
+            self.corner_to_center = CornerToCenterBox(split=True)
+
+    def forward(self, samples, matches, anchors, refs, *args, **kwargs):
+        F = nd
+        # TODO(zhreshold): batch_pick, take multiple elements?
+        ref_boxes = nd.repeat(refs.reshape((0, 1, -1, 4)), axis=1, repeats=matches.shape[1])
+        ref_boxes = nd.split(ref_boxes, axis=-1, num_outputs=4, squeeze_axis=True)
+        ref_boxes = nd.concat(*[F.pick(ref_boxes[i], matches, axis=2).reshape((0, -1, 1)) for i in range(4)], dim=2)
+        g = self.corner_to_center(ref_boxes)
+        a = self.corner_to_center(anchors)
+        t0 = (g[0] - a[0]) / a[2] / self._stds[0]
+        t1 = (g[1] - a[1]) / a[3] / self._stds[1]
+        t2 = F.log(g[2] / a[2]) / self._stds[2]
+        t3 = F.log(g[3] / a[3]) / self._stds[3]
+        codecs = F.concat(t0, t1, t2, t3, dim=2)
+        temp = F.tile(samples.reshape((0, -1, 1)), reps=(1, 1, 4)) > 0.5
+        targets = F.where(temp, codecs, F.zeros_like(codecs))
+        masks = F.where(temp, F.ones_like(temp), F.zeros_like(temp))
+        return targets, masks
+
+
+@register
+@alias('rcnn_box_decoder')
+class NormalizedBoxCenterDecoder(HybridBoxDecoder):
+    """
+
+    """
+    def __init__(self, stds=(0.1, 0.1, 0.2, 0.2)):
+        super(NormalizedBoxCenterDecoder, self).__init__()
+        assert len(stds) == 4, "Box Encoder requires 4 std values."
+        self._stds = stds
+        with self.name_scope():
+            self.corner_to_center = CornerToCenterBox(split=True)
+
+    def hybrid_forward(self, F, x, anchors, *args, **kwargs):
+        a = self.corner_to_center(anchors)
+        p = F.split(x, axis=2, num_outputs=4)
+        ox = p[0] * self._stds[0] * a[2] + a[0]
+        oy = p[1] * self._stds[1] * a[3] + a[1]
+        ow = F.exp(p[2] * self._stds[2]) * a[2] / 2
+        oh = F.exp(p[3] * self._stds[3]) * a[3] / 2
+        return F.concat(ox - ow, oy - oh, ox + ow, oy + oh, dim=2)
+
+@register
+@alias('plus1_class_encoder')
+class MultiClassEncoder(ClassEncoder):
+    """
+
+    """
+    def __init__(self, ignore_label=-1):
+        super(MultiClassEncoder, self).__init__()
+        self._ignore_label = ignore_label
+
+    def forward(self, samples, matches, refs, *args, **kwargs):
+        refs = nd.repeat(refs.reshape((0, 1, -1)), axis=1, repeats=matches.shape[1])
+        target_ids = nd.pick(refs, matches, axis=2) + 1
+        targets = nd.where(samples > 0.5, target_ids, nd.ones_like(target_ids) * self._ignore_label)
+        targets = nd.where(samples < -0.5, nd.zeros_like(targets), targets)
+        return targets
+
+@register
+@alias('plus1_class_decoder')
+class MultiClassDecoder(HybridClassDecoder):
+    """
+
+    """
+    def __init__(self, axis=-1):
+        super(MultiClassDecoder, self).__init__()
+        self._axis = axis
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        pos_x = x.slice_axis(axis=self._axis, begin=1, end=-1)
+        cls_id = F.argmax(pos_x, self._axis)
+        scores = F.pick(pos_x, cls_id, axis=-1)
+        return cls_id, scores
diff --git a/example/object-detection/block/feature.py b/example/object-detection/block/feature.py
new file mode 100644
index 0000000000..6b7a78c61f
--- /dev/null
+++ b/example/object-detection/block/feature.py
@@ -0,0 +1,114 @@
+"""Feature extraction blocks.
+Feature or Multi-Feature extraction is a key component in object detection.
+Class predictor/Box predictor are usually applied on feature layer(s).
+A good feature extraction mechanism is critical to performance.
+"""
+import mxnet as mx
+from mxnet.symbol import Symbol
+from mxnet.gluon import HybridBlock, SymbolBlock
+from mxnet.gluon import nn
+from mxnet.gluon.model_zoo import vision
+from mxnet.base import string_types
+
+def parse_network(network, outputs, inputs, pretrained, ctx):
+    """Parse network with specified outputs and other arguments.
+
+    Parameters
+    ----------
+    network : str or HybridBlock or Symbol
+        Logic chain: load from gluon.model_zoo.vision if network is string.
+        Convert to Symbol if network is HybridBlock
+    outputs : str or list of str
+        The name of layers to be extracted as features
+    pretrained : bool
+        Use pretrained parameters as in gluon.model_zoo
+    ctx : Context
+        The context
+
+    Returns
+    -------
+    inputs : list of Symbol
+        Network input Symbols, usually ['data']
+    outputs : list of Symbol
+        Network output Symbols, usually as features
+    params : ParameterDict
+        Network parameters.
+    """
+    for i in range(len(inputs)):
+        if isinstance(inputs[i], string_types):
+            inputs[i] = mx.sym.var(inputs[i])
+        assert isinstance(inputs[i], Symbol), "Network expects inputs are Symbols."
+    if len(inputs) == 1:
+        inputs = inputs[0]
+    else:
+        inputs = mx.sym.Group(inputs)
+    params = None
+    if isinstance(network, string_types):
+        network = vision.get_model(network, pretrained=pretrained, ctx=ctx, prefix='')
+    if isinstance(network, HybridBlock):
+        params = network.collect_params()
+        network = network(inputs)
+    assert isinstance(network, Symbol), \
+        "FeatureExtractor requires the network argument to be either " \
+        "str, HybridBlock or Symbol, but got %s"%type(network)
+
+    if isinstance(outputs, string_types):
+        outputs = [outputs]
+    assert len(outputs) > 0, "At least one outputs must be specified."
+    outputs = [out if out.endswith('_output') else out + '_output' for out in outputs]
+    outputs = [network.get_internals()[out] for out in outputs]
+    return inputs, outputs, params
+
+
+class FeatureExtractor(SymbolBlock):
+    """Feature extractor.
+
+    Parameters
+    ----------
+    network : str or HybridBlock or Symbol
+        Logic chain: load from gluon.model_zoo.vision if network is string.
+        Convert to Symbol if network is HybridBlock
+    outputs : str or list of str
+        The name of layers to be extracted as features
+    inputs : list of str or list of Symbol
+        The inputs of network.
+    pretrained : bool
+        Use pretrained parameters as in gluon.model_zoo
+    ctx : Context
+        The context
+    """
+    def __init__(self, network, outputs, inputs=['data'], pretrained=False, ctx=mx.cpu()):
+        inputs, outputs, params = parse_network(network, outputs, inputs, pretrained, ctx)
+        super(FeatureExtractor, self).__init__(outputs, inputs, params=params)
+
+
+class FeatureExpander(SymbolBlock):
+    """Feature extractor with additional layers to append.
+    This is very common in SSD networks.
+
+    """
+    def __init__(self, network, outputs, num_filters, use_1x1_transition=True,
+                 use_bn=True, reduce_ratio=1.0, min_depth=128, global_pool=False,
+                 pretrained=False, ctx=mx.cpu(), inputs=['data']):
+        inputs, outputs, params = parse_network(network, outputs, inputs, pretrained, ctx)
+        # append more layers
+        y = outputs[-1]
+        for i, f in enumerate(num_filters):
+            if use_1x1_transition:
+                num_trans = max(min_depth, int(round(f * reduce_ratio)))
+                y = mx.sym.Convolution(
+                    y, num_filter=num_trans, kernel=(1, 1), no_bias=use_bn,
+                    name='expand_trans_conv{}'.format(i))
+                if use_bn:
+                    y = mx.sym.BatchNorm(y, name='expand_trans_bn{}'.format(i))
+                y = mx.sym.Activation(y, act_type='relu', name='expand_trans_relu{}'.format(i))
+            y = mx.sym.Convolution(
+                y, num_filter=f, kernel=(3, 3), pad=(1, 1), stride=(2, 2),
+                name='expand_conv{}'.format(i))
+            if use_bn:
+                y = mx.sym.BatchNorm(y, name='expand_bn{}'.format(i))
+            y = mx.sym.Activation(y, act_type='relu', name='expand_reu{}'.format(i))
+            outputs.append(y)
+        if global_pool:
+            outputs.append(mx.sym.Pooling(y, pool_type='avg', global_pool=True, kernel=(1, 1)))
+        super(FeatureExpander, self).__init__(outputs, inputs, params)
diff --git a/example/object-detection/block/loss.py b/example/object-detection/block/loss.py
new file mode 100644
index 0000000000..252f33b4e6
--- /dev/null
+++ b/example/object-detection/block/loss.py
@@ -0,0 +1,162 @@
+"""Custom losses for object detection.
+Losses are used to penalize incorrect classification and inaccurate box regression.
+Losses are subclasses of gluon.loss.Loss which is a HybridBlock actually.
+"""
+from mxnet.gluon import loss
+from mxnet.gluon.loss import _reshape_like, _apply_weighting
+import numpy as np
+
+def find_inf(x, mark='null'):
+    pos = np.where(x.asnumpy().flat == np.inf)[0]
+    print(mark, pos)
+
+
+class SmoothL1Loss(loss.Loss):
+    """SmoothL1 loss for finer grade regression.
+    SmoothL1 is introduced in
+
+    """
+    def __init__(self, sigma=1., weight=None, batch_axis=0, size_average=True, **kwargs):
+        super(SmoothL1Loss, self).__init__(weight, batch_axis, **kwargs)
+        self._sigma = sigma
+        self._size_average = size_average
+
+    def hybrid_forward(self, F, pred, label, sample_weight=None):
+        label = _reshape_like(F, label, pred)
+        loss = F.smooth_l1(pred - label, scalar=self._sigma)
+        loss = _apply_weighting(F, loss, self._weight, sample_weight)
+        if self._size_average:
+            return F.mean(loss, axis=self._batch_axis, exclude=True)
+        else:
+            return F.sum(loss, axis=self._batch_axis, exclude=True)
+
+
+class SoftmaxCrossEntropyLoss(loss.Loss):
+    r"""Computes the softmax cross entropy loss. (alias: SoftmaxCELoss)
+
+    If `sparse_label` is `True` (default), label should contain integer
+    category indicators:
+
+    .. math::
+
+        \DeclareMathOperator{softmax}{softmax}
+
+        p = \softmax({pred})
+
+        L = -\sum_i \log p_{i,{label}_i}
+
+    `label`'s shape should be `pred`'s shape with the `axis` dimension removed.
+    i.e. for `pred` with shape (1,2,3,4) and `axis = 2`, `label`'s shape should
+    be (1,2,4).
+
+    If `sparse_label` is `False`, `label` should contain probability distribution
+    and `label`'s shape should be the same with `pred`:
+
+    .. math::
+
+        p = \softmax({pred})
+
+        L = -\sum_i \sum_j {label}_j \log p_{ij}
+
+    Parameters
+    ----------
+    axis : int, default -1
+        The axis to sum over when computing softmax and entropy.
+    sparse_label : bool, default True
+        Whether label is an integer array instead of probability distribution.
+    from_logits : bool, default False
+        Whether input is a log probability (usually from log_softmax) instead
+        of unnormalized numbers.
+    weight : float or None
+        Global scalar weight for loss.
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+    ignore_label : int, default -1
+        The label to be ignored for calculating loss.
+
+
+    Inputs:
+        - **pred**: the prediction tensor, where the `batch_axis` dimension
+          ranges over batch size and `axis` dimension ranges over the number
+          of classes.
+        - **label**: the truth tensor. When `sparse_label` is True, `label`'s
+          shape should be `pred`'s shape with the `axis` dimension removed.
+          i.e. for `pred` with shape (1,2,3,4) and `axis = 2`, `label`'s shape
+          should be (1,2,4) and values should be integers between 0 and 2. If
+          `sparse_label` is False, `label`'s shape must be the same as `pred`
+          and values should be floats in the range `[0, 1]`.
+        - **sample_weight**: element-wise weighting tensor. Must be broadcastable
+          to the same shape as label. For example, if label has shape (64, 10)
+          and you want to weigh each sample in the batch separately,
+          sample_weight should have shape (64, 1).
+
+    Outputs:
+        - **loss**: loss tensor with shape (batch_size,). Dimenions other than
+          batch_axis are averaged out.
+    """
+    def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
+                 batch_axis=0, ignore_label=-1, size_average=True, **kwargs):
+        super(SoftmaxCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
+        self._axis = axis
+        self._sparse_label = sparse_label
+        self._from_logits = from_logits
+        self._ignore_label = ignore_label
+        self._size_average = size_average
+
+    def hybrid_forward(self, F, pred, label, sample_weight=None):
+        if not self._from_logits:
+            pred = F.log_softmax(pred, self._axis)
+        if self._sparse_label:
+            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
+            loss = F.where(label.expand_dims(axis=self._axis) == self._ignore_label,
+                           F.zeros_like(loss), loss)
+        else:
+            label = _reshape_like(F, label, pred)
+            loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
+        loss = _apply_weighting(F, loss, self._weight, sample_weight)
+        if self._size_average:
+            return F.mean(loss, axis=self._batch_axis, exclude=True)
+        else:
+            return F.sum(loss, axis=self._batch_axis, exclude=True)
+
+
+
+class FocalLoss(loss.Loss):
+    """Focal Loss for inbalanced classification.
+    Focal loss was described in https://arxiv.org/abs/1708.02002
+
+    Parameters
+    ----------
+    pending
+    """
+    def __init__(self, axis=-1, alpha=0.25, gamma=2, sparse_label=True,
+                 from_logits=False, batch_axis=0, weight=None, num_class=None,
+                 eps=1e-12, size_average=True, **kwargs):
+        super(FocalLoss, self).__init__(weight, batch_axis, **kwargs)
+        self._axis = axis
+        self._alpha = alpha
+        self._gamma = gamma
+        self._sparse_label = sparse_label
+        if sparse_label and (not isinstance(num_class, int) or (num_class < 1)):
+            raise ValueError("Number of class > 0 must be provided if sparse label is used.")
+        self._num_class = num_class
+        self._from_logits = from_logits
+        self._eps = eps
+        self._size_average = size_average
+
+    def hybrid_forward(self, F, output, label, sample_weight=None):
+        if not self._from_logits:
+            output = F.sigmoid(output)
+        if self._sparse_label:
+            one_hot = F.one_hot(label, self._num_class)
+        else:
+            one_hot = label > 0
+        pt = F.where(one_hot, output, 1 - output)
+        t = F.ones_like(one_hot)
+        alpha = F.where(one_hot, self._alpha * t, (1 - self._alpha) * t)
+        loss = -alpha * ((1 - pt) ** self._gamma) * F.log(F.minimum(pt + self._eps, 1))
+        loss = _apply_weighting(F, loss, self._weight, sample_weight)
+        if self._size_average:
+            return F.mean(loss, axis=self._batch_axis, exclude=True)
+        else:
+            return F.sum(loss, axis=self._batch_axis, exclude=True)
diff --git a/example/object-detection/block/matcher.py b/example/object-detection/block/matcher.py
new file mode 100644
index 0000000000..d5e6ca48d0
--- /dev/null
+++ b/example/object-detection/block/matcher.py
@@ -0,0 +1,83 @@
+"""Matchers for target assignment.
+Matchers are commonly used in object-detection for anchor-groundtruth matching.
+The matching process is a prerequisite to training target assignment.
+Matching is usually not required during testing.
+"""
+from mxnet import gluon
+from .registry import register, alias, create
+
+
+@register
+class CompositeMatcher(gluon.HybridBlock):
+    """A Matcher that combines multiple strategies.
+
+    Parameters
+    ----------
+    matchers : list of Matcher
+        Matcher is a Block/HybridBlock used to match two groups of boxes
+    """
+    def __init__(self, matchers):
+        super(CompositeMatcher, self).__init__()
+        assert len(matchers) > 0, "At least one matcher required."
+        for matcher in matchers:
+            assert isinstance(matcher, (gluon.Block, gluon.HybridBlock))
+        self._matchers = matchers
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        matches = [matcher(x) for matcher in self._matchers]
+        return self._compose_matches(F, matches)
+
+    def _compose_matches(self, F, matches):
+        """Given multiple match results, compose the final match results.
+        The order of matches matters. Only the unmatched(-1s) in the current
+        state will be substituded with the matching in the rest matches.
+
+        Parameters
+        ----------
+        matches : list of NDArrays
+            N match results, each is an output of a different Matcher
+
+        Returns
+        -------
+         one match results as (B, N, M) NDArray
+        """
+        result = matches[0]
+        for match in matches[1:]:
+            result = F.where(result > -0.5, result, match)
+        return result
+
+
+@register
+class BipartiteMatcher(gluon.HybridBlock):
+    """A Matcher implementing bipartite matching strategy.
+
+    Parameters
+    ----------
+    threshold : float
+        Threshold used to ignore invalid paddings
+    is_ascend : bool
+        Whether sort matching order in ascending order. Default is False.
+    """
+    def __init__(self, threshold=1e-12, is_ascend=False):
+        super(BipartiteMatcher, self).__init__()
+        self._threshold = threshold
+        self._is_ascend = is_ascend
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        match = F.contrib.bipartite_matching(x, threshold=self._threshold,
+                                             is_ascend=self._is_ascend)
+        return match[0]
+
+
+@register
+class MaximumMatcher(gluon.HybridBlock):
+    """A Matcher implementing maximum matching strategy."""
+    def __init__(self, threshold):
+        super(MaximumMatcher, self).__init__()
+        self._threshold = threshold
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        argmax = F.argmax(x, axis=-1)
+        match = F.where(F.pick(x, argmax, axis=-1) > self._threshold, argmax,
+                        F.ones_like(argmax) * -1)
+        return match
diff --git a/example/object-detection/block/predictor.py b/example/object-detection/block/predictor.py
new file mode 100644
index 0000000000..6b317a3e16
--- /dev/null
+++ b/example/object-detection/block/predictor.py
@@ -0,0 +1,49 @@
+"""Predictor for classification/box prediction.
+
+"""
+from mxnet.gluon import HybridBlock
+from mxnet.gluon import nn
+
+
+class ConvPredictor(HybridBlock):
+    """Convolutional predictor.
+    Convolutional predictor is widely used in object-detection. It can be used
+    to predict classification scores (1 channel per class) or box predictor,
+    which is usually 4 channels per box.
+    The output is of shape (N, num_channel, H, W).
+
+    Parameters
+    ----------
+
+
+    """
+    def __init__(self, num_channel, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
+                 activation=None, use_bias=True, **kwargs):
+        super(ConvPredictor, self).__init__(**kwargs)
+        with self.name_scope():
+            self.predictor = nn.Conv2D(
+                num_channel, kernel, strides=stride, padding=pad,
+                activation=activation, use_bias=use_bias)
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        return self.predictor(x)
+
+
+class FCPredictor(HybridBlock):
+    """Fully connected predictor.
+    Fully connected predictor is used to ignore spatial information and will
+    output fixed-sized predictions.
+
+
+    Parameters
+    ----------
+
+    """
+    def __init__(self, num_output, activation=None, use_bias=True, **kwargs):
+        super(FCPredictor, self).__init__(**kwargs)
+        with self.name_scope():
+            self.predictor = nn.Dense(
+                num_output, activation=activation, use_bias=use_bias)
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        return self.predictor(x)
diff --git a/example/object-detection/block/registry.py b/example/object-detection/block/registry.py
new file mode 100644
index 0000000000..cfcfbd990b
--- /dev/null
+++ b/example/object-detection/block/registry.py
@@ -0,0 +1,8 @@
+"""Registry wrapper for all detection blocks.
+"""
+from mxnet import registry
+from mxnet import gluon
+
+register = registry.get_register_func(gluon.Block, 'object_detection')
+alias = registry.get_alias_func(gluon.Block, 'object_detection')
+create = registry.get_create_func(gluon.Block, 'object_detection')
diff --git a/example/object-detection/block/sampler.py b/example/object-detection/block/sampler.py
new file mode 100644
index 0000000000..4d06b0b9f4
--- /dev/null
+++ b/example/object-detection/block/sampler.py
@@ -0,0 +1,87 @@
+"""Samplers for positive/negative/ignore sample selections.
+This module is used to select samples during training.
+Based on different strategies, we would like to choose different number of
+samples as positive, negative or ignore(don't care). The purpose is to alleviate
+unbalanced training target in some circumstances.
+The output of sampler is an NDArray of the same shape as the matching results.
+Note: 1 for positive, -1 for negative, 0 for ignore.
+"""
+import numpy as np
+from mxnet import gluon
+from mxnet import nd
+from mxnet import autograd
+from .registry import register, alias, create
+
+
+class Sampler(gluon.Block):
+    """A Base class for standard samplers when hybrid_forward is not available."""
+    def __init__(self):
+        super(Sampler, self).__init__()
+
+
+class HybridSampler(gluon.HybridBlock):
+    """A Base class for hybrid implementation of Samplers."""
+    def __init__(self):
+        super(HybridSampler, self).__init__()
+
+
+@register
+class NaiveSampler(HybridSampler):
+    """A naive sampler that take all existing matching results.
+    There is no ignored sample in this case.
+    """
+    def __init__(self):
+        super(NaiveSampler, self).__init__()
+
+    def hybrid_forward(self, F, x, *args, **kwargs):
+        marker = F.ones_like(x)
+        y = F.where(x >= 0, marker, marker * -1)
+        return y
+
+
+@register
+class OHEMSampler(Sampler):
+    """A sampler implementing Online Hard-negative mining.
+    As described in paper https://arxiv.org/abs/1604.03540.
+
+    Parameters
+    ----------
+
+    """
+    def __init__(self, ratio, min_samples=0, thresh=0.5):
+        super(OHEMSampler, self).__init__()
+        assert ratio > 0, "OHEMSampler ratio must > 0, {} given".format(ratio)
+        self._ratio = ratio
+        self._min_samples = min_samples
+        self._thresh = thresh
+
+    def forward(self, x, logits, ious, *args):
+        """
+
+        """
+        F = nd
+        num_positive = F.sum(x > -1, axis=1)
+        num_negative = self._ratio * num_positive
+        num_total = x.shape[1]  # scalar
+        num_negative = F.minimum(F.maximum(self._min_samples, num_negative),
+                                 num_total - num_positive)
+        positive = logits.slice_axis(axis=2, begin=1, end=-1)
+        background = logits.slice_axis(axis=2, begin=0, end=1).reshape((0, -1))
+        maxval = positive.max(axis=2)
+        esum = F.exp(logits - maxval.reshape((0, 0, 1))).sum(axis=2)
+        score = -F.log(F.exp(background - maxval) / esum)
+        mask = F.ones_like(score) * -1
+        score = F.where(x < 0, score, mask)  # mask out positive samples
+        if len(ious.shape) == 3:
+            ious = F.max(ious, axis=2)
+        score = F.where(ious < self._thresh, score, mask)  # mask out if iou is large
+        argmaxs = F.argsort(score, axis=1, is_ascend=False)
+
+        # neg number is different in each batch, using dynamic numpy operations.
+        y = np.zeros(x.shape)
+        y[np.where(x.asnumpy() >= 0)] = 1  # assign positive samples
+        argmaxs = argmaxs.asnumpy()
+        for i, num_neg in zip(range(x.shape[0]), num_negative.asnumpy().astype(np.int32)):
+            indices = argmaxs[i, :num_neg]
+            y[i, indices.astype(np.int32)] = -1  # assign negative samples
+        return F.array(y, ctx=x.context)
diff --git a/example/object-detection/block/target.py b/example/object-detection/block/target.py
new file mode 100644
index 0000000000..ceefe97880
--- /dev/null
+++ b/example/object-detection/block/target.py
@@ -0,0 +1,41 @@
+"""Target generator for training.
+Target generator is used to generate training targets, given anchors, ground-truths,
+match results and sampler.
+"""
+from mxnet import nd
+from mxnet import autograd
+from mxnet.gluon import Block
+from block.matcher import CompositeMatcher, BipartiteMatcher, MaximumMatcher
+from block.sampler import NaiveSampler, OHEMSampler
+from block.coder import MultiClassEncoder, NormalizedBoxCenterEncoder
+import numpy as np
+
+class SSDTargetGenerator(Block):
+    """
+
+    """
+    def __init__(self, threshold=0.5, **kwargs):
+        super(SSDTargetGenerator, self).__init__(**kwargs)
+        self._matcher = CompositeMatcher([BipartiteMatcher(), MaximumMatcher(threshold)])
+        # self._sampler = NaiveSampler()
+        self._sampler = OHEMSampler(3, thresh=0.5)
+        self._cls_encoder = MultiClassEncoder()
+        self._box_encoder = NormalizedBoxCenterEncoder()
+
+    def forward(self, predictions, labels):
+        # predictions: [cls_preds, box_preds, anchors]
+        anchors = predictions[2].reshape((-1, 4))
+        gt_boxes = nd.slice_axis(labels, axis=-1, begin=1, end=5)
+        gt_ids = nd.slice_axis(labels, axis=-1, begin=0, end=1)
+        ious = nd.transpose(nd.contrib.box_iou(anchors, gt_boxes), (1, 0, 2))
+        matches = self._matcher(ious)
+        samples = self._sampler(matches, predictions[0], ious)
+        cls_targets = self._cls_encoder(samples, matches, gt_ids)
+        box_targets, box_masks = self._box_encoder(samples, matches, anchors, gt_boxes)
+        # print('box-targets', box_targets[0], 'box-masks', box_masks[0])
+        # ref = nd.contrib.MultiBoxTarget(*[predictions[2], labels, predictions[0].transpose(axes=(0, 2, 1))], negative_mining_ratio=3)
+        # loc_target, loc_mask, ref_cls_target = ref
+        # print('diff', np.sum(np.abs(ref_cls_target.asnumpy().flatten() - cls_targets.asnumpy().flatten())))
+        # print('diff2', np.sum(np.abs(loc_target.asnumpy().flatten() - box_targets.asnumpy().flatten())))
+        # print('diff3', np.sum(np.abs(loc_mask.asnumpy().flatten() - box_masks.asnumpy().flatten())))
+        return cls_targets, box_targets, box_masks
diff --git a/example/object-detection/builder/__init__.py b/example/object-detection/builder/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/builder/model_builder.py b/example/object-detection/builder/model_builder.py
new file mode 100644
index 0000000000..e946061458
--- /dev/null
+++ b/example/object-detection/builder/model_builder.py
@@ -0,0 +1,2 @@
+"""Build network by configurations."""
+from model_zoo import *
diff --git a/example/object-detection/config/__init__.py b/example/object-detection/config/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/config/config.py b/example/object-detection/config/config.py
new file mode 100644
index 0000000000..f739b93c90
--- /dev/null
+++ b/example/object-detection/config/config.py
@@ -0,0 +1,55 @@
+from __future__ import absolute_import
+import yaml
+import collections
+import os.path as osp
+import logging
+
+CONFIG = {}
+DEFAULT_CONFIG = osp.join(osp.dirname(__file__), 'default.yml')
+
+
+def get_config():
+    """Grab the config as dict()."""
+    return CONFIG
+
+def load_config(cfg_file):
+    """Update configurations with new config file."""
+    cfg = get_config()
+    with open(cfg_file, 'r') as inf:
+        new_cfg = yaml.load(inf)
+    if new_cfg:
+        cfg.update(new_cfg)
+    else:
+        logging.warning('Nothing loaded from %s', cfg_file)
+    return cfg
+
+def update_config(new_cfg):
+    """Update configs with dict."""
+    def recursive_update(d, u, log=''):
+        for k, v in u.items():
+            if isinstance(d, collections.Mapping):
+                if not k in d:
+                    logging.warning('%s is not in default config, is it on purpose?',
+                                    log + '.' + str(k))
+                if isinstance(v, collections.Mapping):
+                    r = recursive_update(d.get(k, {}), v, log=log + '.' + str(k))
+                    d[k] = r
+                else:
+                    d[k] = u[k]
+            else:
+                logging.warning('%s is not a parent', str(d))
+        return d
+    cfg = get_config()
+    recursive_update(cfg, new_cfg)
+    return cfg
+
+def save_config(filename):
+    """Save current configuration to file."""
+    with open(filename, 'w') as fout:
+        fout.write(yaml.dump(get_config(), default_flow_style=False))
+
+def dump_config():
+    return yaml.dump(get_config(), default_flow_style=False)
+
+# load the default configurations
+CONFIG = load_config(DEFAULT_CONFIG)
diff --git a/example/object-detection/config/default.yml b/example/object-detection/config/default.yml
new file mode 100644
index 0000000000..4150336a7a
--- /dev/null
+++ b/example/object-detection/config/default.yml
@@ -0,0 +1,107 @@
+# This is the default configuration file for object detection API
+# The structure of the configuration is listed as follows:
+# --- dataset
+#  |- augmentation
+#  |- model
+
+# dataset name, directory, # class, class names, etc...
+dataset:
+  name: pascal_0712_trainval
+  directory: './data/'
+  num_class: 20
+
+network:
+  name: resnet50_v1
+  pretrained: yes
+  feature_extractor:
+    layers:
+      - '_plus15'
+      - ''
+      - ''
+      - ''
+      - ''
+      - ''
+    channels:
+      - 0
+      - 256
+      - 256
+      - 256
+      - 256
+      - 256
+    min_channel: 128
+    use_1x1_conv: yes
+    channel_multiplier: 0.5
+  arch: ssd
+  ssd:
+    anchors:
+      num_layers: 6
+      min_scale: 0.2
+      max_scale: 0.95
+      aspect_ratios:
+        - 1.0
+        - 2.0
+        - 0.5
+        - 3.0
+        - 0.33333
+      reduce_ratio_lowerst_layer: 3  # use 3 ratios in the lowerst layer instead of 5
+      clip: no
+      verbose: no
+    matcher:
+      positive_threshold: 0.5
+      negative_threshold: 0.5
+      negative_mining_ratio: 3.0
+      minimum_negative_samples: 0
+      variances:
+        - 0.1
+        - 0.1
+        - 0.2
+        - 0.2
+
+# augmentation is applied to all images in iterators
+augmentation:
+  # training augmentaion list
+  train:
+    horizontal_flip: yes
+    resize: 0
+    rand_crop: 0.8
+    rand_pad: 0.8
+    rand_gray: 0.05
+    rand_mirror: yes
+    mean: yes
+    std: yes
+    brightness: 0.05
+    contrast: 0.05
+    saturation: 0.05
+    pca_noise: 0.05
+    hue: 0.05
+    inter_method: 10
+    min_object_covered: 0.75
+    aspect_ratio_range:
+      - 0.5
+      - 2.0
+    area_range:
+      - 0.3
+      - 3.0
+    min_eject_coverage: 0.4
+    max_attempts: 100
+    pad_val:
+      - 127.0
+      - 127.0
+      - 127.0
+
+  # validation augmentation is turned off
+  val:
+    horizontal_flip: no
+    resize: 0
+    rand_crop: 0
+    rand_pad: 0
+    rand_gray: 0
+    rand_mirror: no
+    mean: yes
+    std: yes
+    brightness: 0
+    contrast: 0
+    saturation: 0
+    pca_noise: 0
+    hue: 0
+    inter_method: 2
diff --git a/example/object-detection/data/README.md b/example/object-detection/data/README.md
new file mode 100644
index 0000000000..256b9c7a9f
--- /dev/null
+++ b/example/object-detection/data/README.md
@@ -0,0 +1 @@
+Dataset directory
diff --git a/example/object-detection/dataset/README.md b/example/object-detection/dataset/README.md
new file mode 100644
index 0000000000..d0b58e3b0d
--- /dev/null
+++ b/example/object-detection/dataset/README.md
@@ -0,0 +1 @@
+### Prepare dataset for object Detection
diff --git a/example/object-detection/dataset/__init__.py b/example/object-detection/dataset/__init__.py
new file mode 100644
index 0000000000..a67230ab96
--- /dev/null
+++ b/example/object-detection/dataset/__init__.py
@@ -0,0 +1,2 @@
+"""Dataset gallery."""
+from dataset.voc import VOCDetection
diff --git a/example/object-detection/dataset/base.py b/example/object-detection/dataset/base.py
new file mode 100644
index 0000000000..7ad0cc2bc6
--- /dev/null
+++ b/example/object-detection/dataset/base.py
@@ -0,0 +1,25 @@
+"""Base detection dataset methods."""
+import os
+from mxnet.gluon.data import dataset
+
+
+class DetectionDataset(dataset.Dataset):
+    """Base detection Dataset.
+
+    Parameters
+    ----------
+    name : str
+        The name of dataset, by default, dataset/names/{}.names will be loaded,
+        where names of classes is defined.
+    root : str
+        The root path of xxx.names, by defaut is 'dataset/names/'
+    """
+    def __init__(self, name, root=None):
+        if root is None:
+            root = os.path.join(os.path.dirname(__file__), 'names')
+        else:
+            assert isinstance(root, str), "Provided root must be str"
+        name_path = os.path.join(root, name + '.names')
+        with open(name_path, 'r') as f:
+            self.classes = [line.strip() for line in f.readlines()]
+        self.num_classes = len(self.classes)
diff --git a/example/object-detection/dataset/coco.py b/example/object-detection/dataset/coco.py
new file mode 100644
index 0000000000..4b291f5974
--- /dev/null
+++ b/example/object-detection/dataset/coco.py
@@ -0,0 +1,3 @@
+"""MS COCO dataset."""
+from mxnet.gluon.data import dataset
+import os
diff --git a/example/object-detection/dataset/dataloader.py b/example/object-detection/dataset/dataloader.py
new file mode 100644
index 0000000000..c9fedd80ce
--- /dev/null
+++ b/example/object-detection/dataset/dataloader.py
@@ -0,0 +1,83 @@
+"""Dataset generator."""
+
+import numpy as np
+from mxnet.gluon.data import sampler as _sampler
+from mxnet import nd
+
+
+def _batchify(data):
+    """Collate data into batch."""
+    if isinstance(data[0], nd.NDArray):
+        return nd.stack(*data)
+    elif isinstance(data[0], tuple):
+        data = zip(*data)
+        return [_batchify(i) for i in data]
+    else:
+        data = np.asarray(data)
+        # for l in data:
+        #     print(l)
+        # padding the labels
+        batch_size = len(data)
+        pad = max([l.shape[0] for l in data])
+        buf = np.full((batch_size, pad, data[0].shape[-1]), -1, dtype=data[0].dtype)
+        for i, l in enumerate(data):
+            buf[i][:l.shape[0], :] = l
+        return nd.array(buf, dtype=data[0].dtype)
+
+
+class DataLoader(object):
+    """Loads data from a dataset and returns mini-batches of data.
+
+    Parameters
+    ----------
+    dataset : Dataset
+        Source dataset. Note that numpy and mxnet arrays can be directly used
+        as a Dataset.
+    batch_size : int
+        Size of mini-batch.
+    shuffle : bool
+        Whether to shuffle the samples.
+    sampler : Sampler
+        The sampler to use. Either specify sampler or shuffle, not both.
+    last_batch : {'keep', 'discard', 'rollover'}
+        How to handle the last batch if batch_size does not evenly divide
+        `len(dataset)`.
+
+        keep - A batch with less samples than previous batches is returned.
+        discard - The last batch is discarded if its incomplete.
+        rollover - The remaining samples are rolled over to the next epoch.
+    batch_sampler : Sampler
+        A sampler that returns mini-batches. Do not specify batch_size,
+        shuffle, sampler, and last_batch if batch_sampler is specified.
+    """
+    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
+                 last_batch=None, batch_sampler=None):
+        self._dataset = dataset
+
+        if batch_sampler is None:
+            if batch_size is None:
+                raise ValueError("batch_size must be specified unless " \
+                                 "batch_sampler is specified")
+            if sampler is None:
+                if shuffle:
+                    sampler = _sampler.RandomSampler(len(dataset))
+                else:
+                    sampler = _sampler.SequentialSampler(len(dataset))
+            elif shuffle:
+                raise ValueError("shuffle must not be specified if sampler is specified")
+
+            batch_sampler = _sampler.BatchSampler(
+                sampler, batch_size, last_batch if last_batch else 'keep')
+        elif batch_size is not None or shuffle or sampler is not None or \
+                last_batch is not None:
+            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
+                             "not be specified if batch_sampler is specified.")
+
+        self._batch_sampler = batch_sampler
+
+    def __iter__(self):
+        for batch in self._batch_sampler:
+            yield _batchify([self._dataset[idx] for idx in batch])
+
+    def __len__(self):
+        return len(self._batch_sampler)
diff --git a/example/object-detection/dataset/names/coco.names b/example/object-detection/dataset/names/coco.names
new file mode 100644
index 0000000000..ca76c80b5b
--- /dev/null
+++ b/example/object-detection/dataset/names/coco.names
@@ -0,0 +1,80 @@
+person
+bicycle
+car
+motorbike
+aeroplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+sofa
+pottedplant
+bed
+diningtable
+toilet
+tvmonitor
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
diff --git a/example/object-detection/dataset/names/voc.names b/example/object-detection/dataset/names/voc.names
new file mode 100644
index 0000000000..8420ab35ed
--- /dev/null
+++ b/example/object-detection/dataset/names/voc.names
@@ -0,0 +1,20 @@
+aeroplane
+bicycle
+bird
+boat
+bottle
+bus
+car
+cat
+chair
+cow
+diningtable
+dog
+horse
+motorbike
+person
+pottedplant
+sheep
+sofa
+train
+tvmonitor
diff --git a/example/object-detection/dataset/pycocotools/README.md b/example/object-detection/dataset/pycocotools/README.md
new file mode 100644
index 0000000000..d358f53105
--- /dev/null
+++ b/example/object-detection/dataset/pycocotools/README.md
@@ -0,0 +1,2 @@
+This is a modified version of https://github.com/pdollar/coco python API.
+No `make` is required, but this will not support mask functions.
diff --git a/example/object-detection/dataset/pycocotools/__init__.py b/example/object-detection/dataset/pycocotools/__init__.py
new file mode 100644
index 0000000000..3f7d85bba8
--- /dev/null
+++ b/example/object-detection/dataset/pycocotools/__init__.py
@@ -0,0 +1 @@
+__author__ = 'tylin'
diff --git a/example/object-detection/dataset/pycocotools/coco.py b/example/object-detection/dataset/pycocotools/coco.py
new file mode 100644
index 0000000000..a8939f64a3
--- /dev/null
+++ b/example/object-detection/dataset/pycocotools/coco.py
@@ -0,0 +1,435 @@
+__author__ = 'tylin'
+__version__ = '2.0'
+# Interface for accessing the Microsoft COCO dataset.
+
+# Microsoft COCO is a large image dataset designed for object detection,
+# segmentation, and caption generation. pycocotools is a Python API that
+# assists in loading, parsing and visualizing the annotations in COCO.
+# Please visit http://mscoco.org/ for more information on COCO, including
+# for the data, paper, and tutorials. The exact format of the annotations
+# is also described on the COCO website. For example usage of the pycocotools
+# please see pycocotools_demo.ipynb. In addition to this API, please download both
+# the COCO images and annotations in order to run the demo.
+
+# An alternative to using the API is to load the annotations directly
+# into Python dictionary
+# Using the API provides additional utility functions. Note that this API
+# supports both *instance* and *caption* annotations. In the case of
+# captions not all functions are defined (e.g. categories are undefined).
+
+# The following API functions are defined:
+#  COCO       - COCO api class that loads COCO annotation file and prepare data structures.
+#  decodeMask - Decode binary mask M encoded via run-length encoding.
+#  encodeMask - Encode binary mask M using run-length encoding.
+#  getAnnIds  - Get ann ids that satisfy given filter conditions.
+#  getCatIds  - Get cat ids that satisfy given filter conditions.
+#  getImgIds  - Get img ids that satisfy given filter conditions.
+#  loadAnns   - Load anns with the specified ids.
+#  loadCats   - Load cats with the specified ids.
+#  loadImgs   - Load imgs with the specified ids.
+#  annToMask  - Convert segmentation in an annotation to binary mask.
+#  showAnns   - Display the specified annotations.
+#  loadRes    - Load algorithm results and create API for accessing them.
+#  download   - Download COCO images from mscoco.org server.
+# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
+# Help on each functions can be accessed by: "help COCO>function".
+
+# See also COCO>decodeMask,
+# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
+# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
+# COCO>loadImgs, COCO>annToMask, COCO>showAnns
+
+# Microsoft COCO Toolbox.      version 2.0
+# Data, paper, and tutorials available at:  http://mscoco.org/
+# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
+# Licensed under the Simplified BSD License [see bsd.txt]
+
+import json
+import time
+import matplotlib.pyplot as plt
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon
+import numpy as np
+import copy
+import itertools
+# from . import mask as maskUtils
+import os
+from collections import defaultdict
+import sys
+PYTHON_VERSION = sys.version_info[0]
+if PYTHON_VERSION == 2:
+    from urllib import urlretrieve
+elif PYTHON_VERSION == 3:
+    from urllib.request import urlretrieve
+
+class COCO:
+    def __init__(self, annotation_file=None):
+        """
+        Constructor of Microsoft COCO helper class for reading and visualizing annotations.
+        :param annotation_file (str): location of annotation file
+        :param image_folder (str): location to the folder that hosts images.
+        :return:
+        """
+        # load dataset
+        self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
+        self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
+        if not annotation_file == None:
+            print('loading annotations into memory...')
+            tic = time.time()
+            dataset = json.load(open(annotation_file, 'r'))
+            assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
+            print('Done (t={:0.2f}s)'.format(time.time()- tic))
+            self.dataset = dataset
+            self.createIndex()
+
+    def createIndex(self):
+        # create index
+        print('creating index...')
+        anns, cats, imgs = {}, {}, {}
+        imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
+        if 'annotations' in self.dataset:
+            for ann in self.dataset['annotations']:
+                imgToAnns[ann['image_id']].append(ann)
+                anns[ann['id']] = ann
+
+        if 'images' in self.dataset:
+            for img in self.dataset['images']:
+                imgs[img['id']] = img
+
+        if 'categories' in self.dataset:
+            for cat in self.dataset['categories']:
+                cats[cat['id']] = cat
+
+        if 'annotations' in self.dataset and 'categories' in self.dataset:
+            for ann in self.dataset['annotations']:
+                catToImgs[ann['category_id']].append(ann['image_id'])
+
+        print('index created!')
+
+        # create class members
+        self.anns = anns
+        self.imgToAnns = imgToAnns
+        self.catToImgs = catToImgs
+        self.imgs = imgs
+        self.cats = cats
+
+    def info(self):
+        """
+        Print information about the annotation file.
+        :return:
+        """
+        for key, value in self.dataset['info'].items():
+            print('{}: {}'.format(key, value))
+
+    def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
+        """
+        Get ann ids that satisfy given filter conditions. default skips that filter
+        :param imgIds  (int array)     : get anns for given imgs
+               catIds  (int array)     : get anns for given cats
+               areaRng (float array)   : get anns for given area range (e.g. [0 inf])
+               iscrowd (boolean)       : get anns for given crowd label (False or True)
+        :return: ids (int array)       : integer array of ann ids
+        """
+        imgIds = imgIds if type(imgIds) == list else [imgIds]
+        catIds = catIds if type(catIds) == list else [catIds]
+
+        if len(imgIds) == len(catIds) == len(areaRng) == 0:
+            anns = self.dataset['annotations']
+        else:
+            if not len(imgIds) == 0:
+                lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
+                anns = list(itertools.chain.from_iterable(lists))
+            else:
+                anns = self.dataset['annotations']
+            anns = anns if len(catIds)  == 0 else [ann for ann in anns if ann['category_id'] in catIds]
+            anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
+        if not iscrowd == None:
+            ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
+        else:
+            ids = [ann['id'] for ann in anns]
+        return ids
+
+    def getCatIds(self, catNms=[], supNms=[], catIds=[]):
+        """
+        filtering parameters. default skips that filter.
+        :param catNms (str array)  : get cats for given cat names
+        :param supNms (str array)  : get cats for given supercategory names
+        :param catIds (int array)  : get cats for given cat ids
+        :return: ids (int array)   : integer array of cat ids
+        """
+        catNms = catNms if type(catNms) == list else [catNms]
+        supNms = supNms if type(supNms) == list else [supNms]
+        catIds = catIds if type(catIds) == list else [catIds]
+
+        if len(catNms) == len(supNms) == len(catIds) == 0:
+            cats = self.dataset['categories']
+        else:
+            cats = self.dataset['categories']
+            cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name']          in catNms]
+            cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
+            cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id']            in catIds]
+        ids = [cat['id'] for cat in cats]
+        return ids
+
+    def getImgIds(self, imgIds=[], catIds=[]):
+        '''
+        Get img ids that satisfy given filter conditions.
+        :param imgIds (int array) : get imgs for given ids
+        :param catIds (int array) : get imgs with all given cats
+        :return: ids (int array)  : integer array of img ids
+        '''
+        imgIds = imgIds if type(imgIds) == list else [imgIds]
+        catIds = catIds if type(catIds) == list else [catIds]
+
+        if len(imgIds) == len(catIds) == 0:
+            ids = self.imgs.keys()
+        else:
+            ids = set(imgIds)
+            for i, catId in enumerate(catIds):
+                if i == 0 and len(ids) == 0:
+                    ids = set(self.catToImgs[catId])
+                else:
+                    ids &= set(self.catToImgs[catId])
+        return list(ids)
+
+    def loadAnns(self, ids=[]):
+        """
+        Load anns with the specified ids.
+        :param ids (int array)       : integer ids specifying anns
+        :return: anns (object array) : loaded ann objects
+        """
+        if type(ids) == list:
+            return [self.anns[id] for id in ids]
+        elif type(ids) == int:
+            return [self.anns[ids]]
+
+    def loadCats(self, ids=[]):
+        """
+        Load cats with the specified ids.
+        :param ids (int array)       : integer ids specifying cats
+        :return: cats (object array) : loaded cat objects
+        """
+        if type(ids) == list:
+            return [self.cats[id] for id in ids]
+        elif type(ids) == int:
+            return [self.cats[ids]]
+
+    def loadImgs(self, ids=[]):
+        """
+        Load anns with the specified ids.
+        :param ids (int array)       : integer ids specifying img
+        :return: imgs (object array) : loaded img objects
+        """
+        if type(ids) == list:
+            return [self.imgs[id] for id in ids]
+        elif type(ids) == int:
+            return [self.imgs[ids]]
+
+    def showAnns(self, anns):
+        """
+        Display the specified annotations.
+        :param anns (array of object): annotations to display
+        :return: None
+        """
+        if len(anns) == 0:
+            return 0
+        if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
+            datasetType = 'instances'
+        elif 'caption' in anns[0]:
+            datasetType = 'captions'
+        else:
+            raise Exception('datasetType not supported')
+        if datasetType == 'instances':
+            ax = plt.gca()
+            ax.set_autoscale_on(False)
+            polygons = []
+            color = []
+            for ann in anns:
+                c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
+                if 'segmentation' in ann:
+                    if type(ann['segmentation']) == list:
+                        # polygon
+                        for seg in ann['segmentation']:
+                            poly = np.array(seg).reshape((int(len(seg)/2), 2))
+                            polygons.append(Polygon(poly))
+                            color.append(c)
+                    else:
+                        # mask
+                        t = self.imgs[ann['image_id']]
+                        if type(ann['segmentation']['counts']) == list:
+                            # rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width'])
+                            raise NotImplementedError("maskUtils disabled!")
+                        else:
+                            rle = [ann['segmentation']]
+                        # m = maskUtils.decode(rle)
+                        raise NotImplementedError("maskUtils disabled!")
+                        img = np.ones( (m.shape[0], m.shape[1], 3) )
+                        if ann['iscrowd'] == 1:
+                            color_mask = np.array([2.0,166.0,101.0])/255
+                        if ann['iscrowd'] == 0:
+                            color_mask = np.random.random((1, 3)).tolist()[0]
+                        for i in range(3):
+                            img[:,:,i] = color_mask[i]
+                        ax.imshow(np.dstack( (img, m*0.5) ))
+                if 'keypoints' in ann and type(ann['keypoints']) == list:
+                    # turn skeleton into zero-based index
+                    sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
+                    kp = np.array(ann['keypoints'])
+                    x = kp[0::3]
+                    y = kp[1::3]
+                    v = kp[2::3]
+                    for sk in sks:
+                        if np.all(v[sk]>0):
+                            plt.plot(x[sk],y[sk], linewidth=3, color=c)
+                    plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
+                    plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
+            p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
+            ax.add_collection(p)
+            p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2)
+            ax.add_collection(p)
+        elif datasetType == 'captions':
+            for ann in anns:
+                print(ann['caption'])
+
+    def loadRes(self, resFile):
+        """
+        Load result file and return a result api object.
+        :param   resFile (str)     : file name of result file
+        :return: res (obj)         : result api object
+        """
+        res = COCO()
+        res.dataset['images'] = [img for img in self.dataset['images']]
+
+        print('Loading and preparing results...')
+        tic = time.time()
+        if type(resFile) == str or type(resFile) == unicode:
+            anns = json.load(open(resFile))
+        elif type(resFile) == np.ndarray:
+            anns = self.loadNumpyAnnotations(resFile)
+        else:
+            anns = resFile
+        assert type(anns) == list, 'results in not an array of objects'
+        annsImgIds = [ann['image_id'] for ann in anns]
+        assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
+               'Results do not correspond to current coco set'
+        if 'caption' in anns[0]:
+            imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
+            res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
+            for id, ann in enumerate(anns):
+                ann['id'] = id+1
+        elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
+            res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+            for id, ann in enumerate(anns):
+                bb = ann['bbox']
+                x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
+                if not 'segmentation' in ann:
+                    ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
+                ann['area'] = bb[2]*bb[3]
+                ann['id'] = id+1
+                ann['iscrowd'] = 0
+        elif 'segmentation' in anns[0]:
+            res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+            for id, ann in enumerate(anns):
+                # now only support compressed RLE format as segmentation results
+                # ann['area'] = maskUtils.area(ann['segmentation'])
+                raise NotImplementedError("maskUtils disabled!")
+                if not 'bbox' in ann:
+                    # ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
+                    raise NotImplementedError("maskUtils disabled!")
+                ann['id'] = id+1
+                ann['iscrowd'] = 0
+        elif 'keypoints' in anns[0]:
+            res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+            for id, ann in enumerate(anns):
+                s = ann['keypoints']
+                x = s[0::3]
+                y = s[1::3]
+                x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y)
+                ann['area'] = (x1-x0)*(y1-y0)
+                ann['id'] = id + 1
+                ann['bbox'] = [x0,y0,x1-x0,y1-y0]
+        print('DONE (t={:0.2f}s)'.format(time.time()- tic))
+
+        res.dataset['annotations'] = anns
+        res.createIndex()
+        return res
+
+    def download(self, tarDir = None, imgIds = [] ):
+        '''
+        Download COCO images from mscoco.org server.
+        :param tarDir (str): COCO results directory name
+               imgIds (list): images to be downloaded
+        :return:
+        '''
+        if tarDir is None:
+            print('Please specify target directory')
+            return -1
+        if len(imgIds) == 0:
+            imgs = self.imgs.values()
+        else:
+            imgs = self.loadImgs(imgIds)
+        N = len(imgs)
+        if not os.path.exists(tarDir):
+            os.makedirs(tarDir)
+        for i, img in enumerate(imgs):
+            tic = time.time()
+            fname = os.path.join(tarDir, img['file_name'])
+            if not os.path.exists(fname):
+                urlretrieve(img['coco_url'], fname)
+            print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic))
+
+    def loadNumpyAnnotations(self, data):
+        """
+        Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class}
+        :param  data (numpy.ndarray)
+        :return: annotations (python nested list)
+        """
+        print('Converting ndarray to lists...')
+        assert(type(data) == np.ndarray)
+        print(data.shape)
+        assert(data.shape[1] == 7)
+        N = data.shape[0]
+        ann = []
+        for i in range(N):
+            if i % 1000000 == 0:
+                print('{}/{}'.format(i,N))
+            ann += [{
+                'image_id'  : int(data[i, 0]),
+                'bbox'  : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ],
+                'score' : data[i, 5],
+                'category_id': int(data[i, 6]),
+                }]
+        return ann
+
+    def annToRLE(self, ann):
+        """
+        Convert annotation which can be polygons, uncompressed RLE to RLE.
+        :return: binary mask (numpy 2D array)
+        """
+        t = self.imgs[ann['image_id']]
+        h, w = t['height'], t['width']
+        segm = ann['segmentation']
+        if type(segm) == list:
+            # polygon -- a single object might consist of multiple parts
+            # we merge all parts into one mask rle code
+            # rles = maskUtils.frPyObjects(segm, h, w)
+            # rle = maskUtils.merge(rles)
+            raise NotImplementedError("maskUtils disabled!")
+        elif type(segm['counts']) == list:
+            # uncompressed RLE
+            # rle = maskUtils.frPyObjects(segm, h, w)
+            raise NotImplementedError("maskUtils disabled!")
+        else:
+            # rle
+            rle = ann['segmentation']
+        return rle
+
+    def annToMask(self, ann):
+        """
+        Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask.
+        :return: binary mask (numpy 2D array)
+        """
+        rle = self.annToRLE(ann)
+        # m = maskUtils.decode(rle)
+        raise NotImplementedError("maskUtils disabled!")
+        return m
diff --git a/example/object-detection/dataset/transform.py b/example/object-detection/dataset/transform.py
new file mode 100644
index 0000000000..5e7ae29a07
--- /dev/null
+++ b/example/object-detection/dataset/transform.py
@@ -0,0 +1,290 @@
+"""Transform functions for data."""
+from mxnet import ndarray as nd
+from mxnet import image
+import numpy as np
+import random
+
+
+class Compose(object):
+    """Compose augmentations together.
+
+    Parameters
+    -----------
+
+    """
+    def __init__(self, transforms):
+        self._transforms = transforms
+
+    def __call__(self, src, label):
+        for t in self._transforms:
+            src, label = t(src, label)
+        return src, label
+
+class Lambda(object):
+    """Applies lambda function a transform.
+
+    Parameters
+    ----------
+    func : callable
+        A callable function that will do::
+
+            src, label = func(src, label)
+
+    Returns
+    -------
+    src : NDArray
+        Image
+    label : numpy.ndarray
+        Label
+    """
+    def __init__(self, func):
+        assert callable(lambd), "Lambda function must be callable"
+        self._lambda = func
+
+    def __call__(self, src, label):
+        return self._lambda(src, label)
+
+
+class Cast(object):
+    """Cast image to another type.
+
+    """
+    def __init__(self, typ=np.float32):
+        self._type = typ
+
+    def __call__(self, src, label):
+        return src.astype(self._type), label.astype(self._type)
+
+
+class ToAbsoluteCoords(object):
+    """Convert box coordinate to pixel values.
+
+
+    """
+    def __call__(self, src, label):
+        height, width, _ = src.shape
+        label[:, (1, 3)] *= width
+        label[:, (2, 4)] *= height
+        return src, label
+
+class ToPercentCoords(object):
+    """Convert box coordinates to relative percentage values.
+
+    """
+    def __call__(self, src, label):
+        height, width, _ = src.shape
+        label[:, (1, 3)] /= width
+        label[:, (2, 4)] /= height
+        return src, label
+
+
+class ForceResize(object):
+    """Force resize to data_shape for batch training. Note that coordinates must
+    be converted to percent before this augmentation.
+
+    Parameters
+    ----------
+    size : tuple
+        A tuple of (width, height) to be resized to.
+
+    """
+    def __init__(self, size):
+        self._size = size
+
+    def __call__(self, src, label):
+        src = image.imresize(src, *self._size, interp=1)
+        return src, label
+
+
+def intersect(box_a, box_b):
+    max_xy = np.minimum(box_a[:, 2:], box_b[2:])
+    min_xy = np.maximum(box_a[:, :2], box_b[:2])
+    inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
+    return inter[:, 0] * inter[:, 1]
+
+def jaccard_numpy(box_a, box_b):
+    """Compute the jaccard overlap of two sets of boxes.  The jaccard overlap
+    is simply the intersection over union of two boxes.
+    E.g.:
+    Args:
+        box_a: Multiple bounding boxes, Shape: [num_boxes,4]
+        box_b: Single bounding box, Shape: [4]
+    Return:
+        jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]]
+    """
+    inter = intersect(box_a, box_b)
+    area_a = ((box_a[:, 2]-box_a[:, 0]) *
+              (box_a[:, 3]-box_a[:, 1]))  # [A,B]
+    area_b = ((box_b[2]-box_b[0]) *
+              (box_b[3]-box_b[1]))  # [A,B]
+    union = area_a + area_b - inter
+    return inter / union # [A,B]
+
+
+class RandomSampleCrop(object):
+    """Randomly crop images and modify labels according to constraints.
+
+    Parameters
+    ----------
+
+    """
+    def __init__(self, max_attempts=50):
+        self._options = (
+            # using entire original image
+            None,
+            # sample a patch s.t. minimum iou
+            (0.1, None),
+            (0.3, None),
+            (0.5, None),
+            (0.7, None),
+            (0.9, None),
+            # randomly sample a patch
+            (None, None),
+        )
+        self._max_attempts = max_attempts
+
+
+    def __call__(self, src, label):
+        height, width, _ = src.shape
+        while True:
+            # randomly choose a crop mode
+            mode = random.choice(self._options)
+            if mode is None:
+                # return the original intact
+                return src, label
+
+            min_iou, max_iou = mode
+            if min_iou is None:
+                min_iou = float('-inf')
+            if max_iou is None:
+                max_iou = float('inf')
+
+            # max trails
+            for _ in range(self._max_attempts):
+                current_image = src
+                w = random.uniform(0.3 * width , width)
+                h = random.uniform(0.3 * height, height)
+
+                # aspect ratio constraint 0.5
+                if h / w < 0.5 or h / w > 2:
+                    continue
+
+                left = random.uniform(0, width - w)
+                top = random.uniform(0, height - h)
+
+                # convert to integer rect
+                rect = np.array([int(left), int(top), int(left+w), int(top+h)])
+
+                # calculate iou
+                overlap = jaccard_numpy(label[:, 1:5], rect)
+
+                # check min and max iou constraint? if not try again
+                if overlap.min() < min_iou and max_iou < overlap.max():
+                    continue
+
+                # crop
+                current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], :]
+
+                # keep overlap with gt box is center in sampled patch
+                centers = (label[:, 1:3] + label[:, 3:5]) / 2.0
+
+                # mask in all gt boxes that above and to the left of centers
+                m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
+                # mask in all gt boxes that under ant to the right of centers
+                m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
+                mask = m1 * m2
+
+                # check valid masks
+                if not mask.any():
+                    continue
+
+                # take only matching gt
+                current_label = label[mask, :].copy()
+
+                # should we use the box left and top corner of the crop's
+                current_label[:, 1:3] = np.maximum(current_label[:, 1:3], rect[:2])
+                # adjust to crop
+                current_label[:, 1:3] -= rect[:2]
+                current_label[:, 3:5] = np.minimum(current_label[:, 3:5], rect[2:])
+                current_label[:, 3:5] -= rect[:2]
+
+                return current_image, current_label
+        return src, label
+
+class Expand(object):
+    """Randomly pad image.
+
+
+    """
+    def __init__(self, mean_pixel):
+        self._mean = [127, 127, 127]
+
+    def __call__(self, src, label):
+        if random.randint(0, 1):
+            return src, label
+
+        height, width, _ = src.shape
+        ratio = random.uniform(1, 4)
+        left = int(random.uniform(0, width * ratio - width))
+        top = int(random.uniform(0, height * ratio - height))
+        new_width = int(width * ratio)
+        new_height = int(height * ratio)
+
+        # default using mean pixels
+        expand_image = nd.repeat(nd.array(self._mean), repeats=new_width * new_height).reshape((new_height, new_width, -1))
+        expand_image[top:top + height, left:left + width, :] = src
+        new_label = label.copy()
+        new_label[:, 1:3] += (left, top)
+        new_label[:, 3:5] += (left, top)
+
+        return expand_image, new_label
+
+class Transpose(object):
+    """Transpose to tensor order"""
+    def __init__(self, order=(2, 0, 1)):
+        self._order = order
+
+    def __call__(self, src, label):
+        return nd.transpose(src, axes=self._order), label
+
+class SSDAugmentation(object):
+    def __init__(self, data_shape, mean_pixel=[123, 117, 104], std_pixel=[58, 57, 58]):
+        self._augments = Compose([
+            Cast(),
+            ToAbsoluteCoords(),
+            Expand(mean_pixel),
+            RandomSampleCrop(),
+            ToPercentCoords(),
+            image.det.DetHorizontalFlipAug(0.5),
+            ForceResize(data_shape),
+            image.det.DetBorrowAug(image.ColorNormalizeAug(mean_pixel, std_pixel)),
+            Transpose(),
+        ])
+
+    def __call__(self, src, label):
+        # print(self._augments(src, label)[1])
+        return self._augments(src, label)
+
+class SSDValid(object):
+    def __init__(self, data_shape, mean_pixel=[123, 117, 104], std_pixel=[58, 57, 58]):
+        self._augments = Compose([
+            Cast(),
+            ForceResize(data_shape),
+            image.det.DetBorrowAug(image.ColorNormalizeAug(mean_pixel, std_pixel)),
+            Transpose(),
+        ])
+
+    def __call__(self, src, label):
+        # print(self._augments(src, label)[1])
+        return self._augments(src, label)
+
+class SSDAugmentation2(object):
+    def __init__(self, data_shape):
+        ag_list = image.det.CreateDetAugmenter([3] + data_shape, rand_crop=0.8, rand_pad=0.8,
+            rand_mirror=True, mean=True, std=True)
+        ag_list.append(Transpose())
+        ag_list.append(Cast())
+        self._augments = Compose(ag_list)
+
+    def __call__(self, src, label):
+        return self._augments(src, label)
diff --git a/example/object-detection/dataset/utils.py b/example/object-detection/dataset/utils.py
new file mode 100644
index 0000000000..5583cd9917
--- /dev/null
+++ b/example/object-detection/dataset/utils.py
@@ -0,0 +1,20 @@
+"""Utility functions."""
+import os
+import sys
+
+def mkdirs_p(path):
+    """Make directory recursively if not exists.
+
+    Parameters
+    ----------
+    path : str
+        The destination directory to be created.
+    """
+    if sys.version_info[0:2] >= (3, 2):
+        os.makedirs(path, exist_ok=True)
+    else:
+        try:
+            os.makedirs(path)
+        except:
+            pass
+    assert os.path.isdir(path), "Unable to create directory: {}".format(path)
diff --git a/example/object-detection/dataset/voc.py b/example/object-detection/dataset/voc.py
new file mode 100644
index 0000000000..71a509be18
--- /dev/null
+++ b/example/object-detection/dataset/voc.py
@@ -0,0 +1,360 @@
+"""Pascal VOC dataset."""
+import os
+import logging
+import numpy as np
+try:
+    import xml.etree.cElementTree as ET
+except ImportError:
+    import xml.etree.ElementTree as ET
+try:
+    import cPickle as pickle
+except ImportError:
+    import pickle
+from mxnet import image
+from dataset.base import DetectionDataset
+from dataset.utils import mkdirs_p
+
+
+class VOCDetection(DetectionDataset):
+    """Pascal VOC detection Dataset.
+
+    Parameters
+    ----------
+    root : string
+        Path to VOCdevkit folder.
+    sets : list of tuples
+        List of combinations of (year, name), e.g. [(2007, 'trainval'), (2012, 'train')].
+        For years, candidates can be: 2007, 2012.
+        For names, candidates can be: 'train', 'val', 'trainval', 'test'.
+    flag : {0, 1}, default 1
+        If 0, always convert images to greyscale.
+
+        If 1, always convert images to colored (RGB).
+    transform : callable, optional
+        A function that takes data and label and transforms them::
+
+            transform = lambda data, label: (data.astype(np.float32)/255, label)
+        A transform function for object detection should take label into consideration,
+        because any geometric modification will require label to be modified.
+    index_map : dict, optional
+        If provided as dict, class indecies are mapped by looking up in the dict.
+        Otherwise will use alphabetic indexing for all classes from 0 to 19.
+    preload : bool
+        All labels will be parsed and loaded into memory at initialization.
+        This will allow early check for errors, and will be faster.
+    """
+    def __init__(self, root, sets, flag=1, transform=None, index_map=None, preload=True):
+        super(VOCDetection, self).__init__('voc')
+        self._im_shapes = {}
+        self._root = os.path.expanduser(root)
+        self._flag = flag
+        self._transform = transform
+        self._sets = sets
+        self._items = self._load_items(sets)
+        self._anno_path = os.path.join('{}', 'Annotations', '{}.xml')
+        self._image_path = os.path.join('{}', 'JPEGImages', '{}.jpg')
+        self.index_map = index_map or dict(zip(self.classes, range(self.num_classes)))
+        self._label_cache = self._preload_labels() if preload else None
+        self._comp = 'comp4'
+        self._cache_dir = os.path.join(os.path.dirname(__file__), '..', 'data', 'cache')
+        self._result_dir = os.path.join(self._root, 'result')
+
+    def __str__(self):
+        detail = ','.join([str(s[0]) + s[1] for s in self._sets])
+        return self.__class__.__name__ + '(' + detail + ')'
+
+    def _load_items(self, sets):
+        """Load individual image indices from sets."""
+        ids = []
+        for year, name in sets:
+            root = os.path.join(self._root, 'VOC' + str(year))
+            lf = os.path.join(root, 'ImageSets', 'Main', name + '.txt')
+            with open(lf, 'r') as f:
+                ids += [(root, line.strip()) for line in f.readlines()]
+        return ids
+
+    def _load_label(self, idx):
+        """Parse xml file and return labels."""
+        img_id = self._items[idx]
+        anno_path = self._anno_path.format(*img_id)
+        root = ET.parse(anno_path).getroot()
+        size = root.find('size')
+        width = float(size.find('width').text)
+        height = float(size.find('height').text)
+        if idx not in self._im_shapes:
+            # store the shapes for later usage
+            self._im_shapes[idx] = (width, height)
+        label = []
+        for obj in root.iter('object'):
+            difficult = int(obj.find('difficult').text)
+            cls_name = obj.find('name').text.strip().lower()
+            if cls_name not in self.classes:
+                continue
+            cls_id = self.index_map[cls_name]
+            xml_box = obj.find('bndbox')
+            xmin = (float(xml_box.find('xmin').text) - 1) / width
+            ymin = (float(xml_box.find('ymin').text) - 1) / height
+            xmax = (float(xml_box.find('xmax').text) - 1) / width
+            ymax = (float(xml_box.find('ymax').text) - 1) / height
+            try:
+                self._validator(xmin, ymin, xmax, ymax)
+            except AssertionError as e:
+                raise RuntimeError("Invalid label at {}, {}".format(anno_path, e))
+            label.append([cls_id, xmin, ymin, xmax, ymax, difficult])
+        return np.array(label)
+
+    def _validator(self, xmin, ymin, xmax, ymax):
+        """Validate labels."""
+        assert xmin >= 0 and xmin < 1.0, "xmin must in [0, 1), given {}".format(xmin)
+        assert ymin >= 0 and ymin < 1.0, "ymin must in [0, 1), given {}".format(ymin)
+        assert xmax > xmin and ymin <= 1.0, "xmax must in (xmin, 1], given {}".format(xmax)
+        assert ymax > ymin and ymax <= 1.0, "ymax must in (ymin, 1], given {}".format(ymax)
+
+    def _preload_labels(self):
+        """Preload all labels into memory."""
+        logging.debug("Preloading {} labels into memory...".format(str(self)))
+        return [self._load_label(idx) for idx in range(self.__len__())]
+
+    def __len__(self):
+        return len(self._items)
+
+    def __getitem__(self, idx):
+        img_id = self._items[idx]
+        img_path = self._image_path.format(*img_id)
+        label = self._label_cache[idx] if self._label_cache else self._load_label(idx)
+        img = image.imread(img_path, self._flag)
+        if self._transform is not None:
+            return self._transform(img, label)
+        return img, label
+
+    def eval_results(self, results):
+        """Evaluate results.
+
+
+        """
+        assert len(self._sets) == 1, "concatenated sets are not supposed to be evaluated."
+        assert isinstance(results, np.ndarray), (
+            "np.ndarray expected, given {}".format(type(results)))
+        assert len(self._items) == results.shape[0], (
+            "# image mismatch: {} vs. {}".format(len(self._items), results.shape[0]))
+        self._write_results(results)
+        return self.do_python_eval()
+
+    def _get_filename_template(self):
+        """Get filename template."""
+        dir_name = os.path.join(self._result_dir, 'VOC' + str(self._sets[0][0]), self._comp)
+        mkdirs_p(dir_name)
+        return os.path.join(dir_name, self._comp + '_det_' + self._sets[0][1] + '_{:s}.txt')
+
+    def _write_results(self, results):
+        """Write results to disk in compliance with PASCAL formats."""
+        for cls_name in self.classes:
+            logging.info('Writing {} VOC results file'.format(cls_name))
+            filename = self._get_filename_template().format(cls_name)
+            buf = []
+            for im_ind, index in enumerate(self._items):
+                dets = results[im_ind]
+                if dets.shape[0] < 1:
+                    continue
+                if im_ind not in self._im_shapes:
+                    self._load_label(im_ind)
+                w, h = self._im_shapes[im_ind]
+                # the VOCdevkit expects 1-based indices
+                for k in range(dets.shape[0]):
+                    if (int(dets[k, 0]) == self.index_map[cls_name]):
+                        buf.append('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
+                            format(index[1], dets[k, 1],
+                                   int(dets[k, 2] * w) + 1, int(dets[k, 3] * h) + 1,
+                                   int(dets[k, 4] * w) + 1, int(dets[k, 5] * h) + 1))
+            whole = ''.join(buf)
+            with open(filename, 'wt') as f:
+                f.write(whole)
+
+    def do_python_eval(self):
+        """Apply python evaluation functions."""
+        data_path = os.path.join(self._root, 'VOC' + str(self._sets[0][0]))
+        annopath = os.path.join(data_path, 'Annotations', '{}.xml')
+        imageset_file = os.path.join(data_path, 'ImageSets', 'Main', self._sets[0][1] + '.txt')
+        aps = []
+        use_07_metric = True if int(self._sets[0][0]) < 2010 else False
+        logging.info("Use VOC07 metric? " + ('Yes' if use_07_metric else 'No'))
+        for cls_ind, cls_name in enumerate(self.classes):
+            filename = self._get_filename_template().format(cls_name)
+            rec, prec, ap = self._voc_eval(
+                filename, annopath, imageset_file, cls_name, self._cache_dir,
+                ovthresh=0.5, use_07_metric=use_07_metric)
+            aps += [ap]
+            logging.info("AP for {} = {:.4f}".format(cls_name, ap))
+        mean_ap = np.mean(aps)
+        logging.info("Mean AP = {:.4f}".format(mean_ap))
+        return 'Mean AP', mean_ap
+
+    def _parse_voc_rec(self, filename):
+        """
+        parse pascal voc record into a dictionary
+        :param filename: xml file path
+        :return: list of dict
+        """
+        tree = ET.parse(filename)
+        objects = []
+        for obj in tree.findall('object'):
+            obj_dict = dict()
+            obj_dict['name'] = obj.find('name').text
+            obj_dict['difficult'] = int(obj.find('difficult').text)
+            bbox = obj.find('bndbox')
+            obj_dict['bbox'] = [int(bbox.find('xmin').text),
+                                int(bbox.find('ymin').text),
+                                int(bbox.find('xmax').text),
+                                int(bbox.find('ymax').text)]
+            objects.append(obj_dict)
+        return objects
+
+
+    def _voc_ap(self, rec, prec, use_07_metric=False):
+        """
+        average precision calculations
+        [precision integrated to recall]
+        :param rec: recall
+        :param prec: precision
+        :param use_07_metric: 2007 metric is 11-recall-point based AP
+        :return: average precision
+        """
+        if use_07_metric:
+            ap = 0.
+            for t in np.arange(0., 1.1, 0.1):
+                if np.sum(rec >= t) == 0:
+                    p = 0
+                else:
+                    p = np.max(prec[rec >= t])
+                ap += p / 11.
+        else:
+            # append sentinel values at both ends
+            mrec = np.concatenate(([0.], rec, [1.]))
+            mpre = np.concatenate(([0.], prec, [0.]))
+
+            # compute precision integration ladder
+            for i in range(mpre.size - 1, 0, -1):
+                mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+            # look for recall value changes
+            i = np.where(mrec[1:] != mrec[:-1])[0]
+
+            # sum (\delta recall) * prec
+            ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+        return ap
+
+
+    def _voc_eval(self, detpath, annopath, imageset_file, classname, cache_dir, ovthresh=0.5, use_07_metric=False):
+        """
+        pascal voc evaluation
+        :param detpath: detection results detpath.format(classname)
+        :param annopath: annotations annopath.format(classname)
+        :param imageset_file: text file containing list of images
+        :param classname: category name
+        :param cache_dir: caching annotations
+        :param ovthresh: overlap threshold
+        :param use_07_metric: whether to use voc07's 11 point ap computation
+        :return: rec, prec, ap
+        """
+        if not os.path.isdir(cache_dir):
+            os.mkdir(cache_dir)
+        cache_file = os.path.join(cache_dir, 'annotations.pkl')
+        with open(imageset_file, 'r') as f:
+            lines = f.readlines()
+        image_filenames = [x.strip() for x in lines]
+
+        # load annotations from cache
+        if not os.path.isfile(cache_file):
+            recs = {}
+            for ind, image_filename in enumerate(image_filenames):
+                recs[image_filename] = self._parse_voc_rec(annopath.format(image_filename))
+                if ind % 1000 == 0:
+                    logging.debug('reading annotations for {:d}/{:d}'.format(ind + 1, len(image_filenames)))
+            logging.debug('saving annotations cache to {:s}'.format(cache_file))
+            with open(cache_file, 'wb') as f:
+                pickle.dump(recs, f)
+        else:
+            with open(cache_file, 'rb') as f:
+                recs = pickle.load(f)
+
+        # extract objects in :param classname:
+        class_recs = {}
+        npos = 0
+        for image_filename in image_filenames:
+            objects = [obj for obj in recs[image_filename] if obj['name'] == classname]
+            bbox = np.array([x['bbox'] for x in objects])
+            difficult = np.array([x['difficult'] for x in objects]).astype(np.bool)
+            det = [False] * len(objects)  # stand for detected
+            npos = npos + sum(~difficult)
+            class_recs[image_filename] = {'bbox': bbox,
+                                          'difficult': difficult,
+                                          'det': det}
+
+        # read detections
+        detfile = detpath.format(classname)
+        with open(detfile, 'r') as f:
+            lines = f.readlines()
+
+        if not lines:
+            return 0.0, 0.0, 0.0
+
+        splitlines = [x.strip().split(' ') for x in lines]
+        image_ids = [x[0] for x in splitlines]
+        confidence = np.array([float(x[1]) for x in splitlines])
+        bbox = np.array([[float(z) for z in x[2:]] for x in splitlines])
+
+        # sort by confidence
+        sorted_inds = np.argsort(-confidence)
+        sorted_scores = np.sort(-confidence)
+        bbox = bbox[sorted_inds, :]
+        image_ids = [image_ids[x] for x in sorted_inds]
+
+        # go down detections and mark true positives and false positives
+        nd = len(image_ids)
+        tp = np.zeros(nd)
+        fp = np.zeros(nd)
+        for d in range(nd):
+            r = class_recs[image_ids[d]]
+            bb = bbox[d, :].astype(float)
+            ovmax = -np.inf
+            bbgt = r['bbox'].astype(float)
+
+            if bbgt.size > 0:
+                # compute overlaps
+                # intersection
+                ixmin = np.maximum(bbgt[:, 0], bb[0])
+                iymin = np.maximum(bbgt[:, 1], bb[1])
+                ixmax = np.minimum(bbgt[:, 2], bb[2])
+                iymax = np.minimum(bbgt[:, 3], bb[3])
+                iw = np.maximum(ixmax - ixmin + 1., 0.)
+                ih = np.maximum(iymax - iymin + 1., 0.)
+                inters = iw * ih
+
+                # union
+                uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
+                       (bbgt[:, 2] - bbgt[:, 0] + 1.) *
+                       (bbgt[:, 3] - bbgt[:, 1] + 1.) - inters)
+
+                overlaps = inters / uni
+                ovmax = np.max(overlaps)
+                jmax = np.argmax(overlaps)
+
+            if ovmax > ovthresh:
+                if not r['difficult'][jmax]:
+                    if not r['det'][jmax]:
+                        tp[d] = 1.
+                        r['det'][jmax] = 1
+                    else:
+                        fp[d] = 1.
+            else:
+                fp[d] = 1.
+
+        # compute precision recall
+        fp = np.cumsum(fp)
+        tp = np.cumsum(tp)
+        rec = tp / float(npos)
+        # avoid division by zero in case first detection matches a difficult ground ruth
+        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+        ap = self._voc_ap(rec, prec, use_07_metric)
+
+        return rec, prec, ap
diff --git a/example/object-detection/demo.py b/example/object-detection/demo.py
new file mode 100644
index 0000000000..55f7dd891f
--- /dev/null
+++ b/example/object-detection/demo.py
@@ -0,0 +1,42 @@
+"""Train Gluon Object-Detection models."""
+import os
+import argparse
+import mxnet as mx
+from predict import predict_ssd
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train a gluon detection network',
+                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--algorithm', dest='algorithm', type=str, default='ssd',
+                        help='which network to use')
+    parser.add_argument('--data-shape', dest='data_shape', type=str, default='512',
+                        help='image data shape, can be int or tuple')
+    parser.add_argument('--model', dest='model', type=str, default='resnet50_v1',
+                        help='base network to use, choices are models from gluon model_zoo')
+    parser.add_argument('--dataset', dest='dataset', type=str, default='voc',
+                        help='which dataset to use')
+    parser.add_argument('--images', dest='images', type=str, default='./data/demo/dog.jpg',
+                        help='run demo with images, use comma to seperate multiple images')
+    parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
+                        help='training batch size')
+    parser.add_argument('--pretrained', type=int, default=1,
+                        help='Whether use pretrained models. '
+                        ' 0: from scratch, 1: use base model, 2: use pretrained detection model')
+    parser.add_argument('--prefix', dest='prefix', type=str, help='new model prefix',
+                        default=os.path.join(os.path.dirname(__file__), 'model', 'default'))
+    parser.add_argument('--gpus', dest='gpus', help='GPU devices to train with',
+                        default='0', type=str)
+    args = parser.parse_args()
+    return args
+
+if __name__ == '__main__':
+    args = parse_args()
+    # context list
+    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
+    ctx = [mx.cpu()] if not ctx else ctx
+    # choose algorithm
+    if args.algorithm.lower() == 'ssd':
+        model = '_'.join([args.algorithm, args.data_shape, args.model])
+        predict_ssd.predict_net(args.images, args.model, args.data_shape, num_class=20)
+    else:
+        raise NotImplementedError("Training algorithm {} not supported.".format(args.algorithm))
diff --git a/example/object-detection/evaluation/__init__.py b/example/object-detection/evaluation/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/evaluation/eval_metric.py b/example/object-detection/evaluation/eval_metric.py
new file mode 100644
index 0000000000..bb2b77b3d5
--- /dev/null
+++ b/example/object-detection/evaluation/eval_metric.py
@@ -0,0 +1,295 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+
+class MApMetric(mx.metric.EvalMetric):
+    """
+    Calculate mean AP for object detection task
+
+    Parameters:
+    ---------
+    ovp_thresh : float
+        overlap threshold for TP
+    use_difficult : boolean
+        use difficult ground-truths if applicable, otherwise just ignore
+    class_names : list of str
+        optional, if provided, will print out AP for each class
+    pred_idx : int
+        prediction index in network output list
+    """
+    def __init__(self, ovp_thresh=0.5, use_difficult=False, class_names=None, pred_idx=0):
+        super(MApMetric, self).__init__('mAP')
+        if class_names is None:
+            self.num = None
+        else:
+            assert isinstance(class_names, (list, tuple))
+            for name in class_names:
+                assert isinstance(name, str), "must provide names as str"
+            num = len(class_names)
+            self.name = class_names + ['mAP']
+            self.num = num + 1
+        self.reset()
+        self.ovp_thresh = ovp_thresh
+        self.use_difficult = use_difficult
+        self.class_names = class_names
+        self.pred_idx = int(pred_idx)
+
+    def reset(self):
+        """Clear the internal statistics to initial state."""
+        if getattr(self, 'num', None) is None:
+            self.num_inst = 0
+            self.sum_metric = 0.0
+        else:
+            self.num_inst = [0] * self.num
+            self.sum_metric = [0.0] * self.num
+        self.records = dict()
+        self.counts = dict()
+
+    def get(self):
+        """Get the current evaluation result.
+
+        Returns
+        -------
+        name : str
+           Name of the metric.
+        value : float
+           Value of the evaluation.
+        """
+        self._update()  # update metric at this time
+        if self.num is None:
+            if self.num_inst == 0:
+                return (self.name, float('nan'))
+            else:
+                return (self.name, self.sum_metric / self.num_inst)
+        else:
+            names = ['%s'%(self.name[i]) for i in range(self.num)]
+            values = [x / y if y != 0 else float('nan') \
+                for x, y in zip(self.sum_metric, self.num_inst)]
+            return (names, values)
+
+    def update(self, labels, preds):
+        """
+        Update internal records. This function now only update internal buffer,
+        sum_metric and num_inst are updated in _update() function instead when
+        get() is called to return results.
+
+        Params:
+        ----------
+        labels: mx.nd.array (n * 6) or (n * 5), difficult column is optional
+            2-d array of ground-truths, n objects(id-xmin-ymin-xmax-ymax-[difficult])
+        preds: mx.nd.array (m * 6)
+            2-d array of detections, m objects(id-score-xmin-ymin-xmax-ymax)
+        """
+        def iou(x, ys):
+            """
+            Calculate intersection-over-union overlap
+            Params:
+            ----------
+            x : numpy.array
+                single box [xmin, ymin ,xmax, ymax]
+            ys : numpy.array
+                multiple box [[xmin, ymin, xmax, ymax], [...], ]
+            Returns:
+            -----------
+            numpy.array
+                [iou1, iou2, ...], size == ys.shape[0]
+            """
+            ixmin = np.maximum(ys[:, 0], x[0])
+            iymin = np.maximum(ys[:, 1], x[1])
+            ixmax = np.minimum(ys[:, 2], x[2])
+            iymax = np.minimum(ys[:, 3], x[3])
+            iw = np.maximum(ixmax - ixmin, 0.)
+            ih = np.maximum(iymax - iymin, 0.)
+            inters = iw * ih
+            uni = (x[2] - x[0]) * (x[3] - x[1]) + (ys[:, 2] - ys[:, 0]) * \
+                (ys[:, 3] - ys[:, 1]) - inters
+            ious = inters / uni
+            ious[uni < 1e-12] = 0  # in case bad boxes
+            return ious
+
+        # independant execution for each image
+        for i in range(labels[0].shape[0]):
+            # get as numpy arrays
+            label = labels[0][i].asnumpy()
+            if np.sum(label[:, 0] >= 0) < 1:
+                continue
+            pred = preds[self.pred_idx][i].asnumpy()
+            # calculate for each class
+            while (pred.shape[0] > 0):
+                cid = int(pred[0, 0])
+                indices = np.where(pred[:, 0].astype(int) == cid)[0]
+                if cid < 0:
+                    pred = np.delete(pred, indices, axis=0)
+                    continue
+                dets = pred[indices]
+                pred = np.delete(pred, indices, axis=0)
+                # sort by score, desceding
+                dets[dets[:,1].argsort()[::-1]]
+                records = np.hstack((dets[:, 1][:, np.newaxis], np.zeros((dets.shape[0], 1))))
+                # ground-truths
+                label_indices = np.where(label[:, 0].astype(int) == cid)[0]
+                gts = label[label_indices, :]
+                label = np.delete(label, label_indices, axis=0)
+                if gts.size > 0:
+                    found = [False] * gts.shape[0]
+                    for j in range(dets.shape[0]):
+                        # compute overlaps
+                        ious = iou(dets[j, 2:], gts[:, 1:5])
+                        ovargmax = np.argmax(ious)
+                        ovmax = ious[ovargmax]
+                        if ovmax > self.ovp_thresh:
+                            if (not self.use_difficult and
+                                gts.shape[1] >= 6 and
+                                gts[ovargmax, 5] > 0):
+                                pass
+                            else:
+                                if not found[ovargmax]:
+                                    records[j, -1] = 1  # tp
+                                    found[ovargmax] = True
+                                else:
+                                    # duplicate
+                                    records[j, -1] = 2  # fp
+                        else:
+                            records[j, -1] = 2 # fp
+                else:
+                    # no gt, mark all fp
+                    records[:, -1] = 2
+
+                # ground truth count
+                if (not self.use_difficult and gts.shape[1] >= 6):
+                    gt_count = np.sum(gts[:, 5] < 1)
+                else:
+                    gt_count = gts.shape[0]
+
+                # now we push records to buffer
+                # first column: score, second column: tp/fp
+                # 0: not set(matched to difficult or something), 1: tp, 2: fp
+                records = records[np.where(records[:, -1] > 0)[0], :]
+                if records.size > 0:
+                    self._insert(cid, records, gt_count)
+
+            # add missing class if not present in prediction
+            while (label.shape[0] > 0):
+                cid = int(label[0, 0])
+                label_indices = np.where(label[:, 0].astype(int) == cid)[0]
+                label = np.delete(label, label_indices, axis=0)
+                if cid < 0:
+                    continue
+                gt_count = label_indices.size
+                self._insert(cid, np.array([[0, 0]]), gt_count)
+
+    def _update(self):
+        """ update num_inst and sum_metric """
+        aps = []
+        for k, v in self.records.items():
+            recall, prec = self._recall_prec(v, self.counts[k])
+            ap = self._average_precision(recall, prec)
+            aps.append(ap)
+            if self.num is not None and k < (self.num - 1):
+                self.sum_metric[k] = ap
+                self.num_inst[k] = 1
+        if self.num is None:
+            self.num_inst = 1
+            self.sum_metric = np.mean(aps)
+        else:
+            self.num_inst[-1] = 1
+            self.sum_metric[-1] = np.mean(aps)
+
+    def _recall_prec(self, record, count):
+        """ get recall and precision from internal records """
+        record = np.delete(record, np.where(record[:, 1].astype(int) == 0)[0], axis=0)
+        sorted_records = record[record[:,0].argsort()[::-1]]
+        tp = np.cumsum(sorted_records[:, 1].astype(int) == 1)
+        fp = np.cumsum(sorted_records[:, 1].astype(int) == 2)
+        if count <= 0:
+            recall = tp * 0.0
+        else:
+            recall = tp / float(count)
+        prec = tp.astype(float) / (tp + fp)
+        return recall, prec
+
+    def _average_precision(self, rec, prec):
+        """
+        calculate average precision
+
+        Params:
+        ----------
+        rec : numpy.array
+            cumulated recall
+        prec : numpy.array
+            cumulated precision
+        Returns:
+        ----------
+        ap as float
+        """
+        # append sentinel values at both ends
+        mrec = np.concatenate(([0.], rec, [1.]))
+        mpre = np.concatenate(([0.], prec, [0.]))
+
+        # compute precision integration ladder
+        for i in range(mpre.size - 1, 0, -1):
+            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+        # look for recall value changes
+        i = np.where(mrec[1:] != mrec[:-1])[0]
+
+        # sum (\delta recall) * prec
+        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+        return ap
+
+    def _insert(self, key, records, count):
+        """ Insert records according to key """
+        if key not in self.records:
+            assert key not in self.counts
+            self.records[key] = records
+            self.counts[key] = count
+        else:
+            self.records[key] = np.vstack((self.records[key], records))
+            assert key in self.counts
+            self.counts[key] += count
+
+
+class VOC07MApMetric(MApMetric):
+    """ Mean average precision metric for PASCAL V0C 07 dataset """
+    def __init__(self, *args, **kwargs):
+        super(VOC07MApMetric, self).__init__(*args, **kwargs)
+
+    def _average_precision(self, rec, prec):
+        """
+        calculate average precision, override the default one,
+        special 11-point metric
+
+        Params:
+        ----------
+        rec : numpy.array
+            cumulated recall
+        prec : numpy.array
+            cumulated precision
+        Returns:
+        ----------
+        ap as float
+        """
+        ap = 0.
+        for t in np.arange(0., 1.1, 0.1):
+            if np.sum(rec >= t) == 0:
+                p = 0
+            else:
+                p = np.max(prec[rec >= t])
+            ap += p / 11.
+        return ap
diff --git a/example/object-detection/evaluation/voc.py b/example/object-detection/evaluation/voc.py
new file mode 100644
index 0000000000..0da8291228
--- /dev/null
+++ b/example/object-detection/evaluation/voc.py
@@ -0,0 +1,177 @@
+"""Pascal VOC evaluation."""
+from __future__ import print_function
+import numpy as np
+import os
+try:
+    import cPickle as pickle
+except ImportError:
+    import pickle
+
+
+def parse_voc_rec(filename):
+    """
+    parse pascal voc record into a dictionary
+    :param filename: xml file path
+    :return: list of dict
+    """
+    import xml.etree.ElementTree as ET
+    tree = ET.parse(filename)
+    objects = []
+    for obj in tree.findall('object'):
+        obj_dict = dict()
+        obj_dict['name'] = obj.find('name').text
+        obj_dict['difficult'] = int(obj.find('difficult').text)
+        bbox = obj.find('bndbox')
+        obj_dict['bbox'] = [int(bbox.find('xmin').text),
+                            int(bbox.find('ymin').text),
+                            int(bbox.find('xmax').text),
+                            int(bbox.find('ymax').text)]
+        objects.append(obj_dict)
+    return objects
+
+
+def voc_ap(rec, prec, use_07_metric=False):
+    """
+    average precision calculations
+    [precision integrated to recall]
+    :param rec: recall
+    :param prec: precision
+    :param use_07_metric: 2007 metric is 11-recall-point based AP
+    :return: average precision
+    """
+    if use_07_metric:
+        ap = 0.
+        for t in np.arange(0., 1.1, 0.1):
+            if np.sum(rec >= t) == 0:
+                p = 0
+            else:
+                p = np.max(prec[rec >= t])
+            ap += p / 11.
+    else:
+        # append sentinel values at both ends
+        mrec = np.concatenate(([0.], rec, [1.]))
+        mpre = np.concatenate(([0.], prec, [0.]))
+
+        # compute precision integration ladder
+        for i in range(mpre.size - 1, 0, -1):
+            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+        # look for recall value changes
+        i = np.where(mrec[1:] != mrec[:-1])[0]
+
+        # sum (\delta recall) * prec
+        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+    return ap
+
+
+def voc_eval(detpath, annopath, imageset_file, classname, cache_dir, ovthresh=0.5, use_07_metric=False):
+    """
+    pascal voc evaluation
+    :param detpath: detection results detpath.format(classname)
+    :param annopath: annotations annopath.format(classname)
+    :param imageset_file: text file containing list of images
+    :param classname: category name
+    :param cache_dir: caching annotations
+    :param ovthresh: overlap threshold
+    :param use_07_metric: whether to use voc07's 11 point ap computation
+    :return: rec, prec, ap
+    """
+    if not os.path.isdir(cache_dir):
+        os.mkdir(cache_dir)
+    cache_file = os.path.join(cache_dir, 'annotations.pkl')
+    with open(imageset_file, 'r') as f:
+        lines = f.readlines()
+    image_filenames = [x.strip() for x in lines]
+
+    # load annotations from cache
+    if not os.path.isfile(cache_file):
+        recs = {}
+        for ind, image_filename in enumerate(image_filenames):
+            recs[image_filename] = parse_voc_rec(annopath.format(image_filename))
+            if ind % 100 == 0:
+                print('reading annotations for {:d}/{:d}'.format(ind + 1, len(image_filenames)))
+        print('saving annotations cache to {:s}'.format(cache_file))
+        with open(cache_file, 'wb') as f:
+            pickle.dump(recs, f)
+    else:
+        with open(cache_file, 'rb') as f:
+            recs = pickle.load(f)
+
+    # extract objects in :param classname:
+    class_recs = {}
+    npos = 0
+    for image_filename in image_filenames:
+        objects = [obj for obj in recs[image_filename] if obj['name'] == classname]
+        bbox = np.array([x['bbox'] for x in objects])
+        difficult = np.array([x['difficult'] for x in objects]).astype(np.bool)
+        det = [False] * len(objects)  # stand for detected
+        npos = npos + sum(~difficult)
+        class_recs[image_filename] = {'bbox': bbox,
+                                      'difficult': difficult,
+                                      'det': det}
+
+    # read detections
+    detfile = detpath.format(classname)
+    with open(detfile, 'r') as f:
+        lines = f.readlines()
+
+    splitlines = [x.strip().split(' ') for x in lines]
+    image_ids = [x[0] for x in splitlines]
+    confidence = np.array([float(x[1]) for x in splitlines])
+    bbox = np.array([[float(z) for z in x[2:]] for x in splitlines])
+
+    # sort by confidence
+    sorted_inds = np.argsort(-confidence)
+    sorted_scores = np.sort(-confidence)
+    bbox = bbox[sorted_inds, :]
+    image_ids = [image_ids[x] for x in sorted_inds]
+
+    # go down detections and mark true positives and false positives
+    nd = len(image_ids)
+    tp = np.zeros(nd)
+    fp = np.zeros(nd)
+    for d in range(nd):
+        r = class_recs[image_ids[d]]
+        bb = bbox[d, :].astype(float)
+        ovmax = -np.inf
+        bbgt = r['bbox'].astype(float)
+
+        if bbgt.size > 0:
+            # compute overlaps
+            # intersection
+            ixmin = np.maximum(bbgt[:, 0], bb[0])
+            iymin = np.maximum(bbgt[:, 1], bb[1])
+            ixmax = np.minimum(bbgt[:, 2], bb[2])
+            iymax = np.minimum(bbgt[:, 3], bb[3])
+            iw = np.maximum(ixmax - ixmin + 1., 0.)
+            ih = np.maximum(iymax - iymin + 1., 0.)
+            inters = iw * ih
+
+            # union
+            uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
+                   (bbgt[:, 2] - bbgt[:, 0] + 1.) *
+                   (bbgt[:, 3] - bbgt[:, 1] + 1.) - inters)
+
+            overlaps = inters / uni
+            ovmax = np.max(overlaps)
+            jmax = np.argmax(overlaps)
+
+        if ovmax > ovthresh:
+            if not r['difficult'][jmax]:
+                if not r['det'][jmax]:
+                    tp[d] = 1.
+                    r['det'][jmax] = 1
+                else:
+                    fp[d] = 1.
+        else:
+            fp[d] = 1.
+
+    # compute precision recall
+    fp = np.cumsum(fp)
+    tp = np.cumsum(tp)
+    rec = tp / float(npos)
+    # avoid division by zero in case first detection matches a difficult ground ruth
+    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+    ap = voc_ap(rec, prec, use_07_metric)
+
+    return rec, prec, ap
diff --git a/example/object-detection/model/README.md b/example/object-detection/model/README.md
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/model_zoo/__init__.py b/example/object-detection/model_zoo/__init__.py
new file mode 100644
index 0000000000..eb7b81940e
--- /dev/null
+++ b/example/object-detection/model_zoo/__init__.py
@@ -0,0 +1,31 @@
+"""Module for pretrained model for object-detection package.
+"""
+from mxnet.gluon.model_zoo import vision
+from model_zoo.ssd import *
+
+def get_detection_model(name, **kwargs):
+    """Return a pre-defined model by name
+
+    Parameters
+    ----------
+    name : str
+        Name of the model.
+    pretrained : int
+        Whether to load the pretrained weights for model.
+    classes : int
+        Number of classes for the output layer.
+
+    Returns
+    -------
+    Block
+        The model.
+    """
+    models = {'ssd_512_resnet18_v1': ssd_512_resnet18_v1,
+              'ssd_512_resnet50_v1': ssd_512_resnet50_v1,
+             }
+    name = name.lower()
+    if name not in models:
+        raise ValueError(
+            'Model %s is not supported. Available options are\n\t%s'%(
+                name, '\n\t'.join(sorted(models.keys()))))
+    return models[name](**kwargs)
diff --git a/example/object-detection/model_zoo/faster_rcnn.py b/example/object-detection/model_zoo/faster_rcnn.py
new file mode 100644
index 0000000000..ffcdddbfbc
--- /dev/null
+++ b/example/object-detection/model_zoo/faster_rcnn.py
@@ -0,0 +1,15 @@
+"""Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks.
+arXiv:1506.01497v1
+"""
+import mxnet as mx
+from mxnet import ndarray as nd
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.gluon import Block, HybridBlock
+
+
+class FasterRCNN(Block):
+    """
+    """
+    def __init__(self):
+        pass
diff --git a/example/object-detection/model_zoo/mask_rcnn.py b/example/object-detection/model_zoo/mask_rcnn.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/model_zoo/model_store.py b/example/object-detection/model_zoo/model_store.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/model_zoo/retinanet.py b/example/object-detection/model_zoo/retinanet.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/model_zoo/rfcn.py b/example/object-detection/model_zoo/rfcn.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/model_zoo/ssd.py b/example/object-detection/model_zoo/ssd.py
new file mode 100644
index 0000000000..a575f4e4e2
--- /dev/null
+++ b/example/object-detection/model_zoo/ssd.py
@@ -0,0 +1,119 @@
+"""Single-shot Multi-box Detector.
+"""
+from collections import namedtuple
+import mxnet as mx
+from mxnet import ndarray as nd
+from mxnet import gluon
+from mxnet.gluon import nn
+from mxnet.gluon import Block, HybridBlock
+from mxnet.gluon.model_zoo import vision
+from block.feature import FeatureExpander
+from block.anchor import SSDAnchorGenerator
+from block.predictor import ConvPredictor
+
+class SSDNet(Block):
+    """
+
+    """
+    def __init__(self, network, features, num_filters, scale, ratios, base_size,
+                 num_classes, strides=None, offsets=None, clip=None,
+                 use_1x1_transition=True, use_bn=True, reduce_ratio=1.0,
+                 min_depth=128, global_pool=False, pretrained=False,
+                 ctx=mx.cpu(), **kwargs):
+        super(SSDNet, self).__init__(**kwargs)
+        num_layers = len(features) + len(num_filters) + int(global_pool)
+        assert len(scale) == 2, "Must specify scale as (min_scale, max_scale)."
+        min_scale, max_scale = scale
+        sizes = [min_scale + (max_scale - min_scale) * i / (num_layers - 1)
+                 for i in range(num_layers)] + [1.0]
+        sizes = [x * base_size for x in sizes]
+        sizes = list(zip(sizes[:-1], sizes[1:]))
+        assert isinstance(ratios, list), "Must provide ratios as list or list of list"
+        if not isinstance(ratios[0], (tuple, list)):
+            ratios = ratios * num_layers  # propagate to all layers if use same ratio
+        assert num_layers == len(sizes) == len(ratios), \
+            "Mismatched (number of layers) vs (sizes) vs (ratios): {}, {}, {}".format(
+                num_layers, len(sizes), len(ratios))
+        assert num_layers > 0, "SSD require at least one layer, suggest multiple."
+        self._num_layers = num_layers
+        self.num_classes = num_classes + 1
+        self.features = FeatureExpander(
+            network=network, outputs=features, num_filters=num_filters,
+            use_1x1_transition=use_1x1_transition,
+            use_bn=use_bn, reduce_ratio=reduce_ratio, min_depth=min_depth,
+            global_pool=global_pool, pretrained=pretrained, ctx=ctx)
+
+
+        with self.name_scope():
+            self.class_predictors = nn.HybridSequential()
+            self.box_predictors = nn.HybridSequential()
+            self.anchor_generators = nn.Sequential()
+            for i, s, r in zip(range(num_layers), sizes, ratios):
+                self.anchor_generators.add(SSDAnchorGenerator(
+                    s, r, im_size=(base_size, base_size), clip=clip))
+                num_anchors = self.anchor_generators[-1].num_depth
+                self.class_predictors.add(ConvPredictor(num_anchors * self.num_classes))
+                self.box_predictors.add(ConvPredictor(num_anchors * 4))
+
+    def forward(self, x, *args):
+        features = self.features(x)
+        cls_preds = [nd.flatten(nd.transpose(cp(feat), (0, 2, 3, 1)))
+            for feat, cp in zip(features, self.class_predictors)]
+        box_preds = [nd.flatten(nd.transpose(bp(feat), (0, 2, 3, 1)))
+            for feat, bp in zip(features, self.box_predictors)]
+        anchors = [nd.reshape(ag(feat), shape=(1, -1))
+            for feat, ag in zip(features, self.anchor_generators)]
+        # for i in range(len(features)):
+        #     print(features[i].shape, cls_preds[i].shape, box_preds[i].shape, anchors[i].shape)
+        # concat
+        cls_preds = nd.concat(*cls_preds, dim=1).reshape((0, -1, self.num_classes))
+        box_preds = nd.concat(*box_preds, dim=1).reshape((0, -1, 4))
+        anchors = nd.concat(*anchors, dim=1).reshape((1, -1, 4))
+        # sync device since anchors are always generated on cpu currently
+        anchors = anchors.as_in_context(cls_preds.context)
+        return [cls_preds, box_preds, anchors]
+
+
+SSDConfig = namedtuple('SSDConfig', 'features num_filters scale ratios')
+
+_factory = {
+    'resnet18_v1_512': SSDConfig(
+        ['stage3_activation1', 'stage4_activation1'], [512, 512, 256, 256],
+        [0.1, 0.95], [[1, 2, 0.5]] + [[1, 2, 0.5, 3, 1.0/3]] * 5),
+    'resnet50_v1_512': SSDConfig(
+        ['stage3_activation5', 'stage4_activation2'], [512, 512, 256, 256],
+        [0.1, 0.95], [[1, 2, 0.5]] + [[1, 2, 0.5, 3, 1.0/3]] * 5),
+}
+
+def get_ssd(name, base_size, classes, pretrained=0, ctx=mx.cpu(), **kwargs):
+    """Get SSD models.
+
+    Parameters
+    ----------
+    name : str
+        Model name
+    base_size : int
+
+    """
+    key = '{}_{}'.format(name, base_size)
+    if not key in _factory:
+        raise NotImplementedError("{} not defined in model_zoo".format(key))
+    c = _factory[key]
+    net = SSDNet(name, c.features, c.num_filters, c.scale, c.ratios, base_size,
+                 num_classes=classes, pretrained=pretrained > 0, ctx=ctx, **kwargs)
+    if pretrained > 1:
+        # load trained ssd model
+        raise NotImplementedError("Loading pretrained model for detection is not finished.")
+    return net
+
+def ssd_512_resnet18_v1(pretrained=0, classes=20, ctx=mx.cpu(), **kwargs):
+    """SSD architecture with ResNet v1 18 layers.
+
+    """
+    return get_ssd('resnet18_v1', 512, classes, pretrained, ctx, **kwargs)
+
+def ssd_512_resnet50_v1(pretrained=0, classes=20, ctx=mx.cpu(), **kwargs):
+    """SSD architecture with ResNet v1 50 layers.
+
+    """
+    return get_ssd('resnet50_v1', 512, classes, pretrained, ctx, **kwargs)
diff --git a/example/object-detection/model_zoo/yolo.py b/example/object-detection/model_zoo/yolo.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/predict/__init__.py b/example/object-detection/predict/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/predict/predict_ssd.py b/example/object-detection/predict/predict_ssd.py
new file mode 100644
index 0000000000..c8fa0ca0a8
--- /dev/null
+++ b/example/object-detection/predict/predict_ssd.py
@@ -0,0 +1,123 @@
+import argparse
+import os
+import logging
+import time
+import random
+import mxnet as mx
+from mxnet import nd
+from mxnet import gluon
+from mxnet import autograd as ag
+from dataset.dataloader import DataLoader
+from dataset import VOCDetection
+from dataset import transform
+from config import config as cfg
+import model_zoo
+from block.loss import *
+from block.target import *
+from block.coder import MultiClassDecoder, NormalizedBoxCenterDecoder
+from trainer.metric import Accuracy, SmoothL1, LossRecorder, MultiBoxMetric
+from trainer.debugger import super_print, find_abnormal
+from evaluation.eval_metric import VOC07MApMetric, MApMetric
+
+def preprocess(filename, data_shape):
+    im = mx.image.imread(filename)
+    im = mx.image.imresize(im, data_shape[1], data_shape[0])
+    im = im.astype('float32')
+    im -= mx.nd.array([123, 117, 104])
+    im /= mx.nd.array([58, 57, 57])
+    im = im.transpose((2, 0, 1))
+    im = im.expand_dims(axis=0)
+    return im
+
+def visualize_detection(img, dets, classes=[], thresh=0.6):
+    """
+    visualize detections in one image
+
+    Parameters:
+    ----------
+    img : numpy.array
+        image, in bgr format
+    dets : numpy.array
+        ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
+        each row is one object
+    classes : tuple or list of str
+        class names
+    thresh : float
+        score threshold
+    """
+    import matplotlib.pyplot as plt
+    import random
+    plt.imshow(img)
+    height = img.shape[0]
+    width = img.shape[1]
+    colors = dict()
+    for i in range(dets.shape[0]):
+        cls_id = int(dets[i, 0])
+        if cls_id >= 0:
+            score = dets[i, 1]
+            if score > thresh:
+                if cls_id not in colors:
+                    colors[cls_id] = (random.random(), random.random(), random.random())
+                xmin = int(dets[i, 2] * width)
+                ymin = int(dets[i, 3] * height)
+                xmax = int(dets[i, 4] * width)
+                ymax = int(dets[i, 5] * height)
+                rect = plt.Rectangle((xmin, ymin), xmax - xmin,
+                                     ymax - ymin, fill=False,
+                                     edgecolor=colors[cls_id],
+                                     linewidth=3.5)
+                plt.gca().add_patch(rect)
+                class_name = str(cls_id)
+                if classes and len(classes) > cls_id:
+                    class_name = classes[cls_id]
+                plt.gca().text(xmin, ymin - 2,
+                                '{:s} {:.3f}'.format(class_name, score),
+                                bbox=dict(facecolor=colors[cls_id], alpha=0.5),
+                                fontsize=12, color='white')
+    plt.show()
+
+def predict_net(im_path, model, data_shape, num_class,
+              pretrained=0, seed=None, log_file=None, dev=False, ctx=mx.cpu(), **kwargs):
+    """Wrapper function for entire training phase.
+
+
+
+
+    """
+    data_shape = [int(x) for x in data_shape.split(',')]
+    if len(data_shape) == 1:
+        data_shape = data_shape * 2
+
+    model = '_'.join(['ssd', str(data_shape[0]), model])
+
+    class_names = 'aeroplane, bicycle, bird, boat, bottle, bus, \
+    car, cat, chair, cow, diningtable, dog, horse, motorbike, \
+    person, pottedplant, sheep, sofa, train, tvmonitor'.split(',')
+
+    net = model_zoo.get_detection_model(model, pretrained=pretrained, classes=num_class, ctx=ctx)
+    net.collect_params().load(os.path.join(os.path.dirname(__file__), '..', 'model', 'ssd.params'), ctx=ctx)
+
+    def ctx_as_list(ctx):
+        if isinstance(ctx, mx.Context):
+            ctx = [ctx]
+        return ctx
+
+    box_decoder = NormalizedBoxCenterDecoder()
+    cls_decoder = MultiClassDecoder()
+
+    x = preprocess(im_path, data_shape)
+    z = net(x)
+    cls_preds, box_preds, anchors = z
+    # out1 = mx.nd.contrib.MultiBoxDetection(nd.softmax(cls_preds).transpose((0, 2, 1)), box_preds.reshape((0, -1)), anchors, nms_topk=400)
+    # print(out)
+    # visualize_detection(mx.image.imread(im_path).asnumpy(), out[0].asnumpy(), class_names, thresh=0.1)
+    # raise
+    boxes = box_decoder(box_preds, anchors)
+    boxes = nd.clip(boxes, 0.0, 1.0)
+    cls_ids, scores = cls_decoder(nd.softmax(cls_preds))
+    # print(mx.nd.sum(cls_ids > -0.5))
+    result = nd.concat(cls_ids.reshape((0, 0, 1)), scores.reshape((0, 0, 1)), boxes, dim=2)
+    out = nd.contrib.box_nms(result, topk=400)
+    # np.testing.assert_allclose(out1.asnumpy(), out.asnumpy())
+    # print(out)
+    visualize_detection(mx.image.imread(im_path).asnumpy(), out[0].asnumpy(), class_names, thresh=0.1)
diff --git a/example/object-detection/proposal.md b/example/object-detection/proposal.md
new file mode 100644
index 0000000000..a7ca1fe9c5
--- /dev/null
+++ b/example/object-detection/proposal.md
@@ -0,0 +1,57 @@
+# Proposals for object-detection modules
+
+### Algorithms
+
+By stage
+- Single stage: SSD,YOLO, RetinaNet
+- Two stage: Faster-RCNN, RFCN, MaskRCNN
+
+By data shape
+- Immutable data_shape: SSD, YOLO
+- Mutable data_shape: Faster-RCNN, RFCN, RetinaNet, MaskRCNN
+
+### Dataset
+- Pascal VOC
+- COCO
+- Imagenet
+- Custom dataset Wrapper
+
+### Network Builder
+- Anchor generator(CustomOp/function): SSD,YOLO, Faster-RCNN, RFCN, RetinaNet
+- Feature extractor(SymbolBlock)
+- Box predictor(Conv2D, Dense, Complex): predict box class/location
+- Target generator(CustomOp/function calling nd internally): SSD,YOLO, Faster-RCNN, RFCN, RetinaNet
+- Loss(gluon.loss): L1, L2, SmoothL1, CE loss, Focal loss
+- Anchor converter(CustomOp, function): anchor + prediction = output
+- Non-maximum-suppression(src/contrib)
+
+### Matcher(performance critical)
+*src/contrib/box_op-inl.h?*
+
+- IOU: overlap between anchors and labels. (How to handle padding?)
+- Areas: box area
+- Intersection: box intersection
+- Clip: clip box to region
+- Sampler: generating positive/negative/other samples
+- OHEM sampler: Hard negative mining
+- More
+
+### Iterator
+- Normal iterator(python iterator): SSD/YOLO iterator
+- Mutable iterator(wrapper for c++ iter?): batch_size=1, take arbitrary data shape. Usually don't need augmentation.
+- Mini-batch iterator: for rcnn variations.
+
+### Trainer
+Apply implementations details in each paper
+
+- Initializer (specific init patterns)
+- Reshape input after N epochs (YOLO2)
+- Warm up / refactor learning rate
+
+### Tester
+Allow arbitrary sized input?
+
+### Suger
+- Model zoo
+- Dataset downloader/loader
+- Configuration loader/saver(yaml)
diff --git a/example/object-detection/train.py b/example/object-detection/train.py
new file mode 100644
index 0000000000..04a7518f58
--- /dev/null
+++ b/example/object-detection/train.py
@@ -0,0 +1,74 @@
+"""Train Gluon Object-Detection models."""
+import os
+import argparse
+import mxnet as mx
+from trainer import train_ssd
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train a gluon detection network',
+                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--algorithm', dest='algorithm', type=str, default='ssd',
+                        help='which network to use')
+    parser.add_argument('--data-shape', dest='data_shape', type=str, default='512',
+                        help='image data shape, can be int or tuple')
+    parser.add_argument('--model', dest='model', type=str, default='resnet50_v1',
+                        help='base network to use, choices are models from gluon model_zoo')
+    parser.add_argument('--dataset', dest='dataset', type=str, default='voc',
+                        help='which dataset to use')
+    parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
+                        help='training batch size')
+    parser.add_argument('--pretrained', type=int, default=1,
+                        help='Whether use pretrained models. '
+                        ' 0: from scratch, 1: use base model, 2: use pretrained detection model')
+    parser.add_argument('--resume', dest='resume', type=int, default=-1,
+                        help='resume training from epoch n')
+    parser.add_argument('--prefix', dest='prefix', type=str, help='new model prefix',
+                        default=os.path.join(os.path.dirname(__file__), 'model', 'default'))
+    parser.add_argument('--gpus', dest='gpus', help='GPU devices to train with',
+                        default='0', type=str)
+    parser.add_argument('--end-epoch', dest='end_epoch', help='end epoch of training',
+                        default=240, type=int)
+    parser.add_argument('--interval', dest='log_interval', help='frequency of logging',
+                        default=20, type=int)
+    parser.add_argument('--lr', dest='learning_rate', type=float, default=0.004,
+                        help='learning rate')
+    parser.add_argument('--momentum', dest='momentum', type=float, default=0.9,
+                        help='momentum')
+    parser.add_argument('--wd', dest='weight_decay', type=float, default=0.0005,
+                        help='weight decay')
+    parser.add_argument('--log', dest='log_file', type=str, default="train.log",
+                        help='Save training log to file')
+    parser.add_argument('--seed', dest='seed', type=int, default=123,
+                        help="Random seed, -1 to disable fixed seed.")
+    parser.add_argument('--dev', type=int, default=0,
+                        help="Turn on develop mode with verbose informations.")
+    parser.add_argument('--lr-steps', dest='lr_steps', type=str, default='80, 160',
+                        help='refactor learning rate at specified epochs')
+    parser.add_argument('--lr-factor', dest='lr_factor', type=str, default='0.1',
+                        help='ratio to refactor learning rate, can be float or list of floats.')
+    # parser.add_argument('--freeze', dest='freeze_pattern', type=str, default="^(conv1_|conv2_).*",
+    #                     help='freeze layer pattern')
+    args = parser.parse_args()
+    return args
+
+if __name__ == '__main__':
+    args = parse_args()
+    # context list
+    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
+    ctx = [mx.cpu()] if not ctx else ctx
+    # choose algorithm
+    if args.algorithm.lower() == 'ssd':
+        model = '_'.join([args.algorithm, args.data_shape, args.model])
+        train_ssd.train_net(model, args.dataset, args.data_shape, args.batch_size,
+                            args.end_epoch, args.learning_rate, args.momentum,
+                            args.weight_decay,
+                            log_interval=args.log_interval,
+                            seed=args.seed,
+                            pretrained=args.pretrained,
+                            log_file=args.log_file,
+                            lr_steps=[int(x.strip()) for x in args.lr_steps.split(',') if x.strip()],
+                            lr_factor=[float(x.strip()) for x in args.lr_factor.split(',') if x.strip()],
+                            dev=args.dev,
+                            ctx=ctx)
+    else:
+        raise NotImplementedError("Training algorithm {} not supported.".format(args.algorithm))
diff --git a/example/object-detection/trainer/__init__.py b/example/object-detection/trainer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/example/object-detection/trainer/debugger.py b/example/object-detection/trainer/debugger.py
new file mode 100644
index 0000000000..189a2c201b
--- /dev/null
+++ b/example/object-detection/trainer/debugger.py
@@ -0,0 +1,14 @@
+import numpy as np
+
+def super_print(*args):
+    threshold = np.get_printoptions().get('threshold', 1000)
+    np.set_printoptions(threshold=np.inf)
+    for arg in args:
+        print(arg)
+    np.set_printoptions(threshold=threshold)
+
+def find_abnormal(arr):
+    pos = np.where(np.logical_not(np.isfinite(arr)))
+    if pos[0].size < 1:
+        return None
+    return pos
diff --git a/example/object-detection/trainer/metric.py b/example/object-detection/trainer/metric.py
new file mode 100644
index 0000000000..d57b71930a
--- /dev/null
+++ b/example/object-detection/trainer/metric.py
@@ -0,0 +1,143 @@
+import mxnet as mx
+from mxnet import ndarray
+from mxnet.metric import check_label_shapes
+import numpy as np
+
+
+class LossRecorder(mx.metric.EvalMetric):
+    """
+
+    """
+    def __init__(self, name):
+        super(LossRecorder, self).__init__(name)
+
+    def update(self, labels, preds=0):
+        """
+        """
+        for loss in labels:
+            if isinstance(loss, mx.nd.NDArray):
+                loss = loss.asnumpy()
+            self.sum_metric += loss.sum()
+            self.num_inst += 1
+
+class Accuracy(mx.metric.EvalMetric):
+    """
+
+    """
+    def __init__(self, axis=1, name='accuracy',
+                 output_names=None, label_names=None,
+                 ignore_label=-1):
+        super(Accuracy, self).__init__(
+            name, axis=axis,
+            output_names=output_names,
+            label_names=label_names)
+        self.axis = axis
+        self.ignore_label = ignore_label
+
+    def update(self, labels, preds):
+        """
+
+        """
+        check_label_shapes(labels, preds)
+
+        for label, pred_label in zip(labels, preds):
+            if pred_label.shape != label.shape:
+                pred_label = ndarray.argmax(pred_label, axis=self.axis)
+            pred_label = pred_label.asnumpy().astype('int32')
+            label = label.asnumpy().astype('int32')
+
+            check_label_shapes(label, pred_label)
+            correct = np.logical_and(
+                pred_label.flat == label.flat,
+                pred_label.flat != self.ignore_label)
+
+            self.sum_metric += correct.sum()
+            self.num_inst += np.sum(label != self.ignore_label)
+
+
+class SmoothL1(mx.metric.EvalMetric):
+    """
+    """
+    def __init__(self, name='smoothl1', output_names=None, label_names=None):
+        super(SmoothL1, self).__init__(
+            name, output_names=output_names, label_names=label_names)
+
+    def update(self, labels, preds):
+        """
+
+        """
+        check_label_shapes(labels, preds)
+
+        for label, pred in zip(labels, preds):
+            smoothl1 = ndarray.smooth_l1(label - pred, scalar=1.0).asnumpy()
+            label = label.asnumpy()
+
+            if len(label.shape) == 1:
+                label = label.reshape(label.shape[0], 1)
+
+            self.sum_metric += np.sum(smoothl1)
+            self.num_inst += np.sum(label > 0) # numpy.prod(label.shape)
+
+
+class MultiBoxMetric(mx.metric.EvalMetric):
+    """Calculate metrics for Multibox training """
+    def __init__(self, eps=1e-8):
+        super(MultiBoxMetric, self).__init__('MultiBox')
+        self.eps = eps
+        self.num = 2
+        self.name = ['CrossEntropy', 'SmoothL1']
+        self.reset()
+
+    def reset(self):
+        """
+        override reset behavior
+        """
+        if getattr(self, 'num', None) is None:
+            self.num_inst = 0
+            self.sum_metric = 0.0
+        else:
+            self.num_inst = [0] * self.num
+            self.sum_metric = [0.0] * self.num
+
+    def update(self, labels, preds):
+        """
+        Implementation of updating metrics
+        """
+        # get generated multi label from network
+        cls_prob = preds[0].asnumpy()
+        loc_loss = preds[1].asnumpy()
+        cls_label = preds[2].asnumpy()
+        valid_count = np.sum(cls_label >= 0)
+        # overall accuracy & object accuracy
+        label = cls_label.flatten()
+        mask = np.where(label >= 0)[0]
+        indices = np.int64(label[mask])
+        prob = cls_prob.transpose((0, 2, 1)).reshape((-1, cls_prob.shape[1]))
+        prob = prob[mask, indices]
+        self.sum_metric[0] += (-np.log(prob + self.eps)).sum()
+        self.num_inst[0] += valid_count
+        # smoothl1loss
+        self.sum_metric[1] += np.sum(loc_loss)
+        self.num_inst[1] += valid_count
+
+    def get(self):
+        """Get the current evaluation result.
+        Override the default behavior
+
+        Returns
+        -------
+        name : str
+           Name of the metric.
+        value : float
+           Value of the evaluation.
+        """
+        if self.num is None:
+            if self.num_inst == 0:
+                return (self.name, float('nan'))
+            else:
+                return (self.name, self.sum_metric / self.num_inst)
+        else:
+            names = ['%s'%(self.name[i]) for i in range(self.num)]
+            values = [x / y if y != 0 else float('nan') \
+                for x, y in zip(self.sum_metric, self.num_inst)]
+            return (names, values)
diff --git a/example/object-detection/trainer/train_ssd.py b/example/object-detection/trainer/train_ssd.py
new file mode 100644
index 0000000000..2be5432479
--- /dev/null
+++ b/example/object-detection/trainer/train_ssd.py
@@ -0,0 +1,237 @@
+import argparse
+import os
+import logging
+import time
+import random
+import mxnet as mx
+from mxnet import nd
+from mxnet import gluon
+from mxnet import autograd as ag
+from dataset.dataloader import DataLoader
+from dataset import VOCDetection
+from dataset import transform
+from config import config as cfg
+import model_zoo
+from block.loss import *
+from block.target import *
+from block.coder import MultiClassDecoder, NormalizedBoxCenterDecoder
+from trainer.metric import Accuracy, SmoothL1, LossRecorder, MultiBoxMetric
+from trainer.debugger import super_print, find_abnormal
+from evaluation.eval_metric import VOC07MApMetric, MApMetric
+
+def train_net(model, dataset, data_shape, batch_size, end_epoch, lr, momentum, wd, log_interval=50,
+              lr_steps=[], lr_factor=1.,
+              pretrained=0, seed=None, log_file=None, dev=False, ctx=mx.cpu(), **kwargs):
+    """Wrapper function for entire training phase.
+
+
+
+
+    """
+    if dev:
+        logging.basicConfig(level=logging.DEBUG)
+    else:
+        logging.basicConfig(level=logging.INFO)
+
+    if log_file:
+        logger = logging.getLogger()
+        fh = logging.FileHandler(log_file)
+        logger.addHandler(fh)
+
+    if isinstance(seed, int) and seed > 0:
+        random.seed(seed)
+
+    data_shape = [int(x) for x in data_shape.split(',')]
+    if len(data_shape) == 1:
+        data_shape = data_shape * 2
+
+    if dataset == 'voc':
+        # dataset
+        num_class = 20
+        t = transform.SSDAugmentation(data_shape)
+        t2 = transform.SSDValid(data_shape)
+        train_dataset = VOCDetection('./data/VOCdevkit', [(2007, 'trainval'), (2012, 'trainval')], transform=t)
+        # train_dataset = VOCDetection('./data/VOCdevkit', [(2007, 'train')], transform=transform)
+        val_dataset = VOCDetection('./data/VOCdevkit', [(2007, 'test')], transform=t2)
+    else:
+        raise NotImplementedError("Dataset {} not supported.".format(dataset))
+
+    train_data = DataLoader(train_dataset, batch_size, True, last_batch='rollover')
+    val_data = DataLoader(val_dataset, batch_size, False, last_batch='keep')
+
+    net = model_zoo.get_detection_model(model, pretrained=pretrained, classes=num_class)
+    if dev:
+        print(net)
+
+    def ctx_as_list(ctx):
+        if isinstance(ctx, mx.Context):
+            ctx = [ctx]
+        return ctx
+
+    if not isinstance(lr_factor, list):
+        lr_factor = [lr_factor]
+    if len(lr_factor) == 1 and len(lr_steps) > 1:
+        lr_factor *= len(lr_steps)
+
+    # logging.debug(str(val_dataset))
+    # for data in train_data:
+    #     import cv2
+    #     import numpy as np
+    #     for i in range(data[0].shape[0]):
+    #         img = data[0][i].asnumpy().transpose((1, 2, 0))[:, :, (2, 1, 0)].astype('uint8')
+    #         w, h, _ =  img.shape
+    #         label = data[1][i].asnumpy()
+    #         canvas = np.asarray(img.copy())
+    #         for j in range(label.shape[0]):
+    #             if label[j, 0] < 0:
+    #                 break
+    #             pt1 = (int(label[j, 1] * w), int(label[j, 2] * h))
+    #             pt2 = (int(label[j, 3] * w), int(label[j, 4] * w))
+    #             cv2.rectangle(canvas, pt1, pt2, (255, 0, 0), 2)
+    #         cv2.imshow('debug', canvas)
+    #         cv2.waitKey()
+
+
+
+
+    # monitor
+    # print(net.collect_params())
+    # raise
+    # checker = net.collect_params()['conv0_weight']
+    checker = net.collect_params()['stage3_conv1_weight']
+
+    box_decoder = NormalizedBoxCenterDecoder()
+    cls_decoder = MultiClassDecoder()
+
+    def evaluate_voc(net, val_data, ctx):
+        ctx = ctx_as_list(ctx)
+        results = []
+        valid_metric = VOC07MApMetric(class_names=val_dataset.classes)
+        for i, batch in enumerate(val_data):
+            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
+            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
+            for x, y in zip(data, label):
+                z = net(x)
+                cls_preds, box_preds, anchors = z
+                # print('box_preds',box_preds)
+                # print('anchors', anchors)
+                # boxes = box_decoder(box_preds, anchors)
+                # boxes = nd.clip(boxes, 0.0, 1.0)
+                # cls_ids, scores = cls_decoder(nd.softmax(cls_preds))
+                # result = nd.concat(cls_ids.reshape((0, 0, 1)), scores.reshape((0, 0, 1)), boxes, dim=2)
+                # print(boxes)
+                # print(cls_ids)
+                # print(scores)
+                # print(result)
+                # out = nd.contrib.box_nms(result, topk=400)
+                out = mx.nd.contrib.MultiBoxDetection(nd.softmax(cls_preds).transpose((0, 2, 1)), box_preds.reshape((0, -1)), anchors, nms_topk=400)
+                # print(out)
+                # results.append(out.asnumpy())
+                valid_metric.update([y], [out])
+        # results = np.vstack(results)
+        # write to disk for eval
+        return valid_metric.get()
+        # return val_dataset.eval_results(results)
+
+
+    # training process
+    def train(net, train_data, val_data, epochs, ctx=mx.cpu()):
+        ctx = ctx_as_list(ctx)
+        target_generator = SSDTargetGenerator()
+        box_weight = None
+        net.initialize(mx.init.Uniform(), ctx=ctx)
+        net.collect_params().reset_ctx(ctx)
+        net.hybridize()
+        trainer = gluon.Trainer(net.collect_params(), 'sgd',
+            {'learning_rate': lr, 'wd': wd, 'momentum':momentum})
+        # cls_loss = FocalLoss(num_class=(num_class+1), weight=1.0)
+        cls_loss = SoftmaxCrossEntropyLoss(size_average=False)
+        # box_loss = gluon.loss.L1Loss()
+        box_loss = SmoothL1Loss(weight=box_weight, size_average=False)
+        cls_metric = Accuracy(axis=-1, ignore_label=-1)
+        box_metric = SmoothL1()
+        cls_metric1 = LossRecorder('CrossEntropy')
+        box_metric1 = LossRecorder('SmoothL1Loss')
+        # debug_metric = MultiBoxMetric()
+
+        for epoch in range(epochs):
+            if epoch in lr_steps:
+                new_lr = trainer.learning_rate * lr_factor[lr_steps.index(epoch)]
+                trainer.set_learning_rate(new_lr)
+                logging.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
+            tic = time.time()
+            btic = time.time()
+            cls_metric.reset()
+            cls_metric1.reset()
+            box_metric.reset()
+            box_metric1.reset()
+            for i, batch in enumerate(train_data):
+                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
+                label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
+                outputs = []
+                labels = []
+                box_preds = []
+                box_labels = []
+                losses1 = []
+                losses2 = []
+                Ls = []
+                with ag.record():
+                    for x, y in zip(data, label):
+                        # x = nd.cast(x, dtype)
+                        # y = nd.cast(y, dtype)
+                        z = net(x)
+                        with ag.pause():
+                            cls_targets, box_targets, box_masks = target_generator(z, y)
+                            valid_cls = nd.sum(cls_targets >= 0, axis=0, exclude=True)
+                            valid_cls = nd.maximum(valid_cls, nd.ones_like(valid_cls))
+                            valid_box = nd.sum(box_masks > 0, axis=0, exclude=True)
+                        # super_print(y, cls_targets, box_targets)
+                        # raise
+                        loss1 = cls_loss(z[0], cls_targets)
+                        # valid_cls1 = nd.sum(valid_cls).asscalar()
+                        # print(valid_cls1)
+                        # loss1 = loss1 * cls_targets.shape[1] / valid_cls
+                        loss1 = loss1 / valid_cls
+                        loss2 = box_loss(z[1] * box_masks, box_targets)
+                        # loss2 = loss2 * box_masks.shape[1] / valid_cls
+                        loss2 = loss2 / valid_box
+                        L = loss1 + loss2
+                        # L = loss1
+                        Ls.append(L)
+                        outputs.append(z[0])
+                        labels.append(cls_targets)
+                        box_preds.append(z[1] * box_masks)
+                        box_labels.append(box_targets)
+                        losses1.append(loss1)
+                        losses2.append(loss2)
+                    ag.backward(Ls)
+                batch_size = batch[0].shape[0]
+                trainer.step(batch_size, ignore_stale_grad=True)
+                cls_metric.update(labels, outputs)
+                # box_metric.update(box_labels, box_preds)
+                cls_metric1.update(losses1)
+                box_metric1.update(losses2)
+                # debug_metric.update(cls_targets, [nd.softmax(z[0]).transpose((0, 2, 1)), nd.smooth_l1((z[1] - box_targets) * box_masks, scalar=1.0), cls_targets])
+                if log_interval and not (i + 1) % log_interval:
+                    # print(checker.grad())
+                    name, acc = cls_metric.get()
+                    name1, mae = box_metric.get()
+                    name2, focalloss = cls_metric1.get()
+                    name3, smoothl1loss = box_metric1.get()
+                    logging.info("Epoch [%d] Batch [%d], Speed: %f samples/sec, %s=%f, %s=%f, %s=%f"%(epoch, i, batch_size/(time.time()-btic), name, acc, name2, focalloss, name3, smoothl1loss))
+                    # names, values = debug_metric.get()
+                    # logging.info("%s=%f, %s=%f"%(names[0], values[0], names[1], values[1]))
+                btic = time.time()
+
+            name, acc = cls_metric.get()
+            name1, mae = box_metric1.get()
+            name2, focalloss = cls_metric1.get()
+            net.collect_params().save(os.path.join(os.path.dirname(__file__), '..', 'model', 'ssd.params'))
+            logging.info('[Epoch %d] training: %s=%f, %s=%f, %s=%f'%(epoch, name, acc, name1, mae, name2, focalloss))
+            logging.info('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
+            map_name, mean_ap = evaluate_voc(net, val_data, ctx)
+            # name, val_acc = test(ctx)
+            for name, ap in zip(map_name, mean_ap):
+                logging.info('[Epoch %d] validation: %s=%f'%(epoch, name, ap))
+
+    train(net, train_data, val_data, end_epoch, ctx=ctx)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services