xref: /OK3568_Linux_fs/external/rknn-toolkit2/examples/tflite/mobilenet_v1/test.py (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1import numpy as np
2import cv2
3from rknn.api import RKNN
4
5
6def show_outputs(outputs):
7    output = outputs[0][0]
8    output_sorted = sorted(output, reverse=True)
9    top5_str = 'mobilenet_v1\n-----TOP 5-----\n'
10    for i in range(5):
11        value = output_sorted[i]
12        index = np.where(output == value)
13        for j in range(len(index)):
14            if (i + j) >= 5:
15                break
16            if value > 0:
17                topi = '{}: {}\n'.format(index[j], value)
18            else:
19                topi = '-1: 0.0\n'
20            top5_str += topi
21    print(top5_str)
22
23
24if __name__ == '__main__':
25
26    # Create RKNN object
27    rknn = RKNN(verbose=True)
28
29    # Pre-process config
30    print('--> Config model')
31    rknn.config(mean_values=[128, 128, 128], std_values=[128, 128, 128])
32    print('done')
33
34    # Load model
35    print('--> Loading model')
36    ret = rknn.load_tflite(model='mobilenet_v1_1.0_224.tflite')
37    if ret != 0:
38        print('Load model failed!')
39        exit(ret)
40    print('done')
41
42    # Build model
43    print('--> Building model')
44    ret = rknn.build(do_quantization=True, dataset='./dataset.txt')
45    if ret != 0:
46        print('Build model failed!')
47        exit(ret)
48    print('done')
49
50    # Export rknn model
51    print('--> Export rknn model')
52    ret = rknn.export_rknn('./mobilenet_v1.rknn')
53    if ret != 0:
54        print('Export rknn model failed!')
55        exit(ret)
56    print('done')
57
58    # Set inputs
59    img = cv2.imread('./dog_224x224.jpg')
60    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
61    img = np.expand_dims(img, 0)
62
63    # Init runtime environment
64    print('--> Init runtime environment')
65    ret = rknn.init_runtime()
66    if ret != 0:
67        print('Init runtime environment failed!')
68        exit(ret)
69    print('done')
70
71    # Inference
72    print('--> Running model')
73    outputs = rknn.inference(inputs=[img])
74    np.save('./tflite_mobilenet_v1_0.npy', outputs[0])
75    show_outputs(outputs)
76    print('done')
77
78    rknn.release()
79