1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <float.h>
17
18#include "tensorflow/examples/android/jni/object_tracking/config.h"
19#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
20
21namespace tf_tracking {
22
23void FramePair::Init(const int64_t start_time, const int64_t end_time) {
24  start_time_ = start_time;
25  end_time_ = end_time;
26  memset(optical_flow_found_keypoint_, false,
27         sizeof(*optical_flow_found_keypoint_) * kMaxKeypoints);
28  number_of_keypoints_ = 0;
29}
30
31void FramePair::AdjustBox(const BoundingBox box,
32                          float* const translation_x,
33                          float* const translation_y,
34                          float* const scale_x,
35                          float* const scale_y) const {
36  static float weights[kMaxKeypoints];
37  static Point2f deltas[kMaxKeypoints];
38  memset(weights, 0.0f, sizeof(*weights) * kMaxKeypoints);
39
40  BoundingBox resized_box(box);
41  resized_box.Scale(0.4f, 0.4f);
42  FillWeights(resized_box, weights);
43  FillTranslations(deltas);
44
45  const Point2f translation = GetWeightedMedian(weights, deltas);
46
47  *translation_x = translation.x;
48  *translation_y = translation.y;
49
50  const Point2f old_center = box.GetCenter();
51  const int good_scale_points =
52      FillScales(old_center, translation, weights, deltas);
53
54  // Default scale factor is 1 for x and y.
55  *scale_x = 1.0f;
56  *scale_y = 1.0f;
57
58  // The assumption is that all deltas that make it to this stage with a
59  // correspondending optical_flow_found_keypoint_[i] == true are not in
60  // themselves degenerate.
61  //
62  // The degeneracy with scale arose because if the points are too close to the
63  // center of the objects, the scale ratio determination might be incalculable.
64  //
65  // The check for kMinNumInRange is not a degeneracy check, but merely an
66  // attempt to ensure some sort of stability. The actual degeneracy check is in
67  // the comparison to EPSILON in FillScales (which I've updated to return the
68  // number good remaining as well).
69  static const int kMinNumInRange = 5;
70  if (good_scale_points >= kMinNumInRange) {
71    const float scale_factor = GetWeightedMedianScale(weights, deltas);
72
73    if (scale_factor > 0.0f) {
74      *scale_x = scale_factor;
75      *scale_y = scale_factor;
76    }
77  }
78}
79
80int FramePair::FillWeights(const BoundingBox& box,
81                           float* const weights) const {
82  // Compute the max score.
83  float max_score = -FLT_MAX;
84  float min_score = FLT_MAX;
85  for (int i = 0; i < kMaxKeypoints; ++i) {
86    if (optical_flow_found_keypoint_[i]) {
87      max_score = MAX(max_score, frame1_keypoints_[i].score_);
88      min_score = MIN(min_score, frame1_keypoints_[i].score_);
89    }
90  }
91
92  int num_in_range = 0;
93  for (int i = 0; i < kMaxKeypoints; ++i) {
94    if (!optical_flow_found_keypoint_[i]) {
95      weights[i] = 0.0f;
96      continue;
97    }
98
99    const bool in_box = box.Contains(frame1_keypoints_[i].pos_);
100    if (in_box) {
101      ++num_in_range;
102    }
103
104    // The weighting based off distance.  Anything within the bounding box
105    // has a weight of 1, and everything outside of that is within the range
106    // [0, kOutOfBoxMultiplier), falling off with the squared distance ratio.
107    float distance_score = 1.0f;
108    if (!in_box) {
109      const Point2f initial = box.GetCenter();
110      const float sq_x_dist =
111          Square(initial.x - frame1_keypoints_[i].pos_.x);
112      const float sq_y_dist =
113          Square(initial.y - frame1_keypoints_[i].pos_.y);
114      const float squared_half_width = Square(box.GetWidth() / 2.0f);
115      const float squared_half_height = Square(box.GetHeight() / 2.0f);
116
117      static const float kOutOfBoxMultiplier = 0.5f;
118      distance_score = kOutOfBoxMultiplier *
119          MIN(squared_half_height / sq_y_dist, squared_half_width / sq_x_dist);
120    }
121
122    // The weighting based on relative score strength. kBaseScore - 1.0f.
123    float intrinsic_score =  1.0f;
124    if (max_score > min_score) {
125      static const float kBaseScore = 0.5f;
126      intrinsic_score = ((frame1_keypoints_[i].score_ - min_score) /
127         (max_score - min_score)) * (1.0f - kBaseScore) + kBaseScore;
128    }
129
130    // The final score will be in the range [0, 1].
131    weights[i] = distance_score * intrinsic_score;
132  }
133
134  return num_in_range;
135}
136
137void FramePair::FillTranslations(Point2f* const translations) const {
138  for (int i = 0; i < kMaxKeypoints; ++i) {
139    if (!optical_flow_found_keypoint_[i]) {
140      continue;
141    }
142    translations[i].x =
143        frame2_keypoints_[i].pos_.x - frame1_keypoints_[i].pos_.x;
144    translations[i].y =
145        frame2_keypoints_[i].pos_.y - frame1_keypoints_[i].pos_.y;
146  }
147}
148
149int FramePair::FillScales(const Point2f& old_center,
150                          const Point2f& translation,
151                          float* const weights,
152                          Point2f* const scales) const {
153  int num_good = 0;
154  for (int i = 0; i < kMaxKeypoints; ++i) {
155    if (!optical_flow_found_keypoint_[i]) {
156      continue;
157    }
158
159    const Keypoint keypoint1 = frame1_keypoints_[i];
160    const Keypoint keypoint2 = frame2_keypoints_[i];
161
162    const float dist1_x = keypoint1.pos_.x - old_center.x;
163    const float dist1_y = keypoint1.pos_.y - old_center.y;
164
165    const float dist2_x = (keypoint2.pos_.x - translation.x) - old_center.x;
166    const float dist2_y = (keypoint2.pos_.y - translation.y) - old_center.y;
167
168    // Make sure that the scale makes sense; points too close to the center
169    // will result in either NaNs or infinite results for scale due to
170    // limited tracking and floating point resolution.
171    // Also check that the parity of the points is the same with respect to
172    // x and y, as we can't really make sense of data that has flipped.
173    if (((dist2_x > EPSILON && dist1_x > EPSILON) ||
174         (dist2_x < -EPSILON && dist1_x < -EPSILON)) &&
175         ((dist2_y > EPSILON && dist1_y > EPSILON) ||
176          (dist2_y < -EPSILON && dist1_y < -EPSILON))) {
177      scales[i].x = dist2_x / dist1_x;
178      scales[i].y = dist2_y / dist1_y;
179      ++num_good;
180    } else {
181      weights[i] = 0.0f;
182      scales[i].x = 1.0f;
183      scales[i].y = 1.0f;
184    }
185  }
186  return num_good;
187}
188
189struct WeightedDelta {
190  float weight;
191  float delta;
192};
193
194// Sort by delta, not by weight.
195inline int WeightedDeltaCompare(const void* const a, const void* const b) {
196  return (reinterpret_cast<const WeightedDelta*>(a)->delta -
197          reinterpret_cast<const WeightedDelta*>(b)->delta) <= 0 ? 1 : -1;
198}
199
200// Returns the median delta from a sorted set of weighted deltas.
201static float GetMedian(const int num_items,
202                       const WeightedDelta* const weighted_deltas,
203                       const float sum) {
204  if (num_items == 0 || sum < EPSILON) {
205    return 0.0f;
206  }
207
208  float current_weight = 0.0f;
209  const float target_weight = sum / 2.0f;
210  for (int i = 0; i < num_items; ++i) {
211    if (weighted_deltas[i].weight > 0.0f) {
212      current_weight += weighted_deltas[i].weight;
213      if (current_weight >= target_weight) {
214        return weighted_deltas[i].delta;
215      }
216    }
217  }
218  LOGW("Median not found! %d points, sum of %.2f", num_items, sum);
219  return 0.0f;
220}
221
222Point2f FramePair::GetWeightedMedian(
223    const float* const weights, const Point2f* const deltas) const {
224  Point2f median_delta;
225
226  // TODO(andrewharp): only sort deltas that could possibly have an effect.
227  static WeightedDelta weighted_deltas[kMaxKeypoints];
228
229  // Compute median X value.
230  {
231    float total_weight = 0.0f;
232
233    // Compute weighted mean and deltas.
234    for (int i = 0; i < kMaxKeypoints; ++i) {
235      weighted_deltas[i].delta = deltas[i].x;
236      const float weight = weights[i];
237      weighted_deltas[i].weight = weight;
238      if (weight > 0.0f) {
239        total_weight += weight;
240      }
241    }
242    qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
243          WeightedDeltaCompare);
244    median_delta.x = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
245  }
246
247  // Compute median Y value.
248  {
249    float total_weight = 0.0f;
250
251    // Compute weighted mean and deltas.
252    for (int i = 0; i < kMaxKeypoints; ++i) {
253      const float weight = weights[i];
254      weighted_deltas[i].weight = weight;
255      weighted_deltas[i].delta = deltas[i].y;
256      if (weight > 0.0f) {
257        total_weight += weight;
258      }
259    }
260    qsort(weighted_deltas, kMaxKeypoints, sizeof(WeightedDelta),
261          WeightedDeltaCompare);
262    median_delta.y = GetMedian(kMaxKeypoints, weighted_deltas, total_weight);
263  }
264
265  return median_delta;
266}
267
268float FramePair::GetWeightedMedianScale(
269    const float* const weights, const Point2f* const deltas) const {
270  float median_delta;
271
272  // TODO(andrewharp): only sort deltas that could possibly have an effect.
273  static WeightedDelta weighted_deltas[kMaxKeypoints * 2];
274
275  // Compute median scale value across x and y.
276  {
277    float total_weight = 0.0f;
278
279    // Add X values.
280    for (int i = 0; i < kMaxKeypoints; ++i) {
281      weighted_deltas[i].delta = deltas[i].x;
282      const float weight = weights[i];
283      weighted_deltas[i].weight = weight;
284      if (weight > 0.0f) {
285        total_weight += weight;
286      }
287    }
288
289    // Add Y values.
290    for (int i = 0; i < kMaxKeypoints; ++i) {
291      weighted_deltas[i + kMaxKeypoints].delta = deltas[i].y;
292      const float weight = weights[i];
293      weighted_deltas[i + kMaxKeypoints].weight = weight;
294      if (weight > 0.0f) {
295        total_weight += weight;
296      }
297    }
298
299    qsort(weighted_deltas, kMaxKeypoints * 2, sizeof(WeightedDelta),
300          WeightedDeltaCompare);
301
302    median_delta = GetMedian(kMaxKeypoints * 2, weighted_deltas, total_weight);
303  }
304
305  return median_delta;
306}
307
308}  // namespace tf_tracking
309