1import numpy as np 2import cv2 3from rknn.api import RKNN 4import torchvision.models as models 5import torch 6import os 7 8 9def export_pytorch_model(): 10 net = models.resnet18(pretrained=True) 11 net.eval() 12 trace_model = torch.jit.trace(net, torch.Tensor(1, 3, 224, 224)) 13 trace_model.save('./resnet18.pt') 14 15 16def show_outputs(output): 17 output_sorted = sorted(output, reverse=True) 18 top5_str = '\n-----TOP 5-----\n' 19 for i in range(5): 20 value = output_sorted[i] 21 index = np.where(output == value) 22 for j in range(len(index)): 23 if (i + j) >= 5: 24 break 25 if value > 0: 26 topi = '{}: {}\n'.format(index[j], value) 27 else: 28 topi = '-1: 0.0\n' 29 top5_str += topi 30 print(top5_str) 31 32 33def show_perfs(perfs): 34 perfs = 'perfs: {}\n'.format(perfs) 35 print(perfs) 36 37 38def softmax(x): 39 return np.exp(x)/sum(np.exp(x)) 40 41 42if __name__ == '__main__': 43 44 model = './resnet18.pt' 45 if not os.path.exists(model): 46 export_pytorch_model() 47 48 input_size_list = [[1, 3, 224, 224]] 49 50 # Create RKNN object 51 rknn = RKNN(verbose=True) 52 53 # Pre-process config 54 print('--> Config model') 55 rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.395, 58.395, 58.395]) 56 print('done') 57 58 # Load model 59 print('--> Loading model') 60 ret = rknn.load_pytorch(model=model, input_size_list=input_size_list) 61 if ret != 0: 62 print('Load model failed!') 63 exit(ret) 64 print('done') 65 66 # Build model 67 print('--> Building model') 68 ret = rknn.build(do_quantization=True, dataset='./dataset.txt') 69 if ret != 0: 70 print('Build model failed!') 71 exit(ret) 72 print('done') 73 74 # Export rknn model 75 print('--> Export rknn model') 76 ret = rknn.export_rknn('./resnet_18.rknn') 77 if ret != 0: 78 print('Export rknn model failed!') 79 exit(ret) 80 print('done') 81 82 # Set inputs 83 img = cv2.imread('./space_shuttle_224.jpg') 84 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 85 86 # Init runtime environment 87 print('--> Init runtime environment') 88 ret = rknn.init_runtime() 89 if ret != 0: 90 print('Init runtime environment failed!') 91 exit(ret) 92 print('done') 93 94 # Inference 95 print('--> Running model') 96 outputs = rknn.inference(inputs=[img]) 97 np.save('./pytorch_resnet18_0.npy', outputs[0]) 98 show_outputs(softmax(np.array(outputs[0][0]))) 99 print('done') 100 101 rknn.release() 102