1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains code for the DataProvider.
16
17A DataProvider is a class which provides some predefined data types from some
18source (TFRecord, etc). The most basic function of a
19data provider is the `Get` operation where one requests one or more types of
20data, or 'items':
21
22  provider.get(items=['image', 'sentence', 'class'])
23
24More concretely, a data provider (a subclass of BaseDataProvider) returns a
25single tensor for each requested item (data type):
26
27  provider = MyDataProvider(...)
28  image, sentence, clazz = provider.get(['image', 'sentence', 'class'])
29
30In this example, the provider `MyDataProvider` must know how to load each item.
31A data provider may be written in a way that the logic necessary to map from
32each item to tensor is completely encapsulated within the data_provider itself.
33"""
34
35from __future__ import absolute_import
36from __future__ import division
37from __future__ import print_function
38
39import abc
40
41
42class DataProvider(object):
43  """Maps a list of requested data items to tensors from a data source.
44
45  All data providers must inherit from DataProvider and implement the Get
46  method which returns arbitrary types of data. No assumption is made about the
47  source of the data nor the mechanism for providing it.
48  """
49  __metaclass__ = abc.ABCMeta
50
51  def __init__(self, items_to_tensors, num_samples):
52    """Constructs the Data Provider.
53
54    Args:
55      items_to_tensors: a dictionary of names to tensors.
56      num_samples: the number of samples in the dataset being provided.
57    """
58    self._items_to_tensors = items_to_tensors
59    self._num_samples = num_samples
60
61  def get(self, items):
62    """Returns a list of tensors specified by the given list of items.
63
64    The list of items is arbitrary different data providers satisfy different
65    lists of items. For example the Pascal VOC might accept items 'image' and
66    'semantics', whereas the NYUDepthV2 data provider might accept items
67    'image', 'depths' and 'normals'.
68
69    Args:
70      items: a list of strings, each of which indicate a particular data type.
71
72    Returns:
73      a list of tensors, whose length matches the length of `items`, where each
74      tensor corresponds to each item.
75
76    Raises:
77      ValueError: if any of the items cannot be satisfied.
78    """
79    self._validate_items(items)
80    return [self._items_to_tensors[item] for item in items]
81
82  def list_items(self):
83    """Returns the list of item names that can be provided by the data provider.
84
85    Returns:
86      a list of item names that can be passed to Get([items]).
87    """
88    return self._items_to_tensors.keys()
89
90  def num_samples(self):
91    """Returns the number of data samples in the dataset.
92
93    Returns:
94      a positive whole number.
95    """
96    return self._num_samples
97
98  def _validate_items(self, items):
99    """Verifies that each given item is a member of the list from ListItems().
100
101    Args:
102      items: a list or tuple of strings.
103
104    Raises:
105      ValueError: if `items` is not a tuple or list or if any of the elements of
106        `items` is not found in the list provided by self.ListItems().
107    """
108    if not isinstance(items, (list, tuple)):
109      raise ValueError('items must be a list or tuple')
110
111    valid_items = self.list_items()
112    for item in items:
113      if item not in valid_items:
114        raise ValueError('Item [%s] is invalid. Valid entries include: %s' %
115                         (item, valid_items))
116