optical_flow.cc revision a30b9926bd7d5276d6ff35af9428dee3e77b7dcb
1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <math.h> 17 18#include "tensorflow/examples/android/jni/object_tracking/geom.h" 19#include "tensorflow/examples/android/jni/object_tracking/image-inl.h" 20#include "tensorflow/examples/android/jni/object_tracking/image.h" 21#include "tensorflow/examples/android/jni/object_tracking/time_log.h" 22#include "tensorflow/examples/android/jni/object_tracking/utils.h" 23 24#include "tensorflow/examples/android/jni/object_tracking/config.h" 25#include "tensorflow/examples/android/jni/object_tracking/flow_cache.h" 26#include "tensorflow/examples/android/jni/object_tracking/frame_pair.h" 27#include "tensorflow/examples/android/jni/object_tracking/image_data.h" 28#include "tensorflow/examples/android/jni/object_tracking/keypoint.h" 29#include "tensorflow/examples/android/jni/object_tracking/keypoint_detector.h" 30#include "tensorflow/examples/android/jni/object_tracking/optical_flow.h" 31 32namespace tf_tracking { 33 34OpticalFlow::OpticalFlow(const OpticalFlowConfig* const config) 35 : config_(config), 36 frame1_(NULL), 37 frame2_(NULL), 38 working_size_(config->image_size) {} 39 40 41void OpticalFlow::NextFrame(const ImageData* const image_data) { 42 // Special case for the first frame: make sure the image ends up in 43 // frame1_ so that keypoint detection can be done on it if desired. 44 frame1_ = (frame1_ == NULL) ? image_data : frame2_; 45 frame2_ = image_data; 46} 47 48 49// Static heart of the optical flow computation. 50// Lucas Kanade algorithm. 51bool OpticalFlow::FindFlowAtPoint_LK(const Image<uint8>& img_I, 52 const Image<uint8>& img_J, 53 const Image<int32>& I_x, 54 const Image<int32>& I_y, 55 const float p_x, 56 const float p_y, 57 float* out_g_x, 58 float* out_g_y) { 59 float g_x = *out_g_x; 60 float g_y = *out_g_y; 61 // Get values for frame 1. They remain constant through the inner 62 // iteration loop. 63 float vals_I[kFlowArraySize]; 64 float vals_I_x[kFlowArraySize]; 65 float vals_I_y[kFlowArraySize]; 66 67 const int kPatchSize = 2 * kFlowIntegrationWindowSize + 1; 68 const float kWindowSizeFloat = static_cast<float>(kFlowIntegrationWindowSize); 69 70#if USE_FIXED_POINT_FLOW 71 const int fixed_x_max = RealToFixed1616(img_I.width_less_one_) - 1; 72 const int fixed_y_max = RealToFixed1616(img_I.height_less_one_) - 1; 73#else 74 const float real_x_max = I_x.width_less_one_ - EPSILON; 75 const float real_y_max = I_x.height_less_one_ - EPSILON; 76#endif 77 78 // Get the window around the original point. 79 const float src_left_real = p_x - kWindowSizeFloat; 80 const float src_top_real = p_y - kWindowSizeFloat; 81 float* vals_I_ptr = vals_I; 82 float* vals_I_x_ptr = vals_I_x; 83 float* vals_I_y_ptr = vals_I_y; 84#if USE_FIXED_POINT_FLOW 85 // Source integer coordinates. 86 const int src_left_fixed = RealToFixed1616(src_left_real); 87 const int src_top_fixed = RealToFixed1616(src_top_real); 88 89 for (int y = 0; y < kPatchSize; ++y) { 90 const int fp_y = Clip(src_top_fixed + (y << 16), 0, fixed_y_max); 91 92 for (int x = 0; x < kPatchSize; ++x) { 93 const int fp_x = Clip(src_left_fixed + (x << 16), 0, fixed_x_max); 94 95 *vals_I_ptr++ = img_I.GetPixelInterpFixed1616(fp_x, fp_y); 96 *vals_I_x_ptr++ = I_x.GetPixelInterpFixed1616(fp_x, fp_y); 97 *vals_I_y_ptr++ = I_y.GetPixelInterpFixed1616(fp_x, fp_y); 98 } 99 } 100#else 101 for (int y = 0; y < kPatchSize; ++y) { 102 const float y_pos = Clip(src_top_real + y, 0.0f, real_y_max); 103 104 for (int x = 0; x < kPatchSize; ++x) { 105 const float x_pos = Clip(src_left_real + x, 0.0f, real_x_max); 106 107 *vals_I_ptr++ = img_I.GetPixelInterp(x_pos, y_pos); 108 *vals_I_x_ptr++ = I_x.GetPixelInterp(x_pos, y_pos); 109 *vals_I_y_ptr++ = I_y.GetPixelInterp(x_pos, y_pos); 110 } 111 } 112#endif 113 114 // Compute the spatial gradient matrix about point p. 115 float G[] = { 0, 0, 0, 0 }; 116 CalculateG(vals_I_x, vals_I_y, kFlowArraySize, G); 117 118 // Find the inverse of G. 119 float G_inv[4]; 120 if (!Invert2x2(G, G_inv)) { 121 return false; 122 } 123 124#if NORMALIZE 125 const float mean_I = ComputeMean(vals_I, kFlowArraySize); 126 const float std_dev_I = ComputeStdDev(vals_I, kFlowArraySize, mean_I); 127#endif 128 129 // Iterate kNumIterations times or until we converge. 130 for (int iteration = 0; iteration < kNumIterations; ++iteration) { 131 // Get values for frame 2. 132 float vals_J[kFlowArraySize]; 133 134 // Get the window around the destination point. 135 const float left_real = p_x + g_x - kWindowSizeFloat; 136 const float top_real = p_y + g_y - kWindowSizeFloat; 137 float* vals_J_ptr = vals_J; 138#if USE_FIXED_POINT_FLOW 139 // The top-left sub-pixel is set for the current iteration (in 16:16 140 // fixed). This is constant over one iteration. 141 const int left_fixed = RealToFixed1616(left_real); 142 const int top_fixed = RealToFixed1616(top_real); 143 144 for (int win_y = 0; win_y < kPatchSize; ++win_y) { 145 const int fp_y = Clip(top_fixed + (win_y << 16), 0, fixed_y_max); 146 for (int win_x = 0; win_x < kPatchSize; ++win_x) { 147 const int fp_x = Clip(left_fixed + (win_x << 16), 0, fixed_x_max); 148 *vals_J_ptr++ = img_J.GetPixelInterpFixed1616(fp_x, fp_y); 149 } 150 } 151#else 152 for (int win_y = 0; win_y < kPatchSize; ++win_y) { 153 const float y_pos = Clip(top_real + win_y, 0.0f, real_y_max); 154 for (int win_x = 0; win_x < kPatchSize; ++win_x) { 155 const float x_pos = Clip(left_real + win_x, 0.0f, real_x_max); 156 *vals_J_ptr++ = img_J.GetPixelInterp(x_pos, y_pos); 157 } 158 } 159#endif 160 161#if NORMALIZE 162 const float mean_J = ComputeMean(vals_J, kFlowArraySize); 163 const float std_dev_J = ComputeStdDev(vals_J, kFlowArraySize, mean_J); 164 165 // TODO(andrewharp): Probably better to completely detect and handle the 166 // "corner case" where the patch is fully outside the image diagonally. 167 const float std_dev_ratio = std_dev_J > 0.0f ? std_dev_I / std_dev_J : 1.0f; 168#endif 169 170 // Compute image mismatch vector. 171 float b_x = 0.0f; 172 float b_y = 0.0f; 173 174 vals_I_ptr = vals_I; 175 vals_J_ptr = vals_J; 176 vals_I_x_ptr = vals_I_x; 177 vals_I_y_ptr = vals_I_y; 178 179 for (int win_y = 0; win_y < kPatchSize; ++win_y) { 180 for (int win_x = 0; win_x < kPatchSize; ++win_x) { 181#if NORMALIZE 182 // Normalized Image difference. 183 const float dI = 184 (*vals_I_ptr++ - mean_I) - (*vals_J_ptr++ - mean_J) * std_dev_ratio; 185#else 186 const float dI = *vals_I_ptr++ - *vals_J_ptr++; 187#endif 188 b_x += dI * *vals_I_x_ptr++; 189 b_y += dI * *vals_I_y_ptr++; 190 } 191 } 192 193 // Optical flow... solve n = G^-1 * b 194 const float n_x = (G_inv[0] * b_x) + (G_inv[1] * b_y); 195 const float n_y = (G_inv[2] * b_x) + (G_inv[3] * b_y); 196 197 // Update best guess with residual displacement from this level and 198 // iteration. 199 g_x += n_x; 200 g_y += n_y; 201 202 // LOGV("Iteration %d: delta (%.3f, %.3f)", iteration, n_x, n_y); 203 204 // Abort early if we're already below the threshold. 205 if (Square(n_x) + Square(n_y) < Square(kTrackingAbortThreshold)) { 206 break; 207 } 208 } // Iteration. 209 210 // Copy value back into output. 211 *out_g_x = g_x; 212 *out_g_y = g_y; 213 return true; 214} 215 216 217// Pointwise flow using translational 2dof ESM. 218bool OpticalFlow::FindFlowAtPoint_ESM(const Image<uint8>& img_I, 219 const Image<uint8>& img_J, 220 const Image<int32>& I_x, 221 const Image<int32>& I_y, 222 const Image<int32>& J_x, 223 const Image<int32>& J_y, 224 const float p_x, 225 const float p_y, 226 float* out_g_x, 227 float* out_g_y) { 228 float g_x = *out_g_x; 229 float g_y = *out_g_y; 230 const float area_inv = 1.0f / static_cast<float>(kFlowArraySize); 231 232 // Get values for frame 1. They remain constant through the inner 233 // iteration loop. 234 uint8 vals_I[kFlowArraySize]; 235 uint8 vals_J[kFlowArraySize]; 236 int16 src_gradient_x[kFlowArraySize]; 237 int16 src_gradient_y[kFlowArraySize]; 238 239 // TODO(rspring): try out the IntegerPatchAlign() method once 240 // the code for that is in ../common. 241 const float wsize_float = static_cast<float>(kFlowIntegrationWindowSize); 242 const int src_left_fixed = RealToFixed1616(p_x - wsize_float); 243 const int src_top_fixed = RealToFixed1616(p_y - wsize_float); 244 const int patch_size = 2 * kFlowIntegrationWindowSize + 1; 245 246 // Create the keypoint template patch from a subpixel location. 247 if (!img_I.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, 248 patch_size, patch_size, vals_I) || 249 !I_x.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, 250 patch_size, patch_size, 251 src_gradient_x) || 252 !I_y.ExtractPatchAtSubpixelFixed1616(src_left_fixed, src_top_fixed, 253 patch_size, patch_size, 254 src_gradient_y)) { 255 return false; 256 } 257 258 int bright_offset = 0; 259 int sum_diff = 0; 260 261 // The top-left sub-pixel is set for the current iteration (in 16:16 fixed). 262 // This is constant over one iteration. 263 int left_fixed = RealToFixed1616(p_x + g_x - wsize_float); 264 int top_fixed = RealToFixed1616(p_y + g_y - wsize_float); 265 266 // The truncated version gives the most top-left pixel that is used. 267 int left_trunc = left_fixed >> 16; 268 int top_trunc = top_fixed >> 16; 269 270 // Compute an initial brightness offset. 271 if (kDoBrightnessNormalize && 272 left_trunc >= 0 && top_trunc >= 0 && 273 (left_trunc + patch_size) < img_J.width_less_one_ && 274 (top_trunc + patch_size) < img_J.height_less_one_) { 275 int templ_index = 0; 276 const uint8* j_row = img_J[top_trunc] + left_trunc; 277 278 const int j_stride = img_J.stride(); 279 280 for (int y = 0; y < patch_size; ++y, j_row += j_stride) { 281 for (int x = 0; x < patch_size; ++x) { 282 sum_diff += static_cast<int>(j_row[x]) - vals_I[templ_index++]; 283 } 284 } 285 286 bright_offset = static_cast<int>(static_cast<float>(sum_diff) * area_inv); 287 } 288 289 // Iterate kNumIterations times or until we go out of image. 290 for (int iteration = 0; iteration < kNumIterations; ++iteration) { 291 int jtj[3] = { 0, 0, 0 }; 292 int jtr[2] = { 0, 0 }; 293 sum_diff = 0; 294 295 // Extract the target image values. 296 // Extract the gradient from the target image patch and accumulate to 297 // the gradient of the source image patch. 298 if (!img_J.ExtractPatchAtSubpixelFixed1616(left_fixed, top_fixed, 299 patch_size, patch_size, 300 vals_J)) { 301 break; 302 } 303 304 const uint8* templ_row = vals_I; 305 const uint8* extract_row = vals_J; 306 const int16* src_dx_row = src_gradient_x; 307 const int16* src_dy_row = src_gradient_y; 308 309 for (int y = 0; y < patch_size; ++y, templ_row += patch_size, 310 src_dx_row += patch_size, src_dy_row += patch_size, 311 extract_row += patch_size) { 312 const int fp_y = top_fixed + (y << 16); 313 for (int x = 0; x < patch_size; ++x) { 314 const int fp_x = left_fixed + (x << 16); 315 int32 target_dx = J_x.GetPixelInterpFixed1616(fp_x, fp_y); 316 int32 target_dy = J_y.GetPixelInterpFixed1616(fp_x, fp_y); 317 318 // Combine the two Jacobians. 319 // Right-shift by one to account for the fact that we add 320 // two Jacobians. 321 int32 dx = (src_dx_row[x] + target_dx) >> 1; 322 int32 dy = (src_dy_row[x] + target_dy) >> 1; 323 324 // The current residual b - h(q) == extracted - (template + offset) 325 int32 diff = static_cast<int32>(extract_row[x]) - 326 static_cast<int32>(templ_row[x]) - 327 bright_offset; 328 329 jtj[0] += dx * dx; 330 jtj[1] += dx * dy; 331 jtj[2] += dy * dy; 332 333 jtr[0] += dx * diff; 334 jtr[1] += dy * diff; 335 336 sum_diff += diff; 337 } 338 } 339 340 const float jtr1_float = static_cast<float>(jtr[0]); 341 const float jtr2_float = static_cast<float>(jtr[1]); 342 343 // Add some baseline stability to the system. 344 jtj[0] += kEsmRegularizer; 345 jtj[2] += kEsmRegularizer; 346 347 const int64 prod1 = static_cast<int64>(jtj[0]) * jtj[2]; 348 const int64 prod2 = static_cast<int64>(jtj[1]) * jtj[1]; 349 350 // One ESM step. 351 const float jtj_1[4] = { static_cast<float>(jtj[2]), 352 static_cast<float>(-jtj[1]), 353 static_cast<float>(-jtj[1]), 354 static_cast<float>(jtj[0]) }; 355 const double det_inv = 1.0 / static_cast<double>(prod1 - prod2); 356 357 g_x -= det_inv * (jtj_1[0] * jtr1_float + jtj_1[1] * jtr2_float); 358 g_y -= det_inv * (jtj_1[2] * jtr1_float + jtj_1[3] * jtr2_float); 359 360 if (kDoBrightnessNormalize) { 361 bright_offset += 362 static_cast<int>(area_inv * static_cast<float>(sum_diff) + 0.5f); 363 } 364 365 // Update top left position. 366 left_fixed = RealToFixed1616(p_x + g_x - wsize_float); 367 top_fixed = RealToFixed1616(p_y + g_y - wsize_float); 368 369 left_trunc = left_fixed >> 16; 370 top_trunc = top_fixed >> 16; 371 372 // Abort iterations if we go out of borders. 373 if (left_trunc < 0 || top_trunc < 0 || 374 (left_trunc + patch_size) >= J_x.width_less_one_ || 375 (top_trunc + patch_size) >= J_y.height_less_one_) { 376 break; 377 } 378 } // Iteration. 379 380 // Copy value back into output. 381 *out_g_x = g_x; 382 *out_g_y = g_y; 383 return true; 384} 385 386 387bool OpticalFlow::FindFlowAtPointReversible( 388 const int level, const float u_x, const float u_y, 389 const bool reverse_flow, 390 float* flow_x, float* flow_y) const { 391 const ImageData& frame_a = reverse_flow ? *frame2_ : *frame1_; 392 const ImageData& frame_b = reverse_flow ? *frame1_ : *frame2_; 393 394 // Images I (prev) and J (next). 395 const Image<uint8>& img_I = *frame_a.GetPyramidSqrt2Level(level * 2); 396 const Image<uint8>& img_J = *frame_b.GetPyramidSqrt2Level(level * 2); 397 398 // Computed gradients. 399 const Image<int32>& I_x = *frame_a.GetSpatialX(level); 400 const Image<int32>& I_y = *frame_a.GetSpatialY(level); 401 const Image<int32>& J_x = *frame_b.GetSpatialX(level); 402 const Image<int32>& J_y = *frame_b.GetSpatialY(level); 403 404 // Shrink factor from original. 405 const float shrink_factor = (1 << level); 406 407 // Image position vector (p := u^l), scaled for this level. 408 const float scaled_p_x = u_x / shrink_factor; 409 const float scaled_p_y = u_y / shrink_factor; 410 411 float scaled_flow_x = *flow_x / shrink_factor; 412 float scaled_flow_y = *flow_y / shrink_factor; 413 414 // LOGE("FindFlowAtPoint level %d: %5.2f, %5.2f (%5.2f, %5.2f)", level, 415 // scaled_p_x, scaled_p_y, &scaled_flow_x, &scaled_flow_y); 416 417 const bool success = kUseEsm ? 418 FindFlowAtPoint_ESM(img_I, img_J, I_x, I_y, J_x, J_y, 419 scaled_p_x, scaled_p_y, 420 &scaled_flow_x, &scaled_flow_y) : 421 FindFlowAtPoint_LK(img_I, img_J, I_x, I_y, 422 scaled_p_x, scaled_p_y, 423 &scaled_flow_x, &scaled_flow_y); 424 425 *flow_x = scaled_flow_x * shrink_factor; 426 *flow_y = scaled_flow_y * shrink_factor; 427 428 return success; 429} 430 431 432bool OpticalFlow::FindFlowAtPointSingleLevel( 433 const int level, 434 const float u_x, const float u_y, 435 const bool filter_by_fb_error, 436 float* flow_x, float* flow_y) const { 437 if (!FindFlowAtPointReversible(level, u_x, u_y, false, flow_x, flow_y)) { 438 return false; 439 } 440 441 if (filter_by_fb_error) { 442 const float new_position_x = u_x + *flow_x; 443 const float new_position_y = u_y + *flow_y; 444 445 float reverse_flow_x = 0.0f; 446 float reverse_flow_y = 0.0f; 447 448 // Now find the backwards flow and confirm it lines up with the original 449 // starting point. 450 if (!FindFlowAtPointReversible(level, new_position_x, new_position_y, 451 true, 452 &reverse_flow_x, &reverse_flow_y)) { 453 LOGE("Backward error!"); 454 return false; 455 } 456 457 const float discrepancy_length = 458 sqrtf(Square(*flow_x + reverse_flow_x) + 459 Square(*flow_y + reverse_flow_y)); 460 461 const float flow_length = sqrtf(Square(*flow_x) + Square(*flow_y)); 462 463 return discrepancy_length < 464 (kMaxForwardBackwardErrorAllowed * flow_length); 465 } 466 467 return true; 468} 469 470 471// An implementation of the Pyramidal Lucas-Kanade Optical Flow algorithm. 472// See http://robots.stanford.edu/cs223b04/algo_tracking.pdf for details. 473bool OpticalFlow::FindFlowAtPointPyramidal(const float u_x, const float u_y, 474 const bool filter_by_fb_error, 475 float* flow_x, float* flow_y) const { 476 const int max_level = MAX(kMinNumPyramidLevelsToUseForAdjustment, 477 kNumPyramidLevels - kNumCacheLevels); 478 479 // For every level in the pyramid, update the coordinates of the best match. 480 for (int l = max_level - 1; l >= 0; --l) { 481 if (!FindFlowAtPointSingleLevel(l, u_x, u_y, 482 filter_by_fb_error, flow_x, flow_y)) { 483 return false; 484 } 485 } 486 487 return true; 488} 489 490} // namespace tf_tracking 491