You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/25 18:24:03 UTC

[GitHub] [incubator-tvm] jwfromm commented on a change in pull request #4417: [TOPI][RELAY][OP] add op crop_and_resize

jwfromm commented on a change in pull request #4417: [TOPI][RELAY][OP] add op crop_and_resize
URL: https://github.com/apache/incubator-tvm/pull/4417#discussion_r350346119
 
 

 ##########
 File path: topi/python/topi/image/resize.py
 ##########
 @@ -210,3 +210,185 @@ def _bicubic(*indices):
         raise ValueError('%s method is not supported.' % method)
 
     return tvm.compute(output_shape, compute_func, name='resize', tag=tag.INJECTIVE)
+
+
+def crop_and_resize(data, boxes, box_indices, crop_size, layout="NCHW",
+                    method="bilinear", extrapolation_value=0, out_dtype=None):
+    """Perform crop and resize operation on the data.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        inputs is a 4-D tensor with shape
+        [batch, channel, in_height, in_width]
+        or  [batch, in_height, in_width, channel]
+
+    boxes : tvm.Tensor
+        A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies
+        the coordinates of a box.
+
+    box_indices : tvm.Tensor
+        A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that
+        the i-th box refers to.
+
+    crop_size : Tuple
+        The target size of each box.
+
+    layout : string, optional
+        "NCHW", "NHWC"
+
+    method : {"bilinear", "nearest_neighbor"}
+        Method to be used for resizing.
+
+    extrapolation_value: float, optional
+        Value used for extrapolation, when applicable.
+
+    out_dtype : string, optional
+        Type to return. If left None will be same as input type.
+
+    Returns
+    -------
+    output : tvm.Tensor
+        4-D with shape [num_boxes, channel, crop_height, crop_width]
+        or [num_boxes, crop_height, crop_width, channel]
+    """
+    method = method.lower()
+    target_h = crop_size[0]
+    target_w = crop_size[1]
+
+    if layout == 'NHWC':
+        output_shape = [box_indices.shape[0], crop_size[0], crop_size[1], data.shape[3]]
+        image_height = data.shape[1]
+        image_width = data.shape[2]
+    elif layout == 'NCHW':
+        output_shape = [box_indices.shape[0], data.shape[1], crop_size[0], crop_size[1]]
+        image_height = data.shape[2]
+        image_width = data.shape[3]
+    # Otherwise layout must be NCHWxc
+    else:
+        output_shape = [box_indices.shape[0], data.shape[1],
+                        crop_size[0], crop_size[1], data.shape[4]]
+        image_height = data.shape[2]
+        image_width = data.shape[3]
+
+    def _get_pixel(n, c, y, x, cc):
+        if layout.lower() == 'nhwc':
+            return data(n, y.astype("int32"), x.astype("int32"), c).astype('float')
+        if layout.lower() == 'nchw':
+            return data(n, c, y.astype("int32"), x.astype("int32")).astype('float')
+        # else must be NCHWxc
+        return data(n, c, y.astype("int32"), x.astype("int32"), cc).astype('float')
+
+    def _get_indices(*indices):
+        if layout == 'NHWC':
+            n, y, x, c = indices
+            cc = None
+        elif layout == 'NCHW':
+            n, c, y, x = indices
+            cc = None
+        else:
+            n, c, y, x, cc = indices
+
+        return n, c, y, x, cc
+
+    def _cast_output(value):
+        if out_dtype:
+            dtype = out_dtype
+        else:
+            dtype = data.dtype
+        return value.astype(dtype)
+
+    # Nearest neighbor computation
+    def _nearest_neighbor(*indices):
+        n, c, y, x, cc = _get_indices(*indices)
+        box_idx = box_indices(n)
+
+        y1, x1 = boxes(n, 0), boxes(n, 1)
+        y2, x2 = boxes(n, 2), boxes(n, 3)
+
+        in_h = (image_height - 1) * (y2 - y1)
+        in_w = (image_width - 1) * (x2 - x1)
+        h_scale = tvm.div(in_h, target_h - 1)
+        w_scale = tvm.div(in_w, target_w - 1)
+
+        in_y = y1 * (image_height - 1) + h_scale * y
+        in_x = x1 * (image_width - 1) + w_scale * x
+        closest_x_index = tvm.round(in_x)
+        closest_y_index = tvm.round(in_y)
+
+        value = _get_pixel(box_idx, c, closest_y_index, closest_x_index, cc)
+        out_y = tvm.if_then_else(in_y < 0,
+                                 extrapolation_value,
+                                 tvm.if_then_else(in_y > image_height - 1,
+                                                  extrapolation_value,
+                                                  value))
+
+        # use extrapolation_value if in_x is out of boundary
+        out = tvm.if_then_else(in_x < 0,
+                               extrapolation_value,
+                               tvm.if_then_else(in_x > image_width - 1,
+                                                extrapolation_value,
+                                                out_y))
+        return _cast_output(out)
+
+
+    # Bilinear helper functions and computation.
+    def _lerp(A, B, t):
+        return A * (1.0 - t) + B * t
+
+    def _bilinear(*indices):
+        n, c, y, x, cc = _get_indices(*indices)
+        box_idx = box_indices(n)
+
+        y1, x1 = boxes(n, 0), boxes(n, 1)
+        y2, x2 = boxes(n, 2), boxes(n, 3)
+
+        in_h = (image_height - 1) * (y2 - y1)
+        in_w = (image_width - 1) * (x2 - x1)
+        h_scale = tvm.div(in_h, target_h - 1)
+        w_scale = tvm.div(in_w, target_w - 1)
+
+        in_y = y1 * (image_height - 1) + h_scale * y
+        in_x = x1 * (image_width - 1) + w_scale * x
+
+        top_y_index = tvm.floor(in_y).astype('int32')
+        bottom_y_index = tvm.ceil(in_y).astype('int32')
+        y_lerp = in_y - top_y_index
+
+        left_x_index = tvm.floor(in_x)
+        right_x_index = tvm.ceil(in_x)
+        x_lerp = in_x - left_x_index
+
+        top_left = _get_pixel(box_idx, c, top_y_index, left_x_index, cc)
+        top_right = _get_pixel(box_idx, c, top_y_index, right_x_index, cc)
+        bottom_left = _get_pixel(box_idx, c, bottom_y_index, left_x_index, cc)
+        bottom_right = _get_pixel(box_idx, c, bottom_y_index, right_x_index, cc)
+
+        top = _lerp(top_left, top_right, x_lerp)
+        bottom = _lerp(bottom_left, bottom_right, x_lerp)
+        value = _lerp(top, bottom, y_lerp)
+
+        # use extrapolation_value if in_y is out of boundary
+        out_y = tvm.if_then_else(in_y < 0,
+                                 extrapolation_value,
+                                 tvm.if_then_else(in_y > image_height - 1,
+                                                  extrapolation_value,
+                                                  value))
+
+        # use extrapolation_value if in_x is out of boundary
+        out = tvm.if_then_else(in_x < 0,
+                               extrapolation_value,
+                               tvm.if_then_else(in_x > image_width - 1,
+                                                extrapolation_value,
+                                                out_y))
+        return _cast_output(out)
+
+    # Determine which interpolation method to use then run it.
+    if method == "nearest_neighbor":
+        compute_func = _nearest_neighbor
+    elif method == "bilinear":
+        compute_func = _bilinear
+    else:
+        raise ValueError('%s method is not supported.' % method)
+
+    return tvm.compute(output_shape, compute_func, name='crop_and_resize', tag=tag.INJECTIVE)
 
 Review comment:
   A lot of this code is heavily duplicated from the regular resize function. It would make maintenance much easier if we refactored a little so that both resize and resize_and_crop called into the same nearest_neighbor, bilinear, and bicubic kernels. Maybe we can define more general helper functions that will work in both cases.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services