1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/lib/io/two_level_iterator.h"
17
18#include "tensorflow/core/lib/io/block.h"
19#include "tensorflow/core/lib/io/format.h"
20#include "tensorflow/core/lib/io/iterator.h"
21#include "tensorflow/core/lib/io/table.h"
22
23namespace tensorflow {
24namespace table {
25
26namespace {
27
28typedef Iterator* (*BlockFunction)(void*, const StringPiece&);
29
30class TwoLevelIterator : public Iterator {
31 public:
32  TwoLevelIterator(Iterator* index_iter, BlockFunction block_function,
33                   void* arg);
34
35  ~TwoLevelIterator() override;
36
37  void Seek(const StringPiece& target) override;
38  void SeekToFirst() override;
39  void Next() override;
40
41  bool Valid() const override {
42    return (data_iter_ == nullptr) ? false : data_iter_->Valid();
43  }
44  StringPiece key() const override {
45    assert(Valid());
46    return data_iter_->key();
47  }
48  StringPiece value() const override {
49    assert(Valid());
50    return data_iter_->value();
51  }
52  Status status() const override {
53    // It'd be nice if status() returned a const Status& instead of a
54    // Status
55    if (!index_iter_->status().ok()) {
56      return index_iter_->status();
57    } else if (data_iter_ != nullptr && !data_iter_->status().ok()) {
58      return data_iter_->status();
59    } else {
60      return status_;
61    }
62  }
63
64 private:
65  void SaveError(const Status& s) {
66    if (status_.ok() && !s.ok()) status_ = s;
67  }
68  void SkipEmptyDataBlocksForward();
69  void SetDataIterator(Iterator* data_iter);
70  void InitDataBlock();
71
72  BlockFunction block_function_;
73  void* arg_;
74  Status status_;
75  Iterator* index_iter_;
76  Iterator* data_iter_;  // May be NULL
77  // If data_iter_ is non-NULL, then "data_block_handle_" holds the
78  // "index_value" passed to block_function_ to create the data_iter_.
79  string data_block_handle_;
80};
81
82TwoLevelIterator::TwoLevelIterator(Iterator* index_iter,
83                                   BlockFunction block_function, void* arg)
84    : block_function_(block_function),
85      arg_(arg),
86      index_iter_(index_iter),
87      data_iter_(nullptr) {}
88
89TwoLevelIterator::~TwoLevelIterator() {
90  delete index_iter_;
91  delete data_iter_;
92}
93
94void TwoLevelIterator::Seek(const StringPiece& target) {
95  index_iter_->Seek(target);
96  InitDataBlock();
97  if (data_iter_ != nullptr) data_iter_->Seek(target);
98  SkipEmptyDataBlocksForward();
99}
100
101void TwoLevelIterator::SeekToFirst() {
102  index_iter_->SeekToFirst();
103  InitDataBlock();
104  if (data_iter_ != nullptr) data_iter_->SeekToFirst();
105  SkipEmptyDataBlocksForward();
106}
107
108void TwoLevelIterator::Next() {
109  assert(Valid());
110  data_iter_->Next();
111  SkipEmptyDataBlocksForward();
112}
113
114void TwoLevelIterator::SkipEmptyDataBlocksForward() {
115  while (data_iter_ == nullptr || !data_iter_->Valid()) {
116    // Move to next block
117    if (!index_iter_->Valid()) {
118      SetDataIterator(nullptr);
119      return;
120    }
121    index_iter_->Next();
122    InitDataBlock();
123    if (data_iter_ != nullptr) data_iter_->SeekToFirst();
124  }
125}
126
127void TwoLevelIterator::SetDataIterator(Iterator* data_iter) {
128  if (data_iter_ != nullptr) {
129    SaveError(data_iter_->status());
130    delete data_iter_;
131  }
132  data_iter_ = data_iter;
133}
134
135void TwoLevelIterator::InitDataBlock() {
136  if (!index_iter_->Valid()) {
137    SetDataIterator(nullptr);
138  } else {
139    StringPiece handle = index_iter_->value();
140    if (data_iter_ != nullptr && handle.compare(data_block_handle_) == 0) {
141      // data_iter_ is already constructed with this iterator, so
142      // no need to change anything
143    } else {
144      Iterator* iter = (*block_function_)(arg_, handle);
145      data_block_handle_.assign(handle.data(), handle.size());
146      SetDataIterator(iter);
147    }
148  }
149}
150
151}  // namespace
152
153Iterator* NewTwoLevelIterator(Iterator* index_iter,
154                              BlockFunction block_function, void* arg) {
155  return new TwoLevelIterator(index_iter, block_function, arg);
156}
157
158}  // namespace table
159}  // namespace tensorflow
160