1 package com.rockchip.gpadc.demo; 2 3 import android.content.res.AssetManager; 4 import android.graphics.RectF; 5 6 import com.rockchip.gpadc.demo.yolo.InferenceWrapper; 7 import com.rockchip.gpadc.demo.yolo.PostProcess; 8 import com.rockchip.gpadc.demo.tracker.ObjectTracker; 9 10 import java.io.IOException; 11 import java.util.ArrayList; 12 13 import static com.rockchip.gpadc.demo.yolo.PostProcess.INPUT_SIZE; 14 import static com.rockchip.gpadc.demo.rga.HALDefine.CAMERA_PREVIEW_WIDTH; 15 import static com.rockchip.gpadc.demo.rga.HALDefine.CAMERA_PREVIEW_HEIGHT; 16 import static java.lang.System.arraycopy; 17 18 public class InferenceResult { 19 20 OutputBuffer mOutputBuffer; 21 ArrayList<Recognition> recognitions = null; 22 private boolean mIsVaild = false; //是否需要重新计算 23 PostProcess mPostProcess = new PostProcess(); 24 private ObjectTracker mSSDObjectTracker; 25 init(AssetManager assetManager)26 public void init(AssetManager assetManager) throws IOException { 27 mOutputBuffer = new OutputBuffer(); 28 29 mPostProcess.init(assetManager); 30 31 // mSSDObjectTracker = new ObjectTracker(CAMERA_PREVIEW_WIDTH, CAMERA_PREVIEW_HEIGHT, 3); 32 } 33 reset()34 public void reset() { 35 if (recognitions != null) { 36 recognitions.clear(); 37 mIsVaild = true; 38 } 39 mSSDObjectTracker = new ObjectTracker(CAMERA_PREVIEW_WIDTH, CAMERA_PREVIEW_HEIGHT, 3); 40 } setResult(OutputBuffer outputs)41 public synchronized void setResult(OutputBuffer outputs) { 42 43 if (mOutputBuffer.mGrid0Out == null) { 44 mOutputBuffer.mGrid0Out = outputs.mGrid0Out.clone(); 45 mOutputBuffer.mGrid1Out = outputs.mGrid1Out.clone(); 46 mOutputBuffer.mGrid2Out = outputs.mGrid2Out.clone(); 47 } else { 48 arraycopy(outputs.mGrid0Out, 0, mOutputBuffer.mGrid0Out, 0, 49 outputs.mGrid0Out.length); 50 arraycopy(outputs.mGrid1Out, 0, mOutputBuffer.mGrid1Out, 0, 51 outputs.mGrid1Out.length); 52 arraycopy(outputs.mGrid2Out, 0, mOutputBuffer.mGrid2Out, 0, 53 outputs.mGrid2Out.length); 54 } 55 mIsVaild = false; 56 } 57 getResult(InferenceWrapper mInferenceWrapper)58 public synchronized ArrayList<Recognition> getResult(InferenceWrapper mInferenceWrapper) { 59 if (!mIsVaild) { 60 mIsVaild = true; 61 62 recognitions = mInferenceWrapper.postProcess(mOutputBuffer); 63 64 recognitions = mSSDObjectTracker.tracker(recognitions); 65 } 66 67 return recognitions; 68 } 69 70 public static class OutputBuffer { 71 public byte[] mGrid0Out; 72 public byte[] mGrid1Out; 73 public byte[] mGrid2Out; 74 } 75 76 /** 77 * An immutable result returned by a Classifier describing what was recognized. 78 */ 79 public static class Recognition { 80 81 private int trackId = 0; 82 83 /** 84 * A unique identifier for what has been recognized. Specific to the class, not the instance of 85 * the object. 86 */ 87 private final int id; 88 89 /** 90 * A sortable score for how good the recognition is relative to others. Higher should be better. 91 */ 92 private final Float confidence; 93 94 /** Optional location within the source image for the location of the recognized object. */ 95 private RectF location; 96 Recognition( final int id, final Float confidence, final RectF location)97 public Recognition( 98 final int id, final Float confidence, final RectF location) { 99 this.id = id; 100 this.confidence = confidence; 101 this.location = location; 102 // TODO -- add name field, and show it. 103 } 104 getId()105 public int getId() { 106 return id; 107 } 108 getConfidence()109 public Float getConfidence() { 110 return confidence; 111 } 112 getLocation()113 public RectF getLocation() { 114 return new RectF(location); 115 } 116 setLocation(RectF location)117 public void setLocation(RectF location) { 118 this.location = location; 119 } 120 setTrackId(int trackId)121 public void setTrackId(int trackId) { 122 this.trackId = trackId; 123 } 124 getTrackId()125 public int getTrackId() { 126 return this.trackId; 127 } 128 129 @Override toString()130 public String toString() { 131 String resultString = ""; 132 133 resultString += "[" + id + "] "; 134 135 if (confidence != null) { 136 resultString += String.format("(%.1f%%) ", confidence * 100.0f); 137 } 138 139 if (location != null) { 140 resultString += location + " "; 141 } 142 143 return resultString.trim(); 144 } 145 } 146 147 /** 148 * Detected objects, returned from native yolo_post_process 149 */ 150 public static class DetectResultGroup { 151 /** 152 * detected objects count. 153 */ 154 public int count = 0; 155 156 /** 157 * id for each detected object. 158 */ 159 public int[] ids; 160 161 /** 162 * score for each detected object. 163 */ 164 public float[] scores; 165 166 /** 167 * box for each detected object. 168 */ 169 public float[] boxes; 170 171 // public DetectResultGroup( 172 // int count, int[] ids, float[] scores, float[] boxes 173 // ) { 174 // this.count = count; 175 // this.ids = ids; 176 // this.scores = scores; 177 // this.boxes = boxes; 178 // } 179 // 180 // public int getCount() { 181 // return count; 182 // } 183 // 184 // public void setCount(int count) { 185 // this.count = count; 186 // } 187 // 188 // public int[] getIds() { 189 // return ids; 190 // } 191 // 192 // public void setIds(int[] ids) { 193 // this.ids = ids; 194 // } 195 // 196 // public float[] getScores() { 197 // return scores; 198 // } 199 // 200 // public void setScores(float[] scores) { 201 // this.scores = scores; 202 // } 203 // 204 // public float[] getBoxes() { 205 // return boxes; 206 // } 207 // 208 // public void setBoxes(float[] boxes) { 209 // this.boxes = boxes; 210 // } 211 } 212 } 213