1 package com.rockchip.gpadc.demo.tracker;
2 
3 import android.graphics.RectF;
4 import android.util.Log;
5 
6 import com.rockchip.gpadc.demo.InferenceResult;
7 
8 import java.util.ArrayList;
9 
10 public class ObjectTracker {
11 
12     private final String TAG = "rkyolo.ObjectTracker";
13 
14     private long mHandle;
15 
16     private int mMaxTrackLifetime = 3;
17 
18     private int mWidth;
19 
20     private int mHeight;
21 
22     private static int MAX_TRACKED_NUM = 64;
23 
24     private static float[] track_input_location = new float[MAX_TRACKED_NUM *4];
25     private static int[] track_input_class = new int[MAX_TRACKED_NUM];
26     private static float[] track_input_score = new float[MAX_TRACKED_NUM];
27     private static float[] track_output_location = new float[MAX_TRACKED_NUM *4];
28     private static int[] track_output_class = new int[MAX_TRACKED_NUM];
29     private static int[] track_output_id = new int[MAX_TRACKED_NUM];
30     private static float[] track_output_score = new float[MAX_TRACKED_NUM];
31 
32 //    public int track_count = 0;
33 //    public long track_time = 0;
34 
ObjectTracker(int width, int height, int maxTrackLifetime)35     public ObjectTracker(int width, int height, int maxTrackLifetime) {
36         mWidth = width;
37         mHeight = height;
38         mMaxTrackLifetime = maxTrackLifetime;
39         mHandle = native_create();
40     }
41 
finalize()42     protected void finalize() {
43         native_destroy(mHandle);
44     }
45 
tracker(ArrayList<InferenceResult.Recognition> recognitions)46     public ArrayList<InferenceResult.Recognition> tracker(ArrayList<InferenceResult.Recognition> recognitions) {
47 //        long startTime = System.currentTimeMillis();
48 //        long endTime;
49         int track_input_num = 0;
50         ArrayList<InferenceResult.Recognition> tracked_recognitions = new ArrayList<>();
51 
52         for (int i = 0; i < recognitions.size(); ++i) {
53 
54             track_input_location[4*track_input_num +0] = recognitions.get(i).getLocation().left;
55             track_input_location[4*track_input_num +1] = recognitions.get(i).getLocation().top;
56             track_input_location[4*track_input_num +2] = recognitions.get(i).getLocation().right;
57             track_input_location[4*track_input_num +3] = recognitions.get(i).getLocation().bottom;
58             track_input_class[track_input_num] = recognitions.get(i).getId();
59             track_input_score[track_input_num] = recognitions.get(i).getConfidence();
60             //Log.i(TAG, track_input_num +" javain class:" +topClassScoreIndex +" P:" +track_input_score[track_input_num] +" score:" +expit(track_input_score[track_input_num]));
61             track_input_num++;
62             if (track_input_num >= MAX_TRACKED_NUM){
63                 break;
64             }
65         }
66 
67         int[] track_output_num = new int[1];
68 
69         native_track(mHandle, mMaxTrackLifetime,
70                 track_input_num, track_input_location, track_input_class, track_input_score,
71                 track_output_num, track_output_location, track_output_class, track_output_score,
72                 track_output_id, mWidth, mHeight);
73 
74         for (int i = 0; i < track_output_num[0]; ++i) {
75 
76             RectF detection = new RectF(
77                             track_output_location[i * 4 + 0]/mWidth,
78                             track_output_location[i * 4 + 1]/mHeight,
79                             track_output_location[i * 4 + 2]/mWidth,
80                             track_output_location[i * 4 + 3]/mHeight);
81             float exp_score =  track_output_score[i];
82             if (track_output_score[i] == -10000){
83                 exp_score = 0;
84             }
85             InferenceResult.Recognition recog = new InferenceResult.Recognition(
86                     track_output_class[i],
87                     exp_score,
88                     detection);
89             recog.setTrackId(track_output_id[i]);
90             //Log.i(TAG, "javaout"+i +" class:" +topClassScoreIndex +" P:" +track_output_score[i] +" score:" +exp_score);
91             tracked_recognitions.add(recog);
92         }
93 //        endTime = System.currentTimeMillis();
94 //        this.track_count += 1;
95 //        this.track_time += (endTime - startTime);
96 //        if (this.track_count >= 100) {
97 //            float track_avg = this.track_time * 1.0f / this.track_count;
98 //            Log.i(TAG, String.format("track cost time avg: %.5f", track_avg));
99 //            this.track_count = 0;
100 //            this.track_time = 0;
101 //        }
102         return tracked_recognitions;
103     }
104 
native_create()105     private native long native_create();
native_destroy(long handle)106     private native void native_destroy(long handle);
native_track(long hanle, int maxTrackLifetime, int track_input_num, float[] track_input_locations, int[] track_input_class, float[] track_input_score, int[] track_output_num, float[] track_output_locations, int[] track_output_class, float[] track_output_score, int[] track_output_id, int width, int height)107     private native void native_track(long hanle, int maxTrackLifetime,
108                                      int track_input_num, float[] track_input_locations, int[] track_input_class, float[] track_input_score,
109                                     int[] track_output_num, float[] track_output_locations, int[] track_output_class, float[] track_output_score,
110                                      int[] track_output_id, int width, int height);
111 
112     static {
113         System.loadLibrary("rknn4j");
114     }
115 }
116