xref: /OK3568_Linux_fs/external/rknn-toolkit2/examples/caffe/vgg-ssd/test.py (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1import os
2import math
3import numpy as np
4import cv2
5from rknn.api import RKNN
6
7import PIL.Image as Image
8import PIL.ImageDraw as ImageDraw
9import PIL.ImageFont as ImageFont
10
11np.set_printoptions(threshold=np.inf)
12
13CLASSES = ('__background__',
14           'aeroplane', 'bicycle', 'bird', 'boat',
15           'bottle', 'bus', 'car', 'cat', 'chair',
16           'cow', 'diningtable', 'dog', 'horse',
17           'motorbike', 'person', 'pottedplant',
18           'sheep', 'sofa', 'train', 'tvmonitor')
19
20NUM_CLS = 21
21
22CONF_THRESH = 0.5
23NMS_THRESH = 0.45
24
25
26def IntersectBBox(box1, box2):
27    if box1[0] > box2[2] or box1[2] < box2[0] or box1[1] > box2[3] or box1[3] < box2[1]:
28        return 0
29    else:
30        area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
31        area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
32
33        xx1 = max(box1[0], box2[0])
34        yy1 = max(box1[1], box2[1])
35        xx2 = min(box1[2], box2[2])
36        yy2 = min(box1[3], box2[3])
37
38        w = max(0, xx2-xx1)
39        h = max(0, yy2-yy1)
40
41        ovr = w*h / (area1 + area2 - w*h)
42        return ovr
43
44
45def ssd_post_process(conf_data, loc_data):
46    prior_data = np.loadtxt('mbox_priorbox_97.txt', dtype=np.float32)
47
48    prior_bboxes = prior_data[:len(loc_data)]
49    prior_variances = prior_data[len(loc_data):]
50
51    prior_num = int(len(loc_data) / 4)  # 8732
52
53    conf_data = conf_data.reshape(-1, 21)
54
55    idx_class_conf = []
56    bboxes = []
57
58    # conf
59    for prior_idx in range(0, prior_num):
60        max_val = np.max(conf_data[prior_idx])
61        max_idx = np.argmax(conf_data[prior_idx])
62        if max_val > CONF_THRESH and max_idx != 0:
63            idx_class_conf.append([prior_idx, max_idx, max_val])
64
65    # print(len(idx_class_conf))
66
67    # boxes
68    for i in range(0, prior_num):
69        prior_w = prior_bboxes[4*i+2] - prior_bboxes[4*i]
70        prior_h = prior_bboxes[4*i+3] - prior_bboxes[4*i+1]
71        prior_center_x = (prior_bboxes[4*i+2] + prior_bboxes[4*i]) / 2
72        prior_center_y = (prior_bboxes[4*i+3] + prior_bboxes[4*i+1]) / 2
73
74        bbox_center_x = prior_variances[4*i+0] * loc_data[4*i+0] * prior_w + prior_center_x
75        bbox_center_y = prior_variances[4*i+1] * loc_data[4*i+1] * prior_h + prior_center_y
76        bbox_w = math.exp(prior_variances[4*i+2] * loc_data[4*i+2]) * prior_w
77        bbox_h = math.exp(prior_variances[4*i+3] * loc_data[4*i+3]) * prior_h
78
79        tmp = []
80        tmp.append(max(min(bbox_center_x - bbox_w / 2., 1), 0))
81        tmp.append(max(min(bbox_center_y - bbox_h / 2., 1), 0))
82        tmp.append(max(min(bbox_center_x + bbox_w / 2., 1), 0))
83        tmp.append(max(min(bbox_center_y + bbox_h / 2., 1), 0))
84        bboxes.append(tmp)
85
86    print(len(idx_class_conf))
87
88    # nms
89    cur_class_num = 0
90    idx_class_conf_ = []
91    for i in range(0, len(idx_class_conf)):
92        keep = True
93        k = 0
94        while k < cur_class_num:
95            if keep:
96                ovr = IntersectBBox(bboxes[idx_class_conf[i][0]], bboxes[idx_class_conf_[k][0]])
97                if idx_class_conf_[k][1] == idx_class_conf[i][1] and ovr > NMS_THRESH:
98                    if idx_class_conf_[k][2] < idx_class_conf[i][2]:
99                        idx_class_conf_.pop(k)
100                        idx_class_conf_.append(idx_class_conf[i])
101                    keep = False
102                    break
103                k += 1
104            else:
105                break
106        if keep:
107            idx_class_conf_.append(idx_class_conf[i])
108            cur_class_num += 1
109
110    print(idx_class_conf_)
111
112    box_class_score = []
113
114    for i in range(0, len(idx_class_conf_)):
115        bboxes[idx_class_conf_[i][0]].append(idx_class_conf_[i][1])
116        bboxes[idx_class_conf_[i][0]].append(idx_class_conf_[i][2])
117        box_class_score.append(bboxes[idx_class_conf_[i][0]])
118
119    img = cv2.imread('./road_300x300.jpg')
120    img_pil = Image.fromarray(img)
121    draw = ImageDraw.Draw(img_pil)
122
123    font = ImageFont.load_default()
124
125    for i in range(0, len(box_class_score)):
126        x1 = int(box_class_score[i][0]*img.shape[1])
127        y1 = int(box_class_score[i][1]*img.shape[0])
128        x2 = int(box_class_score[i][2]*img.shape[1])
129        y2 = int(box_class_score[i][3]*img.shape[0])
130        color = (0, int(box_class_score[i][4]/20.0*255), 255)
131        draw.line([(x1, y1), (x1, y2), (x2, y2),
132                   (x2, y1), (x1, y1)], width=2, fill=color)
133        display_str = CLASSES[box_class_score[i][4]] + ":" + str(box_class_score[i][5])
134        display_str_height = np.ceil((1 + 2 * 0.05) * font.getsize(display_str)[1])+1
135
136        if y1 > display_str_height:
137            text_bottom = y1
138        else:
139            text_bottom = y1 + display_str_height
140
141        text_width, text_height = font.getsize(display_str)
142        margin = np.ceil(0.05 * text_height)
143        draw.rectangle([(x1, text_bottom-text_height-2*margin), (x1+text_width, text_bottom)], fill=color)
144        draw.text((x1+margin, text_bottom-text_height-margin), display_str, fill='black', font=font)
145
146    np.copyto(img, np.array(img_pil))
147    cv2.imwrite("result.jpg", img)
148
149
150if __name__ == '__main__':
151
152    if not os.path.exists('./VGG_VOC0712_SSD_300x300_iter_120000.caffemodel'):
153        print('!!! Missing VGG_VOC0712_SSD_300x300_iter_120000.caffemodel !!!\n'
154              '1. Download models_VGGNet_VOC0712_SSD_300x300.tar.gz from https://drive.google.com/file/d/0BzKzrI_SkD1_WVVTSmQxU0dVRzA/view\n'
155              '2. Extract the VGG_VOC0712_SSD_300x300_iter_120000.caffemodel from models_VGGNet_VOC0712_SSD_300x300.tar.gz\n'
156              '3. Or you can also download caffemodel from https://eyun.baidu.com/s/3jJhPRzo , password is rknn\n')
157        exit(-1)
158
159    # Create RKNN object
160    rknn = RKNN(verbose=True)
161
162    # Pre-process config
163    print('--> Config model')
164    rknn.config(mean_values=[103.94, 116.78, 123.68], std_values=[1, 1, 1], quant_img_RGB2BGR=True)
165    print('done')
166
167    # Load model
168    print('--> Loading model')
169    ret = rknn.load_caffe(model='./deploy_rm_detection_output.prototxt',
170                          blobs='./VGG_VOC0712_SSD_300x300_iter_120000.caffemodel')
171    if ret != 0:
172        print('Load model failed!')
173        exit(ret)
174    print('done')
175
176    # Build model
177    print('--> Building model')
178    ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
179    if ret != 0:
180        print('Build model failed!')
181        exit(ret)
182    print('done')
183
184    # Export rknn model
185    print('--> Export rknn model')
186    ret = rknn.export_rknn('./deploy_rm_detection_output.rknn')
187    if ret != 0:
188        print('Export rknn model failed!')
189        exit(ret)
190    print('done')
191
192    # Set inputs
193    img = cv2.imread('./road_300x300.jpg')
194
195    # Init runtime environment
196    print('--> Init runtime environment')
197    ret = rknn.init_runtime()
198    if ret != 0:
199        print('Init runtime environment failed!')
200        exit(ret)
201    print('done')
202
203    # Inference
204    print('--> Running model')
205    outputs = rknn.inference(inputs=[img])
206    print('done')
207
208    outputs[0] = outputs[0].reshape((-1, 1))
209    outputs[1] = outputs[1].reshape((-1, 1))
210    np.save('./caffe_vgg-ssd_0.npy', outputs[0])
211    np.save('./caffe_vgg-ssd_1.npy', outputs[1])
212    ssd_post_process(outputs[1], outputs[0])
213
214    rknn.release()
215