jwfromm commented on a change in pull request #7440:
URL: https://github.com/apache/tvm/pull/7440#discussion_r574901161



##########
File path: python/tvm/topi/testing/roi_align_python.py
##########
@@ -76,11 +78,20 @@ def _bilinear(n, c, y, x):
         for c in range(channel):
             for ph in range(pooled_size_h):
                 for pw in range(pooled_size_w):
-                    total = 0.0
-                    for iy in range(roi_bin_grid_h):
-                        for ix in range(roi_bin_grid_w):
-                            y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h 
/ roi_bin_grid_h
-                            x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w 
/ roi_bin_grid_w
-                            total += _bilinear(batch_index, c, y, x)
-                    b_np[i, c, ph, pw] = total / count
+                    if avg_mode:
+                        total = 0.0
+                        for iy in range(roi_bin_grid_h):
+                            for ix in range(roi_bin_grid_w):
+                                y = roi_start_h + ph * bin_h + (iy + 0.5) * 
bin_h / roi_bin_grid_h
+                                x = roi_start_w + pw * bin_w + (ix + 0.5) * 
bin_w / roi_bin_grid_w
+                                total += _bilinear(batch_index, c, y, x)
+                        b_np[i, c, ph, pw] = total / count
+                    elif max_mode:
+                        total = 0.0
+                        for iy in range(roi_bin_grid_h):
+                            for ix in range(roi_bin_grid_w):
+                                y = roi_start_h + ph * bin_h + (iy + 0.5) * 
bin_h / roi_bin_grid_h
+                                x = roi_start_w + pw * bin_w + (ix + 0.5) * 
bin_w / roi_bin_grid_w
+                                total = max(total, _bilinear(batch_index, c, 
y, x))

Review comment:
       this section could have less code duplication by moving where you check 
the mode.

##########
File path: python/tvm/topi/x86/roi_align.py
##########
@@ -161,47 +167,83 @@ def roi_align_nchw_ir(data, rois, num_rois, w_pc, pos_pc, 
pooled_size, spatial_s
             for ph in range(pooled_size_h):
                 for pw in range(pooled_size_w):
                     output_val = 0.0
-                    for iy in range(roi_bin_grid_h):
-                        for ix in range(roi_bin_grid_w):
-                            output_val += (
-                                w_pc[n, pre_calc_index, 0]
-                                * data[
-                                    roi_batch_index,
-                                    c,
-                                    pos_pc[n, pre_calc_index, 2],
-                                    pos_pc[n, pre_calc_index, 0],
-                                ]
-                                + w_pc[n, pre_calc_index, 1]
-                                * data[
-                                    roi_batch_index,
-                                    c,
-                                    pos_pc[n, pre_calc_index, 2],
-                                    pos_pc[n, pre_calc_index, 1],
-                                ]
-                                + w_pc[n, pre_calc_index, 2]
-                                * data[
-                                    roi_batch_index,
-                                    c,
-                                    pos_pc[n, pre_calc_index, 3],
-                                    pos_pc[n, pre_calc_index, 0],
-                                ]
-                                + w_pc[n, pre_calc_index, 3]
-                                * data[
-                                    roi_batch_index,
-                                    c,
-                                    pos_pc[n, pre_calc_index, 3],
-                                    pos_pc[n, pre_calc_index, 1],
-                                ]
-                            )
-                            pre_calc_index += 1
-
-                    output_val /= count
-                    output[n, c, ph, pw] = output_val
-
+                    if mode == 0:
+                        for iy in range(roi_bin_grid_h):
+                            for ix in range(roi_bin_grid_w):
+                                output_val += (
+                                    w_pc[n, pre_calc_index, 0]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 2],
+                                        pos_pc[n, pre_calc_index, 0],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 1]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 2],
+                                        pos_pc[n, pre_calc_index, 1],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 2]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 3],
+                                        pos_pc[n, pre_calc_index, 0],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 3]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 3],
+                                        pos_pc[n, pre_calc_index, 1],
+                                    ]
+                                )
+                                pre_calc_index += 1
+
+                        output_val /= count
+                        output[n, c, ph, pw] = output_val
+                    elif mode == 1:
+                        for iy in range(roi_bin_grid_h):
+                            for ix in range(roi_bin_grid_w):
+                                bilinear_val = (
+                                    w_pc[n, pre_calc_index, 0]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 2],
+                                        pos_pc[n, pre_calc_index, 0],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 1]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 2],
+                                        pos_pc[n, pre_calc_index, 1],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 2]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 3],
+                                        pos_pc[n, pre_calc_index, 0],
+                                    ]
+                                    + w_pc[n, pre_calc_index, 3]
+                                    * data[
+                                        roi_batch_index,
+                                        c,
+                                        pos_pc[n, pre_calc_index, 3],
+                                        pos_pc[n, pre_calc_index, 1],
+                                    ]
+                                )
+                                pre_calc_index += 1
+                                output_val = max(output_val, bilinear_val)

Review comment:
       I think a lot of the code here could also be consolidated.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to