1 package com.rockchip.gpadc.demo.yolo; 2 3 import android.graphics.RectF; 4 import android.util.Log; 5 6 import com.rockchip.gpadc.demo.InferenceResult; 7 import com.rockchip.gpadc.demo.InferenceResult.OutputBuffer; 8 import com.rockchip.gpadc.demo.InferenceResult.Recognition; 9 import com.rockchip.gpadc.demo.InferenceResult.DetectResultGroup; 10 11 import java.io.IOException; 12 import java.util.ArrayList; 13 14 /** 15 * Created by randall on 18-4-18. 16 */ 17 18 public class InferenceWrapper { 19 private final String TAG = "rkyolo.InferenceWrapper"; 20 21 static { 22 System.loadLibrary("rknn4j"); 23 } 24 25 OutputBuffer mOutputs; 26 ArrayList<Recognition> mRecognitions = new ArrayList<Recognition>(); 27 DetectResultGroup mDetectResults; 28 29 public int OBJ_NUMB_MAX_SIZE = 64; 30 // public int inf_count = 0; 31 // public int post_count = 0; 32 // public long inf_time = 0; 33 // public long post_time = 0; 34 35 InferenceWrapper()36 public InferenceWrapper() { 37 38 } 39 initModel(int im_height, int im_width, int im_channel, String modelPath)40 public int initModel(int im_height, int im_width, int im_channel, String modelPath) throws Exception { 41 mOutputs = new InferenceResult.OutputBuffer(); 42 mOutputs.mGrid0Out = new byte[255 * 80 * 80]; 43 mOutputs.mGrid1Out = new byte[255 * 40 * 40]; 44 mOutputs.mGrid2Out = new byte[255 * 20 * 20]; 45 if (navite_init(im_height, im_width, im_channel, modelPath) != 0) { 46 throw new IOException("rknn init fail!"); 47 } 48 return 0; 49 } 50 51 deinit()52 public void deinit() { 53 native_deinit(); 54 mOutputs.mGrid0Out = null; 55 mOutputs.mGrid1Out = null; 56 mOutputs.mGrid2Out = null; 57 mOutputs = null; 58 59 } 60 run(byte[] inData)61 public InferenceResult.OutputBuffer run(byte[] inData) { 62 // long startTime = System.currentTimeMillis(); 63 // long endTime; 64 native_run(inData, mOutputs.mGrid0Out, mOutputs.mGrid1Out, mOutputs.mGrid2Out); 65 // this.inf_count += 1; 66 // endTime = System.currentTimeMillis(); 67 // this.inf_time += (endTime - startTime); 68 // if (this.inf_count >= 100) { 69 // float inf_avg = this.inf_time * 1.0f / this.inf_count; 70 // Log.w(TAG, String.format("inference avg cost: %.5f ms", inf_avg)); 71 // this.inf_count = 0; 72 // this.inf_time = 0; 73 // } 74 // Log.i(TAG, String.format("inference count: %d", this.inf_count)); 75 return mOutputs; 76 } 77 postProcess(InferenceResult.OutputBuffer outputs)78 public ArrayList<InferenceResult.Recognition> postProcess(InferenceResult.OutputBuffer outputs) { 79 ArrayList<Recognition> recognitions = new ArrayList<Recognition>(); 80 81 mDetectResults = new DetectResultGroup(); 82 mDetectResults.count = 0; 83 mDetectResults.ids = new int[OBJ_NUMB_MAX_SIZE]; 84 mDetectResults.scores = new float[OBJ_NUMB_MAX_SIZE]; 85 mDetectResults.boxes = new float[4 * OBJ_NUMB_MAX_SIZE]; 86 87 if (null == outputs || null == outputs.mGrid0Out || null == outputs.mGrid1Out 88 || null == outputs.mGrid2Out) { 89 return recognitions; 90 } 91 92 // long startTime = System.currentTimeMillis(); 93 // long endTime; 94 int count = native_post_process(outputs.mGrid0Out, outputs.mGrid1Out, outputs.mGrid2Out, 95 mDetectResults.ids, mDetectResults.scores, mDetectResults.boxes); 96 if (count < 0) { 97 Log.w(TAG, "post_process may fail."); 98 mDetectResults.count = 0; 99 } else { 100 mDetectResults.count = count; 101 } 102 // Log.i(TAG, String.format("Detected %d objects", count)); 103 // this.post_count += 1; 104 // Log.i(TAG, String.format("post count: %d", this.post_count)); 105 106 for (int i = 0; i < count; ++i) { 107 RectF rect = new RectF(); 108 rect.left = mDetectResults.boxes[i*4+0]; 109 rect.top = mDetectResults.boxes[i*4+1]; 110 rect.right = mDetectResults.boxes[i*4+2]; 111 rect.bottom = mDetectResults.boxes[i*4+3]; 112 113 Recognition recog = new InferenceResult.Recognition(mDetectResults.ids[i], 114 mDetectResults.scores[i], rect); 115 recognitions.add(recog); 116 } 117 // endTime = System.currentTimeMillis(); 118 // this.post_time += (endTime - startTime); 119 // if (this.post_count >= 100) { 120 // float post_avg = this.post_time * 1.0f / this.post_count; 121 // Log.w(TAG, String.format("post process avg cost: %.5f ms", post_avg)); 122 // this.post_time = 0; 123 // this.post_count = 0; 124 // } 125 126 return recognitions; 127 } 128 navite_init(int im_height, int im_width, int im_channel, String modelPath)129 private native int navite_init(int im_height, int im_width, int im_channel, String modelPath); native_deinit()130 private native void native_deinit(); native_run(byte[] inData, byte[] grid0Out, byte[] grid1Out, byte[] grid2Out)131 private native int native_run(byte[] inData, byte[] grid0Out, byte[] grid1Out, byte[] grid2Out); native_post_process(byte[] grid0Out, byte[] grid1Out, byte[] grid2Out, int[] ids, float[] scores, float[] boxes)132 private native int native_post_process(byte[] grid0Out, byte[] grid1Out, byte[] grid2Out, 133 int[] ids, float[] scores, float[] boxes); 134 135 }