You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/12 00:21:19 UTC

[GitHub] [tvm] jwfromm commented on a change in pull request #7440: [Relay][Topi] Add max mode to ROI align

jwfromm commented on a change in pull request #7440:
URL: https://github.com/apache/tvm/pull/7440#discussion_r574901161



##########
File path: python/tvm/topi/testing/roi_align_python.py
##########
@@ -76,11 +78,20 @@ def _bilinear(n, c, y, x):
         for c in range(channel):
             for ph in range(pooled_size_h):
                 for pw in range(pooled_size_w):
-                    total = 0.0
-                    for iy in range(roi_bin_grid_h):
-                        for ix in range(roi_bin_grid_w):
-                            y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
-                            x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
-                            total += _bilinear(batch_index, c, y, x)
-                    b_np[i, c, ph, pw] = total / count
+                    if avg_mode:
+                        total = 0.0
+                        for iy in range(roi_bin_grid_h):
+                            for ix in range(roi_bin_grid_w):
+                                y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
+                                x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
+                                total += _bilinear(batch_index, c, y, x)
+                        b_np[i, c, ph, pw] = total / count
+                    elif max_mode:
+                        total = 0.0
+                        for iy in range(roi_bin_grid_h):
+                            for ix in range(roi_bin_grid_w):
+                                y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
+                                x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
+                                total = max(total, _bilinear(batch_index, c, y, x))

Review comment:
       this section could have less code duplication by moving where you check the mode.

##########
File path: python/tvm/topi/x86/roi_align.py
##########
@@ -161,47 +167,83 @@ def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, pooled_size, spatial_s
             for ph in range(pooled_size_h):
                 for pw in range(pooled_size_w):
                     output_val = 0.0
-                    for iy in range(roi_bin_grid_h):
-                        for ix in range(roi_bin_grid_w):
-                            output_val += (
-                                w_pc[n, pre_calc_index, 0]
-                                * data[
-                                    roi_batch_index,
-                                    c,
-                                    pos_pc[n, pre_calc_index, 2],
-                                    pos_pc[n, pre_calc_index, 0],
-                                ]
-                                + w_pc[n, pre_calc_index, 1]
-                                * data[
-                                    roi_batch_index,
-                                    c,
-                                    pos_pc[n, pre_calc_index, 2],
-                                    pos_pc[n, pre_calc_index, 1],
-                                ]
-                                + w_pc[n, pre_calc_index, 2]
-                                * data[
-                                    roi_batch_index,
-                                    c,
-                                    pos_pc[n, pre_calc_index, 3],
-                                    pos_pc[n, pre_calc_index, 0],
-                                ]
-                                + w_pc[n, pre_calc_index, 3]
-                                * data[
-                                    roi_batch_index,
-                                    c,
-                                    pos_pc[n, pre_calc_index, 3],
-                                    pos_pc[n, pre_calc_index, 1],
-                                ]
-                            )
-                            pre_calc_index += 1
-
-                    output_val /= count
-                    output[n, c, ph, pw] = output_val
-
+                    if mode == 0:
+                        for iy in range(roi_bin_grid_h):
+                            for ix in range(roi_bin_grid_w):
+                                output_val += (
+                                    w_pc[n, pre_calc_index, 0]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 2],
+                                        pos_pc[n, pre_calc_index, 0],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 1]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 2],
+                                        pos_pc[n, pre_calc_index, 1],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 2]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 3],
+                                        pos_pc[n, pre_calc_index, 0],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 3]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 3],
+                                        pos_pc[n, pre_calc_index, 1],
+                                    ]
+                                )
+                                pre_calc_index += 1
+
+                        output_val /= count
+                        output[n, c, ph, pw] = output_val
+                    elif mode == 1:
+                        for iy in range(roi_bin_grid_h):
+                            for ix in range(roi_bin_grid_w):
+                                bilinear_val = (
+                                    w_pc[n, pre_calc_index, 0]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 2],
+                                        pos_pc[n, pre_calc_index, 0],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 1]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 2],
+                                        pos_pc[n, pre_calc_index, 1],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 2]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 3],
+                                        pos_pc[n, pre_calc_index, 0],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 3]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 3],
+                                        pos_pc[n, pre_calc_index, 1],
+                                    ]
+                                )
+                                pre_calc_index += 1
+                                output_val = max(output_val, bilinear_val)

Review comment:
       I think a lot of the code here could also be consolidated.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org