optical_flow.cc revision a30b9926bd7d5276d6ff35af9428dee3e77b7dcb
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <math.h>
17
18#include "tensorflow/examples/android/jni/object_tracking/geom.h"
19#include "tensorflow/examples/android/jni/object_tracking/image-inl.h"
20#include "tensorflow/examples/android/jni/object_tracking/image.h"
21#include "tensorflow/examples/android/jni/object_tracking/time_log.h"
22#include "tensorflow/examples/android/jni/object_tracking/utils.h"
23
24#include "tensorflow/examples/android/jni/object_tracking/config.h"
25#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h"
26#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h"
27#include "tensorflow/examples/android/jni/object_tracking/image_data.h"
28#include "tensorflow/examples/android/jni/object_tracking/keypoint.h"
29#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h"
30#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h"
31
32namespace tf_tracking {
33
34OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config)
35    : config_(config),
36      frame1_(NULL),
37      frame2_(NULL),
38      working_size_(config->image_size) {}
39
40
41void OpticalFlow::NextFrame(const ImageData* const image_data) {
42  // Special case for the first frame: make sure the image ends up in
43  // frame1_ so that keypoint detection can be done on it if desired.
44  frame1_ = (frame1_ == NULL) ? image_data : frame2_;
45  frame2_ = image_data;
46}
47
48
49// Static heart of the optical flow computation.
50// Lucas Kanade algorithm.
51bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8>& img_I,
52                                     const Image<uint8>& img_J,
53                                     const Image<int32>& I_x,
54                                     const Image<int32>& I_y,
55                                     const float p_x,
56                                     const float p_y,
57                                     float* out_g_x,
58                                     float* out_g_y) {
59  float g_x = *out_g_x;
60  float g_y = *out_g_y;
61  // Get values for frame 1.  They remain constant through the inner
62  // iteration loop.
63  float vals_I[kFlowArraySize];
64  float vals_I_x[kFlowArraySize];
65  float vals_I_y[kFlowArraySize];
66
67  const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1;
68  const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize);
69
70#if USE_FIXED_POINT_FLOW
71  const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1;
72  const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1;
73#else
74  const float real_x_max = I_x.width_less_one_ - EPSILON;
75  const float real_y_max = I_x.height_less_one_ - EPSILON;
76#endif
77
78  // Get the window around the original point.
79  const float src_left_real = p_x - kWindowSizeFloat;
80  const float src_top_real = p_y - kWindowSizeFloat;
81  float* vals_I_ptr = vals_I;
82  float* vals_I_x_ptr = vals_I_x;
83  float* vals_I_y_ptr = vals_I_y;
84#if USE_FIXED_POINT_FLOW
85  // Source integer coordinates.
86  const int src_left_fixed = RealToFixed1616(src_left_real);
87  const int src_top_fixed = RealToFixed1616(src_top_real);
88
89  for (int y = 0; y < kPatchSize; ++y) {
90    const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max);
91
92    for (int x = 0; x < kPatchSize; ++x) {
93      const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max);
94
95      *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y);
96      *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y);
97      *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y);
98    }
99  }
100#else
101  for (int y = 0; y < kPatchSize; ++y) {
102    const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max);
103
104    for (int x = 0; x < kPatchSize; ++x) {
105      const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max);
106
107      *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos);
108      *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos);
109      *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos);
110    }
111  }
112#endif
113
114  // Compute the spatial gradient matrix about point p.
115  float G[] = { 0, 0, 0, 0 };
116  CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G);
117
118  // Find the inverse of G.
119  float G_inv[4];
120  if (!Invert2x2(G, G_inv)) {
121    return false;
122  }
123
124#if NORMALIZE
125  const float mean_I = ComputeMean(vals_I, kFlowArraySize);
126  const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I);
127#endif
128
129  // Iterate kNumIterations times or until we converge.
130  for (int iteration = 0; iteration < kNumIterations; ++iteration) {
131    // Get values for frame 2.
132    float vals_J[kFlowArraySize];
133
134    // Get the window around the destination point.
135    const float left_real = p_x + g_x - kWindowSizeFloat;
136    const float top_real  = p_y + g_y - kWindowSizeFloat;
137    float* vals_J_ptr = vals_J;
138#if USE_FIXED_POINT_FLOW
139    // The top-left sub-pixel is set for the current iteration (in 16:16
140    // fixed). This is constant over one iteration.
141    const int left_fixed = RealToFixed1616(left_real);
142    const int top_fixed  = RealToFixed1616(top_real);
143
144    for (int win_y = 0; win_y < kPatchSize; ++win_y) {
145      const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max);
146      for (int win_x = 0; win_x < kPatchSize; ++win_x) {
147        const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max);
148        *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y);
149      }
150    }
151#else
152    for (int win_y = 0; win_y < kPatchSize; ++win_y) {
153      const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max);
154      for (int win_x = 0; win_x < kPatchSize; ++win_x) {
155        const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max);
156        *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos);
157      }
158    }
159#endif
160
161#if NORMALIZE
162    const float mean_J = ComputeMean(vals_J, kFlowArraySize);
163    const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J);
164
165    // TODO(andrewharp): Probably better to completely detect and handle the
166    // "corner case" where the patch is fully outside the image diagonally.
167    const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f;
168#endif
169
170    // Compute image mismatch vector.
171    float b_x = 0.0f;
172    float b_y = 0.0f;
173
174    vals_I_ptr = vals_I;
175    vals_J_ptr = vals_J;
176    vals_I_x_ptr = vals_I_x;
177    vals_I_y_ptr = vals_I_y;
178
179    for (int win_y = 0; win_y < kPatchSize; ++win_y) {
180      for (int win_x = 0; win_x < kPatchSize; ++win_x) {
181#if NORMALIZE
182        // Normalized Image difference.
183        const float dI =
184            (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio;
185#else
186        const float dI = *vals_I_ptr++ - *vals_J_ptr++;
187#endif
188        b_x += dI * *vals_I_x_ptr++;
189        b_y += dI * *vals_I_y_ptr++;
190      }
191    }
192
193    // Optical flow... solve n = G^-1 * b
194    const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y);
195    const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y);
196
197    // Update best guess with residual displacement from this level and
198    // iteration.
199    g_x += n_x;
200    g_y += n_y;
201
202    // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y);
203
204    // Abort early if we're already below the threshold.
205    if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) {
206      break;
207    }
208  }  // Iteration.
209
210  // Copy value back into output.
211  *out_g_x = g_x;
212  *out_g_y = g_y;
213  return true;
214}
215
216
217// Pointwise flow using translational 2dof ESM.
218bool OpticalFlow::FindFlowAtPoint_ESM(const Image<uint8>& img_I,
219                                      const Image<uint8>& img_J,
220                                      const Image<int32>& I_x,
221                                      const Image<int32>& I_y,
222                                      const Image<int32>& J_x,
223                                      const Image<int32>& J_y,
224                                      const float p_x,
225                                      const float p_y,
226                                      float* out_g_x,
227                                      float* out_g_y) {
228  float g_x = *out_g_x;
229  float g_y = *out_g_y;
230  const float area_inv = 1.0f / static_cast<float>(kFlowArraySize);
231
232  // Get values for frame 1. They remain constant through the inner
233  // iteration loop.
234  uint8 vals_I[kFlowArraySize];
235  uint8 vals_J[kFlowArraySize];
236  int16 src_gradient_x[kFlowArraySize];
237  int16 src_gradient_y[kFlowArraySize];
238
239  // TODO(rspring): try out the IntegerPatchAlign() method once
240  // the code for that is in ../common.
241  const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize);
242  const int src_left_fixed = RealToFixed1616(p_x - wsize_float);
243  const int src_top_fixed = RealToFixed1616(p_y - wsize_float);
244  const int patch_size = 2 * kFlowIntegrationWindowSize + 1;
245
246  // Create the keypoint template patch from a subpixel location.
247  if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
248                                             patch_size, patch_size, vals_I) ||
249      !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
250                                           patch_size, patch_size,
251                                           src_gradient_x) ||
252      !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed,
253                                           patch_size, patch_size,
254                                           src_gradient_y)) {
255    return false;
256  }
257
258  int bright_offset = 0;
259  int sum_diff = 0;
260
261  // The top-left sub-pixel is set for the current iteration (in 16:16 fixed).
262  // This is constant over one iteration.
263  int left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
264  int top_fixed  = RealToFixed1616(p_y + g_y - wsize_float);
265
266  // The truncated version gives the most top-left pixel that is used.
267  int left_trunc = left_fixed >> 16;
268  int top_trunc = top_fixed >> 16;
269
270  // Compute an initial brightness offset.
271  if (kDoBrightnessNormalize &&
272      left_trunc >= 0 && top_trunc >= 0 &&
273      (left_trunc + patch_size) < img_J.width_less_one_ &&
274      (top_trunc + patch_size) < img_J.height_less_one_) {
275    int templ_index = 0;
276    const uint8* j_row = img_J[top_trunc] + left_trunc;
277
278    const int j_stride = img_J.stride();
279
280    for (int y = 0; y < patch_size; ++y, j_row += j_stride) {
281      for (int x = 0; x < patch_size; ++x) {
282        sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++];
283      }
284    }
285
286    bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv);
287  }
288
289  // Iterate kNumIterations times or until we go out of image.
290  for (int iteration = 0; iteration < kNumIterations; ++iteration) {
291    int jtj[3] = { 0, 0, 0 };
292    int jtr[2] = { 0, 0 };
293    sum_diff = 0;
294
295    // Extract the target image values.
296    // Extract the gradient from the target image patch and accumulate to
297    // the gradient of the source image patch.
298    if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed,
299                                               patch_size, patch_size,
300                                               vals_J)) {
301      break;
302    }
303
304    const uint8* templ_row = vals_I;
305    const uint8* extract_row = vals_J;
306    const int16* src_dx_row = src_gradient_x;
307    const int16* src_dy_row = src_gradient_y;
308
309    for (int y = 0; y < patch_size; ++y, templ_row += patch_size,
310         src_dx_row += patch_size, src_dy_row += patch_size,
311         extract_row += patch_size) {
312      const int fp_y = top_fixed + (y << 16);
313      for (int x = 0; x < patch_size; ++x) {
314        const int fp_x = left_fixed + (x << 16);
315        int32 target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y);
316        int32 target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y);
317
318        // Combine the two Jacobians.
319        // Right-shift by one to account for the fact that we add
320        // two Jacobians.
321        int32 dx = (src_dx_row[x] + target_dx) >> 1;
322        int32 dy = (src_dy_row[x] + target_dy) >> 1;
323
324        // The current residual b - h(q) == extracted - (template + offset)
325        int32 diff = static_cast<int32>(extract_row[x]) -
326                     static_cast<int32>(templ_row[x]) -
327                     bright_offset;
328
329        jtj[0] += dx * dx;
330        jtj[1] += dx * dy;
331        jtj[2] += dy * dy;
332
333        jtr[0] += dx * diff;
334        jtr[1] += dy * diff;
335
336        sum_diff += diff;
337      }
338    }
339
340    const float jtr1_float = static_cast<float>(jtr[0]);
341    const float jtr2_float = static_cast<float>(jtr[1]);
342
343    // Add some baseline stability to the system.
344    jtj[0] += kEsmRegularizer;
345    jtj[2] += kEsmRegularizer;
346
347    const int64 prod1 = static_cast<int64>(jtj[0]) * jtj[2];
348    const int64 prod2 = static_cast<int64>(jtj[1]) * jtj[1];
349
350    // One ESM step.
351    const float jtj_1[4] = { static_cast<float>(jtj[2]),
352                             static_cast<float>(-jtj[1]),
353                             static_cast<float>(-jtj[1]),
354                             static_cast<float>(jtj[0]) };
355    const double det_inv = 1.0 / static_cast<double>(prod1 - prod2);
356
357    g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float);
358    g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float);
359
360    if (kDoBrightnessNormalize) {
361      bright_offset +=
362          static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f);
363    }
364
365    // Update top left position.
366    left_fixed = RealToFixed1616(p_x + g_x - wsize_float);
367    top_fixed  = RealToFixed1616(p_y + g_y - wsize_float);
368
369    left_trunc = left_fixed >> 16;
370    top_trunc = top_fixed >> 16;
371
372    // Abort iterations if we go out of borders.
373    if (left_trunc < 0 || top_trunc < 0 ||
374        (left_trunc + patch_size) >= J_x.width_less_one_ ||
375        (top_trunc + patch_size) >= J_y.height_less_one_) {
376      break;
377    }
378  }  // Iteration.
379
380  // Copy value back into output.
381  *out_g_x = g_x;
382  *out_g_y = g_y;
383  return true;
384}
385
386
387bool OpticalFlow::FindFlowAtPointReversible(
388    const int level, const float u_x, const float u_y,
389    const bool reverse_flow,
390    float* flow_x, float* flow_y) const {
391  const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_;
392  const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_;
393
394  // Images I (prev) and J (next).
395  const Image<uint8>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2);
396  const Image<uint8>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2);
397
398  // Computed gradients.
399  const Image<int32>& I_x = *frame_a.GetSpatialX(level);
400  const Image<int32>& I_y = *frame_a.GetSpatialY(level);
401  const Image<int32>& J_x = *frame_b.GetSpatialX(level);
402  const Image<int32>& J_y = *frame_b.GetSpatialY(level);
403
404  // Shrink factor from original.
405  const float shrink_factor = (1 << level);
406
407  // Image position vector (p := u^l), scaled for this level.
408  const float scaled_p_x = u_x / shrink_factor;
409  const float scaled_p_y = u_y / shrink_factor;
410
411  float scaled_flow_x = *flow_x / shrink_factor;
412  float scaled_flow_y = *flow_y / shrink_factor;
413
414  // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level,
415  //     scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y);
416
417  const bool success = kUseEsm ?
418    FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y,
419                        scaled_p_x, scaled_p_y,
420                        &scaled_flow_x, &scaled_flow_y) :
421    FindFlowAtPoint_LK(img_I, img_J, I_x, I_y,
422                       scaled_p_x, scaled_p_y,
423                       &scaled_flow_x, &scaled_flow_y);
424
425  *flow_x = scaled_flow_x * shrink_factor;
426  *flow_y = scaled_flow_y * shrink_factor;
427
428  return success;
429}
430
431
432bool OpticalFlow::FindFlowAtPointSingleLevel(
433    const int level,
434    const float u_x, const float u_y,
435    const bool filter_by_fb_error,
436    float* flow_x, float* flow_y) const {
437  if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) {
438    return false;
439  }
440
441  if (filter_by_fb_error) {
442    const float new_position_x = u_x + *flow_x;
443    const float new_position_y = u_y + *flow_y;
444
445    float reverse_flow_x = 0.0f;
446    float reverse_flow_y = 0.0f;
447
448    // Now find the backwards flow and confirm it lines up with the original
449    // starting point.
450    if (!FindFlowAtPointReversible(level, new_position_x, new_position_y,
451                                   true,
452                                   &reverse_flow_x, &reverse_flow_y)) {
453      LOGE("Backward error!");
454      return false;
455    }
456
457    const float discrepancy_length =
458        sqrtf(Square(*flow_x + reverse_flow_x) +
459              Square(*flow_y + reverse_flow_y));
460
461    const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y));
462
463    return discrepancy_length <
464        (kMaxForwardBackwardErrorAllowed * flow_length);
465  }
466
467  return true;
468}
469
470
471// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm.
472// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details.
473bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y,
474                                           const bool filter_by_fb_error,
475                                           float* flow_x, float* flow_y) const {
476  const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment,
477                            kNumPyramidLevels - kNumCacheLevels);
478
479  // For every level in the pyramid, update the coordinates of the best match.
480  for (int l = max_level - 1; l >= 0; --l) {
481    if (!FindFlowAtPointSingleLevel(l, u_x, u_y,
482                                    filter_by_fb_error, flow_x, flow_y)) {
483      return false;
484    }
485  }
486
487  return true;
488}
489
490}  // namespace tf_tracking
491