huajsj commented on a change in pull request #6367:
URL: https://github.com/apache/incubator-tvm/pull/6367#discussion_r493024473



##########
File path: python/tvm/relay/testing/yolo_detection.py
##########
@@ -196,41 +196,91 @@ def do_nms_sort(dets, classes, thresh):
                     dets[j]["prob"][k] = 0
 
 
+def get_detections(im, det, thresh, names, classes):
+    "Draw the markings around the detected region"
+    labelstr = []
+    category = -1
+    detection = None
+    valid = False
+    for j in range(classes):
+        if det["prob"][j] > thresh:
+            if category == -1:
+                category = j
+            labelstr.append(names[j] + " " + str(round(det["prob"][j], 4)))
+
+    if category > -1:
+        valid = True
+        imc, imh, imw = im.shape
+        width = int(imh * 0.006)
+        offset = category * 123457 % classes
+        red = _get_color(2, offset, classes)
+        green = _get_color(1, offset, classes)
+        blue = _get_color(0, offset, classes)
+        rgb = [red, green, blue]
+        b = det["bbox"]
+        left = int((b.x - b.w / 2.0) * imw)
+        right = int((b.x + b.w / 2.0) * imw)
+        top = int((b.y - b.h / 2.0) * imh)
+        bot = int((b.y + b.h / 2.0) * imh)
+
+        if left < 0:
+            left = 0
+        if right > imw - 1:
+            right = imw - 1
+        if top < 0:
+            top = 0
+        if bot > imh - 1:
+            bot = imh - 1
+
+        detection = {
+            "category": category,
+            "labelstr": labelstr,
+            "left": left,
+            "top": top,
+            "right": right,
+            "bot": bot,
+            "width": width,
+            "rgb": rgb,
+        }
+
+    return valid, detection
+
+
 def draw_detections(font_path, im, dets, thresh, names, classes):
     "Draw the markings around the detected region"
     for det in dets:
-        labelstr = []
-        category = -1
-        for j in range(classes):
-            if det["prob"][j] > thresh:
-                if category == -1:
-                    category = j
-                labelstr.append(names[j] + " " + str(round(det["prob"][j], 4)))
-        if category > -1:
-            imc, imh, imw = im.shape
-            width = int(imh * 0.006)
-            offset = category * 123457 % classes
-            red = _get_color(2, offset, classes)
-            green = _get_color(1, offset, classes)
-            blue = _get_color(0, offset, classes)
-            rgb = [red, green, blue]
-            b = det["bbox"]
-            left = int((b.x - b.w / 2.0) * imw)
-            right = int((b.x + b.w / 2.0) * imw)
-            top = int((b.y - b.h / 2.0) * imh)
-            bot = int((b.y + b.h / 2.0) * imh)
-
-            if left < 0:
-                left = 0
-            if right > imw - 1:
-                right = imw - 1
-            if top < 0:
-                top = 0
-            if bot > imh - 1:
-                bot = imh - 1
-            _draw_box_width(im, left, top, right, bot, width, red, green, blue)
-            label = _get_label(font_path, "".join(labelstr), rgb)
-            _draw_label(im, top + width, left, label, rgb)
+        valid, detection = get_detections(im, det, thresh, names, classes)
+        if valid:
+            rgb = detection["rgb"]
+            label = _get_label(font_path, "".join(detection["labelstr"]), rgb)
+            _draw_box_width(
+                im,
+                detection["left"],
+                detection["top"],
+                detection["right"],
+                detection["bot"],
+                detection["width"],
+                rgb[0],
+                rgb[1],
+                rgb[2],
+            )
+            _draw_label(im, detection["top"] + detection["width"], 
detection["left"], label, rgb)
+
+
+def show_detections(im, dets, thresh, names, classes):

Review comment:
       fixed.

##########
File path: python/tvm/relay/testing/yolo_detection.py
##########
@@ -196,41 +196,91 @@ def do_nms_sort(dets, classes, thresh):
                     dets[j]["prob"][k] = 0
 
 
+def get_detections(im, det, thresh, names, classes):
+    "Draw the markings around the detected region"
+    labelstr = []
+    category = -1
+    detection = None
+    valid = False
+    for j in range(classes):
+        if det["prob"][j] > thresh:
+            if category == -1:
+                category = j
+            labelstr.append(names[j] + " " + str(round(det["prob"][j], 4)))
+
+    if category > -1:
+        valid = True
+        imc, imh, imw = im.shape
+        width = int(imh * 0.006)
+        offset = category * 123457 % classes
+        red = _get_color(2, offset, classes)
+        green = _get_color(1, offset, classes)
+        blue = _get_color(0, offset, classes)
+        rgb = [red, green, blue]
+        b = det["bbox"]
+        left = int((b.x - b.w / 2.0) * imw)
+        right = int((b.x + b.w / 2.0) * imw)
+        top = int((b.y - b.h / 2.0) * imh)
+        bot = int((b.y + b.h / 2.0) * imh)
+
+        if left < 0:
+            left = 0
+        if right > imw - 1:
+            right = imw - 1
+        if top < 0:
+            top = 0
+        if bot > imh - 1:
+            bot = imh - 1
+
+        detection = {
+            "category": category,
+            "labelstr": labelstr,
+            "left": left,
+            "top": top,
+            "right": right,
+            "bot": bot,
+            "width": width,
+            "rgb": rgb,
+        }
+
+    return valid, detection
+
+
 def draw_detections(font_path, im, dets, thresh, names, classes):
     "Draw the markings around the detected region"
     for det in dets:
-        labelstr = []
-        category = -1
-        for j in range(classes):
-            if det["prob"][j] > thresh:
-                if category == -1:
-                    category = j
-                labelstr.append(names[j] + " " + str(round(det["prob"][j], 4)))
-        if category > -1:
-            imc, imh, imw = im.shape
-            width = int(imh * 0.006)
-            offset = category * 123457 % classes
-            red = _get_color(2, offset, classes)
-            green = _get_color(1, offset, classes)
-            blue = _get_color(0, offset, classes)
-            rgb = [red, green, blue]
-            b = det["bbox"]
-            left = int((b.x - b.w / 2.0) * imw)
-            right = int((b.x + b.w / 2.0) * imw)
-            top = int((b.y - b.h / 2.0) * imh)
-            bot = int((b.y + b.h / 2.0) * imh)
-
-            if left < 0:
-                left = 0
-            if right > imw - 1:
-                right = imw - 1
-            if top < 0:
-                top = 0
-            if bot > imh - 1:
-                bot = imh - 1
-            _draw_box_width(im, left, top, right, bot, width, red, green, blue)
-            label = _get_label(font_path, "".join(labelstr), rgb)
-            _draw_label(im, top + width, left, label, rgb)
+        valid, detection = get_detections(im, det, thresh, names, classes)
+        if valid:
+            rgb = detection["rgb"]
+            label = _get_label(font_path, "".join(detection["labelstr"]), rgb)
+            _draw_box_width(
+                im,
+                detection["left"],
+                detection["top"],
+                detection["right"],
+                detection["bot"],
+                detection["width"],
+                rgb[0],
+                rgb[1],
+                rgb[2],
+            )
+            _draw_label(im, detection["top"] + detection["width"], 
detection["left"], label, rgb)
+
+
+def show_detections(im, dets, thresh, names, classes):
+    "Draw the markings around the detected region"

Review comment:
       fixed.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to