1 package com.rockchip.gpadc.demo.yolo;
2 
3 import android.graphics.RectF;
4 import android.util.Log;
5 
6 import com.rockchip.gpadc.demo.InferenceResult;
7 import com.rockchip.gpadc.demo.InferenceResult.OutputBuffer;
8 import com.rockchip.gpadc.demo.InferenceResult.Recognition;
9 import com.rockchip.gpadc.demo.InferenceResult.DetectResultGroup;
10 
11 import java.io.IOException;
12 import java.util.ArrayList;
13 
14 /**
15  * Created by randall on 18-4-18.
16  */
17 
18 public class InferenceWrapper {
19     private final String TAG = "rkyolo.InferenceWrapper";
20 
21     static {
22         System.loadLibrary("rknn4j");
23     }
24 
25     OutputBuffer mOutputs;
26     ArrayList<Recognition> mRecognitions = new ArrayList<Recognition>();
27     DetectResultGroup mDetectResults;
28 
29     public int OBJ_NUMB_MAX_SIZE = 64;
30 //    public int inf_count = 0;
31 //    public int post_count = 0;
32 //    public long inf_time = 0;
33 //    public long post_time = 0;
34 
35 
InferenceWrapper()36     public InferenceWrapper() {
37 
38     }
39 
initModel(int im_height, int im_width, int im_channel, String modelPath)40     public int initModel(int im_height, int im_width, int im_channel, String modelPath) throws Exception {
41         mOutputs = new InferenceResult.OutputBuffer();
42         mOutputs.mGrid0Out = new byte[255 * 80 * 80];
43         mOutputs.mGrid1Out = new byte[255 * 40 * 40];
44         mOutputs.mGrid2Out = new byte[255 * 20 * 20];
45         if (navite_init(im_height, im_width, im_channel, modelPath) != 0) {
46             throw new IOException("rknn init fail!");
47         }
48         return 0;
49     }
50 
51 
deinit()52     public void deinit() {
53         native_deinit();
54         mOutputs.mGrid0Out = null;
55         mOutputs.mGrid1Out = null;
56         mOutputs.mGrid2Out = null;
57         mOutputs = null;
58 
59     }
60 
run(byte[] inData)61     public InferenceResult.OutputBuffer run(byte[] inData) {
62 //        long startTime = System.currentTimeMillis();
63 //        long endTime;
64         native_run(inData, mOutputs.mGrid0Out, mOutputs.mGrid1Out, mOutputs.mGrid2Out);
65 //        this.inf_count += 1;
66 //        endTime = System.currentTimeMillis();
67 //        this.inf_time += (endTime - startTime);
68 //        if (this.inf_count >= 100) {
69 //            float inf_avg = this.inf_time * 1.0f / this.inf_count;
70 //            Log.w(TAG, String.format("inference avg cost: %.5f ms", inf_avg));
71 //            this.inf_count = 0;
72 //            this.inf_time = 0;
73 //        }
74 //        Log.i(TAG, String.format("inference count: %d", this.inf_count));
75         return  mOutputs;
76     }
77 
postProcess(InferenceResult.OutputBuffer outputs)78     public ArrayList<InferenceResult.Recognition> postProcess(InferenceResult.OutputBuffer outputs) {
79         ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
80 
81         mDetectResults = new DetectResultGroup();
82         mDetectResults.count = 0;
83         mDetectResults.ids = new int[OBJ_NUMB_MAX_SIZE];
84         mDetectResults.scores = new float[OBJ_NUMB_MAX_SIZE];
85         mDetectResults.boxes = new float[4 * OBJ_NUMB_MAX_SIZE];
86 
87         if (null == outputs || null == outputs.mGrid0Out || null == outputs.mGrid1Out
88                 || null == outputs.mGrid2Out) {
89             return recognitions;
90         }
91 
92 //        long startTime = System.currentTimeMillis();
93 //        long endTime;
94         int count = native_post_process(outputs.mGrid0Out, outputs.mGrid1Out, outputs.mGrid2Out,
95                 mDetectResults.ids, mDetectResults.scores, mDetectResults.boxes);
96         if (count < 0) {
97             Log.w(TAG, "post_process may fail.");
98             mDetectResults.count = 0;
99         } else {
100             mDetectResults.count = count;
101         }
102 //        Log.i(TAG, String.format("Detected %d objects", count));
103 //        this.post_count += 1;
104 //        Log.i(TAG, String.format("post count: %d", this.post_count));
105 
106         for (int i = 0; i < count; ++i) {
107             RectF rect = new RectF();
108             rect.left = mDetectResults.boxes[i*4+0];
109             rect.top = mDetectResults.boxes[i*4+1];
110             rect.right = mDetectResults.boxes[i*4+2];
111             rect.bottom = mDetectResults.boxes[i*4+3];
112 
113             Recognition recog = new InferenceResult.Recognition(mDetectResults.ids[i],
114                     mDetectResults.scores[i], rect);
115             recognitions.add(recog);
116         }
117 //        endTime = System.currentTimeMillis();
118 //        this.post_time += (endTime - startTime);
119 //        if (this.post_count >= 100) {
120 //            float post_avg = this.post_time * 1.0f / this.post_count;
121 //            Log.w(TAG, String.format("post process avg cost: %.5f ms", post_avg));
122 //            this.post_time = 0;
123 //            this.post_count = 0;
124 //        }
125 
126         return recognitions;
127     }
128 
navite_init(int im_height, int im_width, int im_channel, String modelPath)129     private native int navite_init(int im_height, int im_width, int im_channel, String modelPath);
native_deinit()130     private native void native_deinit();
native_run(byte[] inData, byte[] grid0Out, byte[] grid1Out, byte[] grid2Out)131     private native int native_run(byte[] inData, byte[] grid0Out, byte[] grid1Out, byte[] grid2Out);
native_post_process(byte[] grid0Out, byte[] grid1Out, byte[] grid2Out, int[] ids, float[] scores, float[] boxes)132     private native int native_post_process(byte[] grid0Out, byte[] grid1Out, byte[] grid2Out,
133                                            int[] ids, float[] scores, float[] boxes);
134 
135 }