xref: /OK3568_Linux_fs/external/rknn-toolkit2/rknn_toolkit_lite2/examples/inference_with_lite/test.py (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1import cv2
2import numpy as np
3import platform
4from rknnlite.api import RKNNLite
5
6# decice tree for rk356x/rk3588
7DEVICE_COMPATIBLE_NODE = '/proc/device-tree/compatible'
8
9def get_host():
10    # get platform and device type
11    system = platform.system()
12    machine = platform.machine()
13    os_machine = system + '-' + machine
14    if os_machine == 'Linux-aarch64':
15        try:
16            with open(DEVICE_COMPATIBLE_NODE) as f:
17                device_compatible_str = f.read()
18                if 'rk3588' in device_compatible_str:
19                    host = 'RK3588'
20                elif 'rk3562' in device_compatible_str:
21                    host = 'RK3562'
22                else:
23                    host = 'RK3566_RK3568'
24        except IOError:
25            print('Read device node {} failed.'.format(DEVICE_COMPATIBLE_NODE))
26            exit(-1)
27    else:
28        host = os_machine
29    return host
30
31INPUT_SIZE = 224
32
33RK3566_RK3568_RKNN_MODEL = 'resnet18_for_rk3566_rk3568.rknn'
34RK3588_RKNN_MODEL = 'resnet18_for_rk3588.rknn'
35RK3562_RKNN_MODEL = 'resnet18_for_rk3562.rknn'
36
37
38def show_top5(result):
39    output = result[0].reshape(-1)
40    # softmax
41    output = np.exp(output)/sum(np.exp(output))
42    output_sorted = sorted(output, reverse=True)
43    top5_str = 'resnet18\n-----TOP 5-----\n'
44    for i in range(5):
45        value = output_sorted[i]
46        index = np.where(output == value)
47        for j in range(len(index)):
48            if (i + j) >= 5:
49                break
50            if value > 0:
51                topi = '{}: {}\n'.format(index[j], value)
52            else:
53                topi = '-1: 0.0\n'
54            top5_str += topi
55    print(top5_str)
56
57
58if __name__ == '__main__':
59
60    host_name = get_host()
61    if host_name == 'RK3566_RK3568':
62        rknn_model = RK3566_RK3568_RKNN_MODEL
63    elif host_name == 'RK3562':
64        rknn_model = RK3562_RKNN_MODEL
65    elif host_name == 'RK3588':
66        rknn_model = RK3588_RKNN_MODEL
67    else:
68        print("This demo cannot run on the current platform: {}".format(host_name))
69        exit(-1)
70
71    rknn_lite = RKNNLite()
72
73    # load RKNN model
74    print('--> Load RKNN model')
75    ret = rknn_lite.load_rknn(rknn_model)
76    if ret != 0:
77        print('Load RKNN model failed')
78        exit(ret)
79    print('done')
80
81    ori_img = cv2.imread('./space_shuttle_224.jpg')
82    img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
83
84    # init runtime environment
85    print('--> Init runtime environment')
86    # run on RK356x/RK3588 with Debian OS, do not need specify target.
87    if host_name == 'RK3588':
88        ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
89    else:
90        ret = rknn_lite.init_runtime()
91    if ret != 0:
92        print('Init runtime environment failed')
93        exit(ret)
94    print('done')
95
96    # Inference
97    print('--> Running model')
98    outputs = rknn_lite.inference(inputs=[img])
99    show_top5(outputs)
100    print('done')
101
102    rknn_lite.release()
103