1*4882a593Smuzhiyunimport os 2*4882a593Smuzhiyunimport urllib 3*4882a593Smuzhiyunimport traceback 4*4882a593Smuzhiyunimport time 5*4882a593Smuzhiyunimport sys 6*4882a593Smuzhiyunimport numpy as np 7*4882a593Smuzhiyunimport cv2 8*4882a593Smuzhiyunfrom rknn.api import RKNN 9*4882a593Smuzhiyunimport urllib.request 10*4882a593Smuzhiyun 11*4882a593SmuzhiyunONNX_MODEL = 'resnet50v2.onnx' 12*4882a593SmuzhiyunRKNN_MODEL = 'resnet50v2.rknn' 13*4882a593Smuzhiyun 14*4882a593Smuzhiyun 15*4882a593Smuzhiyundef show_outputs(outputs): 16*4882a593Smuzhiyun output = outputs 17*4882a593Smuzhiyun output_sorted = sorted(output, reverse=True) 18*4882a593Smuzhiyun top5_str = 'resnet50v2\n-----TOP 5-----\n' 19*4882a593Smuzhiyun for i in range(5): 20*4882a593Smuzhiyun value = output_sorted[i] 21*4882a593Smuzhiyun index = np.where(output == value) 22*4882a593Smuzhiyun for j in range(len(index)): 23*4882a593Smuzhiyun if (i + j) >= 5: 24*4882a593Smuzhiyun break 25*4882a593Smuzhiyun if value > 0: 26*4882a593Smuzhiyun topi = '{}: {}\n'.format(index[j], value) 27*4882a593Smuzhiyun else: 28*4882a593Smuzhiyun topi = '-1: 0.0\n' 29*4882a593Smuzhiyun top5_str += topi 30*4882a593Smuzhiyun print(top5_str) 31*4882a593Smuzhiyun 32*4882a593Smuzhiyun 33*4882a593Smuzhiyundef readable_speed(speed): 34*4882a593Smuzhiyun speed_bytes = float(speed) 35*4882a593Smuzhiyun speed_kbytes = speed_bytes / 1024 36*4882a593Smuzhiyun if speed_kbytes > 1024: 37*4882a593Smuzhiyun speed_mbytes = speed_kbytes / 1024 38*4882a593Smuzhiyun if speed_mbytes > 1024: 39*4882a593Smuzhiyun speed_gbytes = speed_mbytes / 1024 40*4882a593Smuzhiyun return "{:.2f} GB/s".format(speed_gbytes) 41*4882a593Smuzhiyun else: 42*4882a593Smuzhiyun return "{:.2f} MB/s".format(speed_mbytes) 43*4882a593Smuzhiyun else: 44*4882a593Smuzhiyun return "{:.2f} KB/s".format(speed_kbytes) 45*4882a593Smuzhiyun 46*4882a593Smuzhiyun 47*4882a593Smuzhiyundef show_progress(blocknum, blocksize, totalsize): 48*4882a593Smuzhiyun speed = (blocknum * blocksize) / (time.time() - start_time) 49*4882a593Smuzhiyun speed_str = " Speed: {}".format(readable_speed(speed)) 50*4882a593Smuzhiyun recv_size = blocknum * blocksize 51*4882a593Smuzhiyun 52*4882a593Smuzhiyun f = sys.stdout 53*4882a593Smuzhiyun progress = (recv_size / totalsize) 54*4882a593Smuzhiyun progress_str = "{:.2f}%".format(progress * 100) 55*4882a593Smuzhiyun n = round(progress * 50) 56*4882a593Smuzhiyun s = ('#' * n).ljust(50, '-') 57*4882a593Smuzhiyun f.write(progress_str.ljust(8, ' ') + '[' + s + ']' + speed_str) 58*4882a593Smuzhiyun f.flush() 59*4882a593Smuzhiyun f.write('\r\n') 60*4882a593Smuzhiyun 61*4882a593Smuzhiyun 62*4882a593Smuzhiyunif __name__ == '__main__': 63*4882a593Smuzhiyun 64*4882a593Smuzhiyun # Create RKNN object 65*4882a593Smuzhiyun rknn = RKNN(verbose=True) 66*4882a593Smuzhiyun 67*4882a593Smuzhiyun # If resnet50v2 does not exist, download it. 68*4882a593Smuzhiyun # Download address: 69*4882a593Smuzhiyun # https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx 70*4882a593Smuzhiyun if not os.path.exists(ONNX_MODEL): 71*4882a593Smuzhiyun print('--> Download {}'.format(ONNX_MODEL)) 72*4882a593Smuzhiyun url = 'https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.onnx' 73*4882a593Smuzhiyun download_file = ONNX_MODEL 74*4882a593Smuzhiyun try: 75*4882a593Smuzhiyun start_time = time.time() 76*4882a593Smuzhiyun urllib.request.urlretrieve(url, download_file, show_progress) 77*4882a593Smuzhiyun except: 78*4882a593Smuzhiyun print('Download {} failed.'.format(download_file)) 79*4882a593Smuzhiyun print(traceback.format_exc()) 80*4882a593Smuzhiyun exit(-1) 81*4882a593Smuzhiyun print('done') 82*4882a593Smuzhiyun 83*4882a593Smuzhiyun # Pre-process config 84*4882a593Smuzhiyun print('--> Config model') 85*4882a593Smuzhiyun rknn.config(mean_values=[123.68, 116.28, 103.53], std_values=[57.38, 57.38, 57.38]) 86*4882a593Smuzhiyun print('done') 87*4882a593Smuzhiyun 88*4882a593Smuzhiyun # Load model 89*4882a593Smuzhiyun print('--> Loading model') 90*4882a593Smuzhiyun ret = rknn.load_onnx(model=ONNX_MODEL) 91*4882a593Smuzhiyun if ret != 0: 92*4882a593Smuzhiyun print('Load model failed!') 93*4882a593Smuzhiyun exit(ret) 94*4882a593Smuzhiyun print('done') 95*4882a593Smuzhiyun 96*4882a593Smuzhiyun # Build model 97*4882a593Smuzhiyun print('--> Building model') 98*4882a593Smuzhiyun ret = rknn.build(do_quantization=True, dataset='./dataset.txt') 99*4882a593Smuzhiyun if ret != 0: 100*4882a593Smuzhiyun print('Build model failed!') 101*4882a593Smuzhiyun exit(ret) 102*4882a593Smuzhiyun print('done') 103*4882a593Smuzhiyun 104*4882a593Smuzhiyun # Accuracy analysis 105*4882a593Smuzhiyun print('--> Accuracy analysis') 106*4882a593Smuzhiyun ret = rknn.accuracy_analysis(inputs=['./dog_224x224.jpg'], output_dir='./snapshot') 107*4882a593Smuzhiyun if ret != 0: 108*4882a593Smuzhiyun print('Accuracy analysis failed!') 109*4882a593Smuzhiyun exit(ret) 110*4882a593Smuzhiyun print('done') 111*4882a593Smuzhiyun 112*4882a593Smuzhiyun print('float32:') 113*4882a593Smuzhiyun output = np.genfromtxt('./snapshot/golden/resnetv24_dense0_fwd.txt') 114*4882a593Smuzhiyun show_outputs(output) 115*4882a593Smuzhiyun 116*4882a593Smuzhiyun print('quantized:') 117*4882a593Smuzhiyun output = np.genfromtxt('./snapshot/simulator/resnetv24_dense0_fwd.txt') 118*4882a593Smuzhiyun show_outputs(output) 119*4882a593Smuzhiyun 120*4882a593Smuzhiyun rknn.release() 121