1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// See docs in ../ops/image_ops.cc.
16#include <math.h>
17#include "tensorflow/core/framework/op_kernel.h"
18#include "tensorflow/core/framework/register_types.h"
19#include "tensorflow/core/framework/tensor.h"
20#include "tensorflow/core/framework/types.h"
21#include "tensorflow/core/kernels/bounds_check.h"
22#include "tensorflow/core/lib/random/simple_philox.h"
23#include "tensorflow/core/util/guarded_philox_random.h"
24
25using tensorflow::random::SimplePhilox;
26
27namespace tensorflow {
28namespace {
29
30// A simple Rectangle class that supplies intersection.
31class Rectangle {
32 public:
33  Rectangle() { Set(0, 0, 0, 0); }
34  Rectangle(int xmin, int ymin, int xmax, int ymax) {
35    Set(xmin, ymin, xmax, ymax);
36  }
37
38  void Set(int xmin, int ymin, int xmax, int ymax) {
39    min_x_ = xmin;
40    min_y_ = ymin;
41    max_x_ = xmax;
42    max_y_ = ymax;
43  }
44
45  bool IsEmpty() const { return min_x_ > max_x_ || min_y_ > max_y_; }
46  float Area() const {
47    return static_cast<float>((max_x_ - min_x_) * (max_y_ - min_y_));
48  }
49
50  Rectangle Intersect(const Rectangle& r) const {
51    const int pmin_x = std::max(min_x_, r.min_x_);
52    const int pmin_y = std::max(min_y_, r.min_y_);
53    const int pmax_x = std::min(max_x_, r.max_x_);
54    const int pmax_y = std::min(max_y_, r.max_y_);
55
56    if (pmin_x > pmax_x || pmin_y > pmax_y) {
57      return Rectangle();
58    } else {
59      return Rectangle(pmin_x, pmin_y, pmax_x, pmax_y);
60    }
61  }
62
63  int min_x_;
64  int min_y_;
65  int max_x_;
66  int max_y_;
67};
68
69// Determine if the supplied cropping box covers a sufficient fraction of the
70// the supplied bounding boxes.
71bool SatisfiesOverlapConstraints(const Rectangle& crop,
72                                 float minimum_object_covered,
73                                 const std::vector<Rectangle>& bounding_boxes) {
74  // Reject any bounding box which contains no pixels.
75  const float kMinArea = 1.0;
76  if (crop.Area() < kMinArea) {
77    return false;
78  }
79
80  // Loop through all objects and determine if the proposed cropping box covers
81  // a sufficient fraction of one of the supplied bounding boxes.
82  bool is_object_covered = false;
83  for (const auto& bbox : bounding_boxes) {
84    const float object_area = bbox.Area();
85    if (object_area < kMinArea) {
86      continue;
87    }
88
89    const float object_covered = crop.Intersect(bbox).Area() / object_area;
90
91    if (object_covered >= minimum_object_covered) {
92      is_object_covered = true;
93      break;
94    }
95  }
96  return is_object_covered;
97}
98
99// Generate a random crop within the rectangle
100// (0, 0, original_width, original_height).
101// The minimum area of the crop will be between
102//   min_relative_crop_area * orig_width * orig_height
103// and
104//   max_relative_crop_area * orig_width * orig_height
105// such that its width = round(aspect_ratio * height).
106// The diameter of the generated rectangle will be uniformly distributed between
107// its minimum and maximum size. The center of the rectangle will be distributed
108// uniformly within the source rectangle. The function returns false if the
109// rectangle could not be generated with the given constraints.
110bool GenerateRandomCrop(int original_width, int original_height,
111                        float min_relative_crop_area,
112                        float max_relative_crop_area, float aspect_ratio,
113                        SimplePhilox* random, Rectangle* crop_rect) {
114  if (max_relative_crop_area <= 0.0 || aspect_ratio <= 0.0 ||
115      original_width <= 0 || original_height <= 0 ||
116      min_relative_crop_area > max_relative_crop_area) {
117    return false;
118  }
119
120  const float min_area =
121      min_relative_crop_area * original_width * original_height;
122  const float max_area =
123      max_relative_crop_area * original_width * original_height;
124
125  int height = static_cast<int>(lrintf(sqrt(min_area / aspect_ratio)));
126  int max_height = static_cast<int>(lrintf(sqrt(max_area / aspect_ratio)));
127
128  if (lrintf(max_height * aspect_ratio) > original_width) {
129    // We must find the smallest max_height satisfying
130    // round(max_height * aspect_ratio) <= original_width:
131    const float kEps = 0.0000001;
132    max_height = static_cast<int>((original_width + 0.5 - kEps) / aspect_ratio);
133  }
134
135  if (max_height > original_height) {
136    max_height = original_height;
137  }
138
139  if (height >= max_height) {
140    height = max_height;
141  }
142
143  if (height < max_height) {
144    // We need to generate a random number in the closed range
145    // [0, max_height - height].
146    height += random->Uniform(max_height - height + 1);
147  }
148  int width = static_cast<int>(lrintf(height * aspect_ratio));
149  DCHECK_LE(width, original_width);
150
151  // Let us not fail if rounding error causes the area to be
152  // outside the constraints.
153  // Try first with a slightly bigger rectangle first.
154  float area = static_cast<float>(width * height);
155  if (area < min_area) {
156    height += 1;
157    width = static_cast<int>(lrintf(height * aspect_ratio));
158    area = width * height;
159  }
160
161  // Let us not fail if rounding error causes the area to be
162  // outside the constraints.
163  // Try first with a slightly smaller rectangle first.
164  if (area > max_area) {
165    height -= 1;
166    width = static_cast<int>(lrintf(height * aspect_ratio));
167    area = width * height;
168  }
169
170  // Now, we explored all options to rectify small rounding errors.
171  // It seems the constraints can't be satisfied: return false.
172  if (area < min_area || area > max_area || width > original_width ||
173      height > original_height || width <= 0 || height <= 0) {
174    return false;
175  }
176
177  int y = 0;
178  if (height < original_height) {
179    y = random->Uniform(original_height - height);
180  }
181  int x = 0;
182  if (width < original_width) {
183    x = random->Uniform(original_width - width);
184  }
185
186  crop_rect->min_x_ = x;
187  crop_rect->min_y_ = y;
188  crop_rect->max_x_ = x + width;
189  crop_rect->max_y_ = y + height;
190  return true;
191}
192}  // namespace
193
194template <typename T>
195class SampleDistortedBoundingBoxV2Op : public OpKernel {
196 public:
197  explicit SampleDistortedBoundingBoxV2Op(OpKernelConstruction* context)
198      : OpKernel(context) {
199    OP_REQUIRES_OK(context, generator_.Init(context));
200
201    if (context->num_inputs() == 2) {
202      OP_REQUIRES_OK(context, context->GetAttr("min_object_covered",
203                                               &min_object_covered_));
204      OP_REQUIRES(
205          context, min_object_covered_ >= 0,
206          errors::InvalidArgument("Min object covered must be non-negative: ",
207                                  min_object_covered_));
208    }
209
210    OP_REQUIRES_OK(context, context->GetAttr("use_image_if_no_bounding_boxes",
211                                             &use_image_if_no_bounding_boxes_));
212
213    OP_REQUIRES_OK(
214        context, context->GetAttr("aspect_ratio_range", &aspect_ratio_range_));
215    OP_REQUIRES(context, aspect_ratio_range_.size() == 2,
216                errors::InvalidArgument(
217                    "Aspect ratio range field must specify 2 dimensions"));
218
219    OP_REQUIRES(
220        context, aspect_ratio_range_[0] > 0 && aspect_ratio_range_[1] > 0,
221        errors::InvalidArgument("Aspect ratio range must be non-negative: [",
222                                aspect_ratio_range_[0], ", ",
223                                aspect_ratio_range_[1], "]"));
224
225    OP_REQUIRES_OK(context, context->GetAttr("area_range", &area_range_));
226    OP_REQUIRES(
227        context, area_range_.size() == 2,
228        errors::InvalidArgument("Area range field must specify 2 dimensions"));
229
230    OP_REQUIRES(
231        context, area_range_[0] > 0 && area_range_[1] > 0,
232        errors::InvalidArgument("Area range must be non-negative: [",
233                                area_range_[0], ", ", area_range_[1], "]"));
234
235    OP_REQUIRES(context, area_range_[0] <= 1 && area_range_[1] <= 1,
236                errors::InvalidArgument(
237                    "Area range must be less then or equal to 1.0: [",
238                    area_range_[0], ", ", area_range_[1], "]"));
239
240    OP_REQUIRES_OK(context, context->GetAttr("max_attempts", &max_attempts_));
241    OP_REQUIRES(context, max_attempts_ > 0,
242                errors::InvalidArgument("Max attempts must be non-negative: ",
243                                        max_attempts_));
244  }
245
246  void Compute(OpKernelContext* context) override {
247    const Tensor& image_size = context->input(0);
248
249    OP_REQUIRES(context, image_size.dims() == 1,
250                errors::InvalidArgument("image_size must be 1-dimensional",
251                                        image_size.shape().DebugString()));
252    OP_REQUIRES(context, image_size.dim_size(0) == 3,
253                errors::InvalidArgument("image_size must contain 3 elements",
254                                        image_size.shape().DebugString()));
255
256    // Note image_size_data(2) is the depth and unused.
257    const uint64 height_raw = internal::SubtleMustCopy(image_size.flat<T>()(0));
258    const uint64 width_raw = internal::SubtleMustCopy(image_size.flat<T>()(1));
259    OP_REQUIRES(context,
260                FastBoundsCheck(height_raw, std::numeric_limits<int32>::max()),
261                errors::InvalidArgument("image height cannot be >= int32 max"));
262    OP_REQUIRES(context,
263                FastBoundsCheck(width_raw, std::numeric_limits<int32>::max()),
264                errors::InvalidArgument("image width cannot be >= int32 max"));
265    const int32 height = static_cast<int32>(height_raw);
266    const int32 width = static_cast<int32>(width_raw);
267
268    // Ensure that the supplied bounding boxes are sane and convert them to
269    // Rectangles.
270    const Tensor& input_boxes = context->input(1);
271    OP_REQUIRES(context, input_boxes.dims() == 3,
272                errors::InvalidArgument("input boxes must be 3-dimensional "
273                                        "[batch, num_boxes, coords]: ",
274                                        input_boxes.shape().DebugString()));
275    OP_REQUIRES(context, input_boxes.dim_size(input_boxes.dims() - 1) == 4,
276                errors::InvalidArgument(
277                    "bounding boxes must have shape [4] or [*, 4], got ",
278                    input_boxes.shape().DebugString()));
279
280    float min_object_covered_val = 0.0;
281    if (context->num_inputs() == 3) {
282      const Tensor& min_object_covered = context->input(2);
283
284      OP_REQUIRES(
285          context, TensorShapeUtils::IsScalar(min_object_covered.shape()),
286          errors::InvalidArgument("min_object_covered must be 0-D, got shape ",
287                                  min_object_covered.shape().DebugString()));
288
289      min_object_covered_val = min_object_covered.scalar<float>()();
290
291      OP_REQUIRES(
292          context, min_object_covered_val >= 0,
293          errors::InvalidArgument("Min object covered must be non-negative: ",
294                                  min_object_covered_val));
295    } else {
296      min_object_covered_val = min_object_covered_;
297    }
298
299    std::vector<Rectangle> bounding_boxes;
300    if (input_boxes.NumElements() > 0) {
301      TTypes<float>::ConstMatrix boxes = input_boxes.flat_inner_dims<float>();
302      for (int b = 0; b < boxes.dimension(0); ++b) {
303        for (int i = 0; i < 4; ++i) {
304          OP_REQUIRES(
305              context, boxes(b, i) >= 0.0 && boxes(b, i) <= 1.0,
306              errors::InvalidArgument("All bounding box coordinates must "
307                                      "be in [0.0, 1.0]: ",
308                                      boxes(b, i)));
309        }
310
311        const int32 x_min = static_cast<int32>(boxes(b, 1) * width);
312        const int32 y_min = static_cast<int32>(boxes(b, 0) * height);
313        const int32 x_max = static_cast<int32>(boxes(b, 3) * width);
314        const int32 y_max = static_cast<int32>(boxes(b, 2) * height);
315
316        bounding_boxes.push_back(Rectangle(x_min, y_min, x_max, y_max));
317      }
318    }
319
320    // Insert the entire image if no bounding boxes are supplied.
321    const Rectangle image_rect(0, 0, width, height);
322    if (bounding_boxes.empty()) {
323      OP_REQUIRES(context, use_image_if_no_bounding_boxes_,
324                  errors::InvalidArgument(
325                      "No bounding boxes provided as input. One must "
326                      "enable use_image_if_no_bounding_boxes if you wish "
327                      "to not provide any bounding boxes."));
328      bounding_boxes.push_back(image_rect);
329    }
330
331    const float min_sample_area = area_range_[0];
332    const float max_sample_area = area_range_[1];
333    const float min_sample_aspect_ratio = aspect_ratio_range_[0];
334    const float max_sample_aspect_ratio = aspect_ratio_range_[1];
335
336    auto local_gen = generator_.ReserveSamples32(4 * max_attempts_);
337    random::SimplePhilox random(&local_gen);
338
339    Rectangle crop_rect;
340    bool sample_generated = false;
341    for (int i = 0; i < max_attempts_; ++i) {
342      const float sample_aspect_ratio =
343          random.RandFloat() *
344              (max_sample_aspect_ratio - min_sample_aspect_ratio) +
345          min_sample_aspect_ratio;
346
347      if (GenerateRandomCrop(width, height, min_sample_area, max_sample_area,
348                             sample_aspect_ratio, &random, &crop_rect)) {
349        if (SatisfiesOverlapConstraints(crop_rect, min_object_covered_val,
350                                        bounding_boxes)) {
351          sample_generated = true;
352          break;
353        }
354      }
355    }
356
357    if (!sample_generated) {
358      crop_rect = image_rect;
359    }
360
361    // Determine the cropping parameters from the bounding box.
362    const int target_width = crop_rect.max_x_ - crop_rect.min_x_;
363    const int target_height = crop_rect.max_y_ - crop_rect.min_y_;
364
365    const int offset_width = crop_rect.min_x_;
366    const int offset_height = crop_rect.min_y_;
367
368    // Ensure that the bounding box fits in the image dimensions.
369    OP_REQUIRES(context, width >= target_width + offset_width,
370                errors::FailedPrecondition(
371                    "width must be > target_width + offset_width: ", width,
372                    "vs ", target_width, " + ", offset_width));
373    OP_REQUIRES(context, height >= target_height + offset_height,
374                errors::FailedPrecondition(
375                    "height must be >= target_height: height = ", height, "vs ",
376                    target_height, " + ", offset_height));
377
378    // Create two vectors, each 3 elements, to provide as arguments to Slice.
379    // See Slice() operation for details.
380    Tensor* begin = nullptr;
381    OP_REQUIRES_OK(context,
382                   context->allocate_output(0, TensorShape({3}), &begin));
383    Tensor* size = nullptr;
384    OP_REQUIRES_OK(context,
385                   context->allocate_output(1, TensorShape({3}), &size));
386    Tensor* bboxes = nullptr;
387    OP_REQUIRES_OK(
388        context, context->allocate_output(2, TensorShape({1, 1, 4}), &bboxes));
389
390    typename TTypes<T, 1>::Tensor begin_data(begin->tensor<T, 1>());
391    typename TTypes<T, 1>::Tensor size_data(size->tensor<T, 1>());
392    TTypes<float, 3>::Tensor bboxes_data = bboxes->tensor<float, 3>();
393
394    begin_data(0) = T(offset_height);
395    size_data(0) = T(target_height);
396
397    begin_data(1) = T(offset_width);
398    size_data(1) = T(target_width);
399
400    bboxes_data(0, 0, 0) =
401        static_cast<float>(crop_rect.min_y_) / static_cast<float>(height);
402    bboxes_data(0, 0, 1) =
403        static_cast<float>(crop_rect.min_x_) / static_cast<float>(width);
404    bboxes_data(0, 0, 2) =
405        static_cast<float>(crop_rect.max_y_) / static_cast<float>(height);
406    bboxes_data(0, 0, 3) =
407        static_cast<float>(crop_rect.max_x_) / static_cast<float>(width);
408
409    // Retain all of the channels.
410    begin_data(2) = T(0);
411    size_data(2) = T(-1);
412  }
413
414 private:
415  GuardedPhiloxRandom generator_;
416  int32 max_attempts_;
417  std::vector<float> area_range_;
418  std::vector<float> aspect_ratio_range_;
419  float min_object_covered_;
420  bool use_image_if_no_bounding_boxes_;
421};
422
423#define REGISTER_KERNELS(type)                                  \
424  REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBox")    \
425                              .Device(DEVICE_CPU)               \
426                              .TypeConstraint<type>("T"),       \
427                          SampleDistortedBoundingBoxV2Op<type>) \
428  REGISTER_KERNEL_BUILDER(Name("SampleDistortedBoundingBoxV2")  \
429                              .Device(DEVICE_CPU)               \
430                              .TypeConstraint<type>("T"),       \
431                          SampleDistortedBoundingBoxV2Op<type>)
432
433TF_CALL_INTEGRAL_TYPES(REGISTER_KERNELS);
434#undef REGISTER_KERNELS
435
436}  // namespace tensorflow
437