xref: /OK3568_Linux_fs/external/rknn-toolkit2/examples/pytorch/resnet18/test.py (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1import numpy as np
2import cv2
3from rknn.api import RKNN
4import torchvision.models as models
5import torch
6import os
7
8
9def export_pytorch_model():
10    net = models.resnet18(pretrained=True)
11    net.eval()
12    trace_model = torch.jit.trace(net, torch.Tensor(1, 3, 224, 224))
13    trace_model.save('./resnet18.pt')
14
15
16def show_outputs(output):
17    output_sorted = sorted(output, reverse=True)
18    top5_str = '\n-----TOP 5-----\n'
19    for i in range(5):
20        value = output_sorted[i]
21        index = np.where(output == value)
22        for j in range(len(index)):
23            if (i + j) >= 5:
24                break
25            if value > 0:
26                topi = '{}: {}\n'.format(index[j], value)
27            else:
28                topi = '-1: 0.0\n'
29            top5_str += topi
30    print(top5_str)
31
32
33def show_perfs(perfs):
34    perfs = 'perfs: {}\n'.format(perfs)
35    print(perfs)
36
37
38def softmax(x):
39    return np.exp(x)/sum(np.exp(x))
40
41
42if __name__ == '__main__':
43
44    model = './resnet18.pt'
45    if not os.path.exists(model):
46        export_pytorch_model()
47
48    input_size_list = [[1, 3, 224, 224]]
49
50    # Create RKNN object
51    rknn = RKNN(verbose=True)
52
53    # Pre-process config
54    print('--> Config model')
55    rknn.config(mean_values=[123.675, 116.28, 103.53], std_values=[58.395, 58.395, 58.395])
56    print('done')
57
58    # Load model
59    print('--> Loading model')
60    ret = rknn.load_pytorch(model=model, input_size_list=input_size_list)
61    if ret != 0:
62        print('Load model failed!')
63        exit(ret)
64    print('done')
65
66    # Build model
67    print('--> Building model')
68    ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
69    if ret != 0:
70        print('Build model failed!')
71        exit(ret)
72    print('done')
73
74    # Export rknn model
75    print('--> Export rknn model')
76    ret = rknn.export_rknn('./resnet_18.rknn')
77    if ret != 0:
78        print('Export rknn model failed!')
79        exit(ret)
80    print('done')
81
82    # Set inputs
83    img = cv2.imread('./space_shuttle_224.jpg')
84    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
85
86    # Init runtime environment
87    print('--> Init runtime environment')
88    ret = rknn.init_runtime()
89    if ret != 0:
90        print('Init runtime environment failed!')
91        exit(ret)
92    print('done')
93
94    # Inference
95    print('--> Running model')
96    outputs = rknn.inference(inputs=[img])
97    np.save('./pytorch_resnet18_0.npy', outputs[0])
98    show_outputs(softmax(np.array(outputs[0][0])))
99    print('done')
100
101    rknn.release()
102