xref: /OK3568_Linux_fs/external/rknn-toolkit2/examples/tensorflow/inception_v3_qat/test.py (revision 4882a59341e53eb6f0b4789bf948001014eff981)
1import numpy as np
2import cv2
3import os
4import urllib
5import tarfile
6import shutil
7import traceback
8import time
9import sys
10from rknn.api import RKNN
11
12PB_FILE = './inception_v3_quant_frozen.pb'
13RKNN_MODEL_PATH = './inception_v3_quant_frozen.rknn'
14INPUTS = ['input']
15OUTPUTS = ['InceptionV3/Logits/SpatialSqueeze']
16IMG_PATH = './goldfish_299x299.jpg'
17INPUT_SIZE = 299
18
19
20def show_outputs(outputs):
21    output = outputs[0][0]
22    output_sorted = sorted(output, reverse=True)
23    top5_str = 'inception_v3\n-----TOP 5-----\n'
24    for i in range(5):
25        value = output_sorted[i]
26        index = np.where(output == value)
27        for j in range(len(index)):
28            if (i + j) >= 5:
29                break
30            if value > 0:
31                topi = '{}: {}\n'.format(index[j], value)
32            else:
33                topi = '-1: 0.0\n'
34            top5_str += topi
35    print(top5_str)
36
37
38def readable_speed(speed):
39    speed_bytes = float(speed)
40    speed_kbytes = speed_bytes / 1024
41    if speed_kbytes > 1024:
42        speed_mbytes = speed_kbytes / 1024
43        if speed_mbytes > 1024:
44            speed_gbytes = speed_mbytes / 1024
45            return "{:.2f} GB/s".format(speed_gbytes)
46        else:
47            return "{:.2f} MB/s".format(speed_mbytes)
48    else:
49        return "{:.2f} KB/s".format(speed_kbytes)
50
51
52def show_progress(blocknum, blocksize, totalsize):
53    speed = (blocknum * blocksize) / (time.time() - start_time)
54    speed_str = " Speed: {}".format(readable_speed(speed))
55    recv_size = blocknum * blocksize
56
57    f = sys.stdout
58    progress = (recv_size / totalsize)
59    progress_str = "{:.2f}%".format(progress * 100)
60    n = round(progress * 50)
61    s = ('#' * n).ljust(50, '-')
62    f.write(progress_str.ljust(8, ' ') + '[' + s + ']' + speed_str)
63    f.flush()
64    f.write('\r\n')
65
66
67if __name__ == '__main__':
68
69    # Create RKNN object
70    rknn = RKNN(verbose=True)
71
72    # If inception_v3_quant_frozen.pb does not exist, download it.
73    # Download address:
74    # https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz
75    if not os.path.exists(PB_FILE):
76        print('--> Download {}'.format(PB_FILE))
77        url = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz'
78        download_file = 'inception_v3_quant.tgz'
79        try:
80            start_time = time.time()
81            urllib.request.urlretrieve(url, download_file, show_progress)
82        except:
83            print('Download {} failed.'.format(download_file))
84            print(traceback.format_exc())
85            exit(-1)
86        try:
87            tar = tarfile.open(download_file)
88            target_dir = os.path.splitext(download_file)[0]
89            if os.path.isdir(target_dir):
90                pass
91            else:
92                os.mkdir(target_dir)
93            tar.extractall(target_dir)
94            tar.close()
95        except:
96            print('Extract {} failed.'.format(download_file))
97            exit(-1)
98        pb_file = os.path.join(target_dir, PB_FILE)
99        if os.path.exists(pb_file):
100            shutil.copyfile(pb_file, './inception_v3_quant_frozen.pb')
101            shutil.rmtree(target_dir)
102            os.remove(download_file)
103        print('done')
104
105    # Pre-process config
106    print('--> Config model')
107    rknn.config(mean_values=[104, 117, 123], std_values=[128, 128, 128])
108    print('done')
109
110    # Load model
111    print('--> Loading model')
112    ret = rknn.load_tensorflow(tf_pb=PB_FILE,
113                               inputs=INPUTS,
114                               outputs=OUTPUTS,
115                               input_size_list=[[1, INPUT_SIZE, INPUT_SIZE, 3]])
116    if ret != 0:
117        print('Load model failed!')
118        exit(ret)
119    print('done')
120
121    # Build model
122    print('--> Building model')
123    ret = rknn.build(do_quantization=False)
124    if ret != 0:
125        print('Build model failed!')
126        exit(ret)
127    print('done')
128
129    # Export rknn model
130    print('--> Export rknn model')
131    ret = rknn.export_rknn(RKNN_MODEL_PATH)
132    if ret != 0:
133        print('Export rknn model failed!')
134        exit(ret)
135    print('done')
136
137    # Set inputs
138    img = cv2.imread(IMG_PATH)
139    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
140
141    # Init runtime environment
142    print('--> Init runtime environment')
143    ret = rknn.init_runtime()
144    if ret != 0:
145        print('Init runtime environment failed!')
146        exit(ret)
147    print('done')
148
149    # Inference
150    print('--> Running model')
151    outputs = rknn.inference(inputs=[img])
152    np.save('./tensorflow_inception_v3_qat_0.npy', outputs[0])
153    x = outputs[0]
154    output = np.exp(x)/np.sum(np.exp(x))
155    outputs = [output]
156    show_outputs(outputs)
157    print('done')
158
159    rknn.release()
160