1*4882a593Smuzhiyunimport os 2*4882a593Smuzhiyunimport urllib 3*4882a593Smuzhiyunimport traceback 4*4882a593Smuzhiyunimport time 5*4882a593Smuzhiyunimport sys 6*4882a593Smuzhiyunimport numpy as np 7*4882a593Smuzhiyunimport cv2 8*4882a593Smuzhiyunfrom rknn.api import RKNN 9*4882a593Smuzhiyun 10*4882a593SmuzhiyunONNX_MODEL = 'resnet50v2.onnx' 11*4882a593SmuzhiyunRKNN_MODEL = 'resnet50v2.rknn' 12*4882a593Smuzhiyun 13*4882a593Smuzhiyun 14*4882a593Smuzhiyundef show_outputs(outputs): 15*4882a593Smuzhiyun output = outputs[0][0] 16*4882a593Smuzhiyun output_sorted = sorted(output, reverse=True) 17*4882a593Smuzhiyun top5_str = 'resnet50v2\n-----TOP 5-----\n' 18*4882a593Smuzhiyun for i in range(5): 19*4882a593Smuzhiyun value = output_sorted[i] 20*4882a593Smuzhiyun index = np.where(output == value) 21*4882a593Smuzhiyun for j in range(len(index)): 22*4882a593Smuzhiyun if (i + j) >= 5: 23*4882a593Smuzhiyun break 24*4882a593Smuzhiyun if value > 0: 25*4882a593Smuzhiyun topi = '{}: {}\n'.format(index[j], value) 26*4882a593Smuzhiyun else: 27*4882a593Smuzhiyun topi = '-1: 0.0\n' 28*4882a593Smuzhiyun top5_str += topi 29*4882a593Smuzhiyun print(top5_str) 30*4882a593Smuzhiyun 31*4882a593Smuzhiyun 32*4882a593Smuzhiyundef readable_speed(speed): 33*4882a593Smuzhiyun speed_bytes = float(speed) 34*4882a593Smuzhiyun speed_kbytes = speed_bytes / 1024 35*4882a593Smuzhiyun if speed_kbytes > 1024: 36*4882a593Smuzhiyun speed_mbytes = speed_kbytes / 1024 37*4882a593Smuzhiyun if speed_mbytes > 1024: 38*4882a593Smuzhiyun speed_gbytes = speed_mbytes / 1024 39*4882a593Smuzhiyun return "{:.2f} GB/s".format(speed_gbytes) 40*4882a593Smuzhiyun else: 41*4882a593Smuzhiyun return "{:.2f} MB/s".format(speed_mbytes) 42*4882a593Smuzhiyun else: 43*4882a593Smuzhiyun return "{:.2f} KB/s".format(speed_kbytes) 44*4882a593Smuzhiyun 45*4882a593Smuzhiyun 46*4882a593Smuzhiyundef show_progress(blocknum, blocksize, totalsize): 47*4882a593Smuzhiyun speed = (blocknum * blocksize) / (time.time() - start_time) 48*4882a593Smuzhiyun speed_str = " Speed: {}".format(readable_speed(speed)) 49*4882a593Smuzhiyun recv_size = blocknum * blocksize 50*4882a593Smuzhiyun 51*4882a593Smuzhiyun f = sys.stdout 52*4882a593Smuzhiyun progress = (recv_size / totalsize) 53*4882a593Smuzhiyun progress_str = "{:.2f}%".format(progress * 100) 54*4882a593Smuzhiyun n = round(progress * 50) 55*4882a593Smuzhiyun s = ('#' * n).ljust(50, '-') 56*4882a593Smuzhiyun f.write(progress_str.ljust(8, ' ') + '[' + s + ']' + speed_str) 57*4882a593Smuzhiyun f.flush() 58*4882a593Smuzhiyun f.write('\r\n') 59*4882a593Smuzhiyun 60*4882a593Smuzhiyun 61*4882a593Smuzhiyunif __name__ == '__main__': 62*4882a593Smuzhiyun 63*4882a593Smuzhiyun # Create RKNN object 64*4882a593Smuzhiyun rknn = RKNN(verbose=True) 65*4882a593Smuzhiyun 66*4882a593Smuzhiyun # If resnet50v2 does not exist, download it. 67*4882a593Smuzhiyun # Download address: 68*4882a593Smuzhiyun # https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx 69*4882a593Smuzhiyun if not os.path.exists(ONNX_MODEL): 70*4882a593Smuzhiyun print('--> Download {}'.format(ONNX_MODEL)) 71*4882a593Smuzhiyun url = 'https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx' 72*4882a593Smuzhiyun download_file = ONNX_MODEL 73*4882a593Smuzhiyun try: 74*4882a593Smuzhiyun start_time = time.time() 75*4882a593Smuzhiyun urllib.request.urlretrieve(url, download_file, show_progress) 76*4882a593Smuzhiyun except: 77*4882a593Smuzhiyun print('Download {} failed.'.format(download_file)) 78*4882a593Smuzhiyun print(traceback.format_exc()) 79*4882a593Smuzhiyun exit(-1) 80*4882a593Smuzhiyun print('done') 81*4882a593Smuzhiyun 82*4882a593Smuzhiyun # pre-process config 83*4882a593Smuzhiyun print('--> config model') 84*4882a593Smuzhiyun rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.82, 58.82, 58.82]) 85*4882a593Smuzhiyun print('done') 86*4882a593Smuzhiyun 87*4882a593Smuzhiyun # Load model 88*4882a593Smuzhiyun print('--> Loading model') 89*4882a593Smuzhiyun ret = rknn.load_onnx(model=ONNX_MODEL) 90*4882a593Smuzhiyun if ret != 0: 91*4882a593Smuzhiyun print('Load model failed!') 92*4882a593Smuzhiyun exit(ret) 93*4882a593Smuzhiyun print('done') 94*4882a593Smuzhiyun 95*4882a593Smuzhiyun # Build model 96*4882a593Smuzhiyun print('--> Building model') 97*4882a593Smuzhiyun ret = rknn.build(do_quantization=True, dataset='./dataset.txt') 98*4882a593Smuzhiyun if ret != 0: 99*4882a593Smuzhiyun print('Build model failed!') 100*4882a593Smuzhiyun exit(ret) 101*4882a593Smuzhiyun print('done') 102*4882a593Smuzhiyun 103*4882a593Smuzhiyun # Export rknn model 104*4882a593Smuzhiyun print('--> Export rknn model') 105*4882a593Smuzhiyun ret = rknn.export_rknn(RKNN_MODEL) 106*4882a593Smuzhiyun if ret != 0: 107*4882a593Smuzhiyun print('Export rknn model failed!') 108*4882a593Smuzhiyun exit(ret) 109*4882a593Smuzhiyun print('done') 110*4882a593Smuzhiyun 111*4882a593Smuzhiyun # Set inputs 112*4882a593Smuzhiyun img = cv2.imread('./dog_224x224.jpg') 113*4882a593Smuzhiyun img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 114*4882a593Smuzhiyun 115*4882a593Smuzhiyun # Init runtime environment 116*4882a593Smuzhiyun print('--> Init runtime environment') 117*4882a593Smuzhiyun ret = rknn.init_runtime() 118*4882a593Smuzhiyun if ret != 0: 119*4882a593Smuzhiyun print('Init runtime environment failed!') 120*4882a593Smuzhiyun exit(ret) 121*4882a593Smuzhiyun print('done') 122*4882a593Smuzhiyun 123*4882a593Smuzhiyun # Inference 124*4882a593Smuzhiyun print('--> Running model') 125*4882a593Smuzhiyun outputs = rknn.inference(inputs=[img]) 126*4882a593Smuzhiyun np.save('./onnx_resnet50v2_0.npy', outputs[0]) 127*4882a593Smuzhiyun x = outputs[0] 128*4882a593Smuzhiyun output = np.exp(x)/np.sum(np.exp(x)) 129*4882a593Smuzhiyun outputs = [output] 130*4882a593Smuzhiyun show_outputs(outputs) 131*4882a593Smuzhiyun print('done') 132*4882a593Smuzhiyun 133*4882a593Smuzhiyun rknn.release() 134