xref: /OK3568_Linux_fs/external/rknn-toolkit2/examples/functions/hybrid_quant/ssd_post_process.py (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1import numpy as np
2import cv2
3from rknn.api import RKNN
4import math
5import PIL.Image as Image
6import PIL.ImageDraw as ImageDraw
7import PIL.ImageFont as ImageFont
8import re
9
10np.set_printoptions(threshold=np.inf)
11
12CLASSES = ('__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
13           'traffic light', 'fire hydrant', '???', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
14           'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', '???', 'backpack', 'umbrella', '???', '???',
15           'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
16           'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', '???', 'wine glass', 'cup', 'fork',
17           'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
18           'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', '???', 'dining table', '???', '???', 'toilet',
19           '???', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
20           'refrigerator', '???', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
21
22NUM_CLS = 91
23
24CONF_THRESH = 0.5
25NMS_THRESH = 0.45
26TOP_BOXES = 100
27max_boxes_to_draw = 100
28
29Y_SCALE = 10.0
30X_SCALE = 10.0
31H_SCALE = 5.0
32W_SCALE = 5.0
33
34prior_file = './box_priors.txt'
35
36box_priors_ = []
37fp = open(prior_file, 'r')
38ls = fp.readlines()
39for s in ls:
40    aList = re.findall('([-+]?\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?', s)
41    for ss in aList:
42        aNum = float((ss[0] + ss[2]))
43        box_priors_.append(aNum)
44fp.close()
45
46
47def softmax(x):
48    return np.exp(x) / np.sum(np.exp(x), axis=0)
49
50
51def IntersectBBox(box1, box2):
52    if box1[0] > box2[2] or box1[2] < box2[0] or box1[1] > box2[3] or box1[3] < box2[1]:
53        return 0
54    else:
55        area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
56        area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
57
58        xx1 = max(box1[0], box2[0])
59        yy1 = max(box1[1], box2[1])
60        xx2 = min(box1[2], box2[2])
61        yy2 = min(box1[3], box2[3])
62
63        w = max(0, xx2 - xx1)
64        h = max(0, yy2 - yy1)
65
66        ovr = w * h / (area1 + area2 - w * h + 0.000001)
67        return ovr
68
69
70def ssd_post_process(conf_data, loc_data, imgpath, output_dir='.'):
71    prior_num = int(len(loc_data) / 4)  # num prior boxes
72
73    prior_bboxes = np.array(box_priors_)
74    prior_bboxes = prior_bboxes.reshape(4, prior_num)
75
76    conf_data = conf_data.reshape(-1, NUM_CLS)
77
78    for i in range(prior_num):
79        conf_data[i] = softmax(conf_data[i])
80
81    idx_class_conf = []
82    bboxes = []
83
84    # conf
85    for prior_idx in range(0, prior_num):
86        conf_data[prior_idx][0] = 0
87        max_val = np.max(conf_data[prior_idx])
88        max_idx = np.argmax(conf_data[prior_idx])
89        if max_val > CONF_THRESH:
90            idx_class_conf.append([prior_idx, max_idx, max_val])
91
92    idx_class_conf_sorted = sorted(idx_class_conf, key=lambda x: x[2], reverse=True)
93
94    idx_class_conf = idx_class_conf_sorted[:min(TOP_BOXES, len(idx_class_conf_sorted))]
95
96    # boxes
97    for i in range(0, prior_num):
98        bbox_center_x = loc_data[4 * i + 1] / X_SCALE * prior_bboxes[3][i] + prior_bboxes[1][i]
99        bbox_center_y = loc_data[4 * i + 0] / Y_SCALE * prior_bboxes[2][i] + prior_bboxes[0][i]
100        bbox_w = math.exp(loc_data[4 * i + 3] / W_SCALE) * prior_bboxes[3][i]
101        bbox_h = math.exp(loc_data[4 * i + 2] / H_SCALE) * prior_bboxes[2][i]
102
103        tmp = []
104        tmp.append(max(min(bbox_center_x - bbox_w / 2., 1), 0))
105        tmp.append(max(min(bbox_center_y - bbox_h / 2., 1), 0))
106        tmp.append(max(min(bbox_center_x + bbox_w / 2., 1), 0))
107        tmp.append(max(min(bbox_center_y + bbox_h / 2., 1), 0))
108        bboxes.append(tmp)
109
110    # nms
111    cur_class_num = 0
112    idx_class_conf_ = []
113    for i in range(0, len(idx_class_conf)):
114        keep = True
115        k = 0
116        while k < cur_class_num:
117            if keep:
118                ovr = IntersectBBox(bboxes[idx_class_conf[i][0]], bboxes[idx_class_conf_[k][0]])
119                if idx_class_conf_[k][1] == idx_class_conf[i][1] and ovr > NMS_THRESH:
120                    keep = False
121                    break
122                k += 1
123            else:
124                break
125        if keep:
126            idx_class_conf_.append(idx_class_conf[i])
127            cur_class_num += 1
128    idx_class_conf_ = idx_class_conf_[:min(len(idx_class_conf_), max_boxes_to_draw)]
129
130    box_class_score = []
131
132    for i in range(0, len(idx_class_conf_)):
133        bboxes[idx_class_conf_[i][0]].append(idx_class_conf_[i][1])
134        bboxes[idx_class_conf_[i][0]].append(idx_class_conf_[i][2])
135        box_class_score.append(bboxes[idx_class_conf_[i][0]])
136
137    img = cv2.imread(imgpath)
138    img_pil = Image.fromarray(img)
139    draw = ImageDraw.Draw(img_pil)
140
141    font = ImageFont.load_default()
142
143    name = imgpath.split("/")[-1][:-4]
144
145    for i in range(0, len(box_class_score)):
146        x1 = box_class_score[i][0] * img.shape[1]
147        y1 = box_class_score[i][1] * img.shape[0]
148        x2 = box_class_score[i][2] * img.shape[1]
149        y2 = box_class_score[i][3] * img.shape[0]
150
151        # draw rect
152        color = (0, int(box_class_score[i][4] / 20.0 * 255), 255)
153        draw.line([(x1, y1), (x1, y2), (x2, y2),
154                   (x2, y1), (x1, y1)], width=2, fill=color)
155
156        display_str = CLASSES[box_class_score[i][4]] + ":" + str('%.2f' % box_class_score[i][5])
157        display_str_height = np.ceil((1 + 2 * 0.05) * font.getsize(display_str)[1]) + 1
158
159        if y1 > display_str_height:
160            text_bottom = y1
161        else:
162            text_bottom = y1 + display_str_height
163
164        text_width, text_height = font.getsize(display_str)
165        margin = np.ceil(0.05 * text_height)
166        draw.rectangle([(x1, text_bottom - text_height - 2 * margin), (x1 + text_width, text_bottom)], fill=color)
167        draw.text((x1 + margin, text_bottom - text_height - margin), display_str, fill='black', font=font)
168
169    print('write output image: {}{}_quant.jpg'.format(output_dir, name))
170    np.copyto(img, np.array(img_pil))
171    cv2.imwrite("{}{}_quant.jpg".format(output_dir, name), img)
172    print('write output image finished.')
173