1 package com.rockchip.gpadc.demo;
2 
3 import android.content.res.AssetManager;
4 import android.graphics.RectF;
5 
6 import com.rockchip.gpadc.demo.yolo.InferenceWrapper;
7 import com.rockchip.gpadc.demo.yolo.PostProcess;
8 import com.rockchip.gpadc.demo.tracker.ObjectTracker;
9 
10 import java.io.IOException;
11 import java.util.ArrayList;
12 
13 import static com.rockchip.gpadc.demo.yolo.PostProcess.INPUT_SIZE;
14 import static com.rockchip.gpadc.demo.rga.HALDefine.CAMERA_PREVIEW_WIDTH;
15 import static com.rockchip.gpadc.demo.rga.HALDefine.CAMERA_PREVIEW_HEIGHT;
16 import static java.lang.System.arraycopy;
17 
18 public class InferenceResult {
19 
20     OutputBuffer mOutputBuffer;
21     ArrayList<Recognition> recognitions = null;
22     private boolean mIsVaild = false;   //是否需要重新计算
23     PostProcess mPostProcess = new PostProcess();
24     private ObjectTracker mSSDObjectTracker;
25 
init(AssetManager assetManager)26     public void init(AssetManager assetManager) throws IOException {
27         mOutputBuffer = new OutputBuffer();
28 
29         mPostProcess.init(assetManager);
30 
31 //        mSSDObjectTracker = new ObjectTracker(CAMERA_PREVIEW_WIDTH, CAMERA_PREVIEW_HEIGHT, 3);
32     }
33 
reset()34     public void reset() {
35         if (recognitions != null) {
36             recognitions.clear();
37             mIsVaild = true;
38         }
39         mSSDObjectTracker = new ObjectTracker(CAMERA_PREVIEW_WIDTH, CAMERA_PREVIEW_HEIGHT, 3);
40     }
setResult(OutputBuffer outputs)41     public synchronized void setResult(OutputBuffer outputs) {
42 
43         if (mOutputBuffer.mGrid0Out == null) {
44             mOutputBuffer.mGrid0Out = outputs.mGrid0Out.clone();
45             mOutputBuffer.mGrid1Out = outputs.mGrid1Out.clone();
46             mOutputBuffer.mGrid2Out = outputs.mGrid2Out.clone();
47         } else {
48             arraycopy(outputs.mGrid0Out, 0, mOutputBuffer.mGrid0Out, 0,
49                     outputs.mGrid0Out.length);
50             arraycopy(outputs.mGrid1Out, 0, mOutputBuffer.mGrid1Out, 0,
51                     outputs.mGrid1Out.length);
52             arraycopy(outputs.mGrid2Out, 0, mOutputBuffer.mGrid2Out, 0,
53                     outputs.mGrid2Out.length);
54         }
55         mIsVaild = false;
56     }
57 
getResult(InferenceWrapper mInferenceWrapper)58     public synchronized ArrayList<Recognition> getResult(InferenceWrapper mInferenceWrapper) {
59         if (!mIsVaild) {
60             mIsVaild = true;
61 
62             recognitions = mInferenceWrapper.postProcess(mOutputBuffer);
63 
64             recognitions = mSSDObjectTracker.tracker(recognitions);
65         }
66 
67         return recognitions;
68     }
69 
70     public static class OutputBuffer {
71         public byte[] mGrid0Out;
72         public byte[] mGrid1Out;
73         public byte[] mGrid2Out;
74     }
75 
76     /**
77      * An immutable result returned by a Classifier describing what was recognized.
78      */
79     public static class Recognition {
80 
81         private int trackId = 0;
82 
83         /**
84          * A unique identifier for what has been recognized. Specific to the class, not the instance of
85          * the object.
86          */
87         private final int id;
88 
89         /**
90          * A sortable score for how good the recognition is relative to others. Higher should be better.
91          */
92         private final Float confidence;
93 
94         /** Optional location within the source image for the location of the recognized object. */
95         private RectF location;
96 
Recognition( final int id, final Float confidence, final RectF location)97         public Recognition(
98                 final int id, final Float confidence, final RectF location) {
99             this.id = id;
100             this.confidence = confidence;
101             this.location = location;
102             // TODO -- add name field, and show it.
103         }
104 
getId()105         public int getId() {
106             return id;
107         }
108 
getConfidence()109         public Float getConfidence() {
110             return confidence;
111         }
112 
getLocation()113         public RectF getLocation() {
114             return new RectF(location);
115         }
116 
setLocation(RectF location)117         public void setLocation(RectF location) {
118             this.location = location;
119         }
120 
setTrackId(int trackId)121         public void setTrackId(int trackId) {
122             this.trackId = trackId;
123         }
124 
getTrackId()125         public int getTrackId() {
126             return this.trackId;
127         }
128 
129         @Override
toString()130         public String toString() {
131             String resultString = "";
132 
133             resultString += "[" + id + "] ";
134 
135             if (confidence != null) {
136                 resultString += String.format("(%.1f%%) ", confidence * 100.0f);
137             }
138 
139             if (location != null) {
140                 resultString += location + " ";
141             }
142 
143             return resultString.trim();
144         }
145     }
146 
147     /**
148      * Detected objects, returned from native yolo_post_process
149      */
150     public static class DetectResultGroup {
151         /**
152          * detected objects count.
153          */
154         public int count = 0;
155 
156         /**
157          * id for each detected object.
158          */
159         public int[] ids;
160 
161         /**
162          * score for each detected object.
163          */
164         public float[] scores;
165 
166         /**
167          * box for each detected object.
168          */
169         public float[] boxes;
170 
171 //        public DetectResultGroup(
172 //                int count, int[] ids, float[] scores, float[] boxes
173 //        ) {
174 //            this.count = count;
175 //            this.ids = ids;
176 //            this.scores = scores;
177 //            this.boxes = boxes;
178 //        }
179 //
180 //        public int getCount() {
181 //            return count;
182 //        }
183 //
184 //        public void setCount(int count) {
185 //            this.count = count;
186 //        }
187 //
188 //        public int[] getIds() {
189 //            return ids;
190 //        }
191 //
192 //        public void setIds(int[] ids) {
193 //            this.ids = ids;
194 //        }
195 //
196 //        public float[] getScores() {
197 //            return scores;
198 //        }
199 //
200 //        public void setScores(float[] scores) {
201 //            this.scores = scores;
202 //        }
203 //
204 //        public float[] getBoxes() {
205 //            return boxes;
206 //        }
207 //
208 //        public void setBoxes(float[] boxes) {
209 //            this.boxes = boxes;
210 //        }
211     }
212 }
213