1import numpy as np 2import cv2 3import os 4import urllib.request 5 6NUM_CLS = 80 7MAX_BOXES = 500 8OBJ_THRESH = 0.5 9NMS_THRESH = 0.6 10 11CLASSES = ("person", "bicycle", "car","motorbike ","aeroplane ","bus ","train","truck ","boat","traffic light", 12 "fire hydrant","stop sign ","parking meter","bench","bird","cat","dog ","horse ","sheep","cow","elephant", 13 "bear","zebra ","giraffe","backpack","umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball","kite", 14 "baseball bat","baseball glove","skateboard","surfboard","tennis racket","bottle","wine glass","cup","fork","knife ", 15 "spoon","bowl","banana","apple","sandwich","orange","broccoli","carrot","hot dog","pizza ","donut","cake","chair","sofa", 16 "pottedplant","bed","diningtable","toilet ","tvmonitor","laptop ","mouse ","remote ","keyboard ","cell phone","microwave ", 17 "oven ","toaster","sink","refrigerator ","book","clock","vase","scissors ","teddy bear ","hair drier", "toothbrush ") 18 19def sigmoid(x): 20 return 1 / (1 + np.exp(-x)) 21 22def process(input, mask, anchors): 23 24 anchors = [anchors[i] for i in mask] 25 grid_h, grid_w = map(int, input.shape[0:2]) 26 27 box_confidence = sigmoid(input[..., 4]) 28 box_confidence = np.expand_dims(box_confidence, axis=-1) 29 30 box_class_probs = sigmoid(input[..., 5:]) 31 32 box_xy = sigmoid(input[..., :2]) 33 box_wh = np.exp(input[..., 2:4]) 34 box_wh = box_wh * anchors 35 36 col = np.tile(np.arange(0, grid_w), grid_w).reshape(-1, grid_w) 37 row = np.tile(np.arange(0, grid_h).reshape(-1, 1), grid_h) 38 39 col = col.reshape(grid_h, grid_w, 1, 1).repeat(3, axis=-2) 40 row = row.reshape(grid_h, grid_w, 1, 1).repeat(3, axis=-2) 41 grid = np.concatenate((col, row), axis=-1) 42 43 box_xy += grid 44 box_xy /= (grid_w, grid_h) 45 box_wh /= (416, 416) 46 box_xy -= (box_wh / 2.) 47 box = np.concatenate((box_xy, box_wh), axis=-1) 48 49 return box, box_confidence, box_class_probs 50 51def filter_boxes(boxes, box_confidences, box_class_probs): 52 """Filter boxes with object threshold. 53 54 # Arguments 55 boxes: ndarray, boxes of objects. 56 box_confidences: ndarray, confidences of objects. 57 box_class_probs: ndarray, class_probs of objects. 58 59 # Returns 60 boxes: ndarray, filtered boxes. 61 classes: ndarray, classes for boxes. 62 scores: ndarray, scores for boxes. 63 """ 64 box_scores = box_confidences * box_class_probs 65 box_classes = np.argmax(box_scores, axis=-1) 66 box_class_scores = np.max(box_scores, axis=-1) 67 pos = np.where(box_class_scores >= OBJ_THRESH) 68 69 boxes = boxes[pos] 70 classes = box_classes[pos] 71 scores = box_class_scores[pos] 72 73 return boxes, classes, scores 74 75def nms_boxes(boxes, scores): 76 """Suppress non-maximal boxes. 77 78 # Arguments 79 boxes: ndarray, boxes of objects. 80 scores: ndarray, scores of objects. 81 82 # Returns 83 keep: ndarray, index of effective boxes. 84 """ 85 x = boxes[:, 0] 86 y = boxes[:, 1] 87 w = boxes[:, 2] 88 h = boxes[:, 3] 89 90 areas = w * h 91 order = scores.argsort()[::-1] 92 93 keep = [] 94 while order.size > 0: 95 i = order[0] 96 keep.append(i) 97 98 xx1 = np.maximum(x[i], x[order[1:]]) 99 yy1 = np.maximum(y[i], y[order[1:]]) 100 xx2 = np.minimum(x[i] + w[i], x[order[1:]] + w[order[1:]]) 101 yy2 = np.minimum(y[i] + h[i], y[order[1:]] + h[order[1:]]) 102 103 w1 = np.maximum(0.0, xx2 - xx1 + 0.00001) 104 h1 = np.maximum(0.0, yy2 - yy1 + 0.00001) 105 inter = w1 * h1 106 107 ovr = inter / (areas[i] + areas[order[1:]] - inter) 108 inds = np.where(ovr <= NMS_THRESH)[0] 109 order = order[inds + 1] 110 keep = np.array(keep) 111 return keep 112 113def yolov3_post_process(input_data): 114 # yolov3 115 masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] 116 anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], 117 [59, 119], [116, 90], [156, 198], [373, 326]] 118 # yolov3-tiny 119 # masks = [[3, 4, 5], [0, 1, 2]] 120 # anchors = [[10, 14], [23, 27], [37, 58], [81, 82], [135, 169], [344, 319]] 121 122 boxes, classes, scores = [], [], [] 123 for input,mask in zip(input_data, masks): 124 b, c, s = process(input, mask, anchors) 125 b, c, s = filter_boxes(b, c, s) 126 boxes.append(b) 127 classes.append(c) 128 scores.append(s) 129 130 boxes = np.concatenate(boxes) 131 classes = np.concatenate(classes) 132 scores = np.concatenate(scores) 133 134 nboxes, nclasses, nscores = [], [], [] 135 for c in set(classes): 136 inds = np.where(classes == c) 137 b = boxes[inds] 138 c = classes[inds] 139 s = scores[inds] 140 141 keep = nms_boxes(b, s) 142 143 nboxes.append(b[keep]) 144 nclasses.append(c[keep]) 145 nscores.append(s[keep]) 146 147 if not nclasses and not nscores: 148 return None, None, None 149 150 boxes = np.concatenate(nboxes) 151 classes = np.concatenate(nclasses) 152 scores = np.concatenate(nscores) 153 154 return boxes, classes, scores 155 156def draw(image, boxes, scores, classes): 157 """Draw the boxes on the image. 158 159 # Argument: 160 image: original image. 161 boxes: ndarray, boxes of objects. 162 classes: ndarray, classes of objects. 163 scores: ndarray, scores of objects. 164 all_classes: all classes name. 165 """ 166 for box, score, cl in zip(boxes, scores, classes): 167 x, y, w, h = box 168 print('class: {}, score: {}'.format(CLASSES[cl], score)) 169 print('box coordinate left,top,right,down: [{}, {}, {}, {}]'.format(x, y, x+w, y+h)) 170 x *= image.shape[1] 171 y *= image.shape[0] 172 w *= image.shape[1] 173 h *= image.shape[0] 174 top = max(0, np.floor(x + 0.5).astype(int)) 175 left = max(0, np.floor(y + 0.5).astype(int)) 176 right = min(image.shape[1], np.floor(x + w + 0.5).astype(int)) 177 bottom = min(image.shape[0], np.floor(y + h + 0.5).astype(int)) 178 179 cv2.rectangle(image, (top, left), (right, bottom), (255, 0, 0), 2) 180 cv2.putText(image, '{0} {1:.2f}'.format(CLASSES[cl], score), 181 (top, left - 6), 182 cv2.FONT_HERSHEY_SIMPLEX, 183 0.6, (0, 0, 255), 2) 184 185 186def download_yolov3_weight(dst_path): 187 if os.path.exists(dst_path): 188 print('yolov3.weight exist.') 189 return 190 print('Downloading yolov3.weights...') 191 url = 'https://pjreddie.com/media/files/yolov3.weights' 192 try: 193 urllib.request.urlretrieve(url, dst_path) 194 except urllib.error.HTTPError as e: 195 print('HTTPError code: ', e.code) 196 print('HTTPError reason: ', e.reason) 197 exit(-1) 198 except urllib.error.URLError as e: 199 print('URLError reason: ', e.reason) 200 else: 201 print('Download yolov3.weight success.') 202