1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Methods to allow dict of numpy arrays."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23import numpy as np
24from six import string_types
25
26from tensorflow.python.estimator.inputs.queues import feeding_functions
27from tensorflow.python.util.tf_export import tf_export
28
29# Key name to pack the target into dict of `features`. See
30# `_get_unique_target_key` for details.
31_TARGET_KEY = '__target_key__'
32
33
34def _get_unique_target_key(features):
35  """Returns a key not existed in the input dict `features`.
36
37  Caller of `input_fn` usually provides `features` (dict of numpy arrays) and
38  `target`, but the underlying feeding module expects a single dict of numpy
39  arrays as input. So, the `target` needs to be packed into the `features`
40  temporarily and unpacked after calling the feeding function. Toward this goal,
41  this function returns a key not existed in the `features` to pack the
42  `target`.
43
44  Args:
45    features: OrderedDict of numpy arrays
46
47  Returns:
48    A unique key that can be used to insert the subsequent target into
49      features dict.
50  """
51  target_key = _TARGET_KEY
52  while target_key in features:
53    target_key += '_n'
54  return target_key
55
56
57def _validate_and_convert_features(x):
58  """Type check input data and make a shadow copy as an ordered dict.
59
60  Args:
61    x: numpy array object or dict of numpy array objects. If an array,
62      the array will be treated as a single feature.
63
64  Returns:
65    OrderedDict copy of x.
66
67  Raises:
68    ValueError: if x is empty
69    TypeError: if x is an unknown type.
70  """
71  if isinstance(x, dict):
72    if not x:
73      raise ValueError('x cannot be an empty dict')
74    # Make a shadow copy and also ensure the order of iteration is consistent.
75    ordered_dict_data = collections.OrderedDict(
76        sorted(x.items(), key=lambda t: t[0]))
77  elif isinstance(x, np.ndarray):
78    if x.size == 0:
79      raise ValueError('x cannot be an empty array')
80
81    # Make a shadow copy and convert to dict to align with dict processing.
82    ordered_dict_data = collections.OrderedDict({'__direct_np_input__': x})
83  else:
84    x_type = type(x).__name__
85    raise TypeError('x must be a dict or array; got {}'.format(x_type))
86
87  return ordered_dict_data
88
89
90@tf_export('estimator.inputs.numpy_input_fn')
91def numpy_input_fn(x,
92                   y=None,
93                   batch_size=128,
94                   num_epochs=1,
95                   shuffle=None,
96                   queue_capacity=1000,
97                   num_threads=1):
98  """Returns input function that would feed dict of numpy arrays into the model.
99
100  This returns a function outputting `features` and `targets` based on the dict
101  of numpy arrays. The dict `features` has the same keys as the `x`. The dict
102  `targets` has the same keys as the `y` if `y` is a dict.
103
104  Example:
105
106  ```python
107  age = np.arange(4) * 1.0
108  height = np.arange(32, 36)
109  x = {'age': age, 'height': height}
110  y = np.arange(-32, -28)
111
112  with tf.Session() as session:
113    input_fn = numpy_io.numpy_input_fn(
114        x, y, batch_size=2, shuffle=False, num_epochs=1)
115  ```
116
117  Args:
118    x: numpy array object or dict of numpy array objects. If an array,
119      the array will be treated as a single feature.
120    y: numpy array object or dict of numpy array object. `None` if absent.
121    batch_size: Integer, size of batches to return.
122    num_epochs: Integer, number of epochs to iterate over data. If `None` will
123      run forever.
124    shuffle: Boolean, if True shuffles the queue. Avoid shuffle at prediction
125      time.
126    queue_capacity: Integer, size of queue to accumulate.
127    num_threads: Integer, number of threads used for reading and enqueueing. In
128      order to have predicted and repeatable order of reading and enqueueing,
129      such as in prediction and evaluation mode, `num_threads` should be 1.
130
131  Returns:
132    Function, that has signature of ()->(dict of `features`, `targets`)
133
134  Raises:
135    ValueError: if the shape of `y` mismatches the shape of values in `x` (i.e.,
136      values in `x` have same shape).
137    ValueError: if duplicate keys are in both `x` and `y` when `y` is a dict.
138    ValueError: if x or y is an empty dict.
139    TypeError: `x` is not a dict or array, or if `shuffle` is not bool.
140  """
141  if not isinstance(shuffle, bool):
142    raise TypeError('shuffle must be explicitly set as boolean; '
143                    'got {}'.format(shuffle))
144
145  def input_fn():
146    """Numpy input function."""
147
148    # Note that `x` should not be used after conversion to ordered_dict_data,
149    # as type could be either dict or array.
150    ordered_dict_data = _validate_and_convert_features(x)
151
152    # Deep copy keys which is a view in python 3
153    feature_keys = list(ordered_dict_data.keys())
154
155    if y is None:
156      target_keys = None
157    elif isinstance(y, dict):
158      if not y:
159        raise ValueError('y cannot be empty dict, use None instead.')
160
161      ordered_dict_y = collections.OrderedDict(
162          sorted(y.items(), key=lambda t: t[0]))
163      target_keys = list(ordered_dict_y.keys())
164
165      duplicate_keys = set(feature_keys).intersection(set(target_keys))
166      if duplicate_keys:
167        raise ValueError('{} duplicate keys are found in both x and y: '
168                         '{}'.format(len(duplicate_keys), duplicate_keys))
169
170      ordered_dict_data.update(ordered_dict_y)
171    else:
172      target_keys = _get_unique_target_key(ordered_dict_data)
173      ordered_dict_data[target_keys] = y
174
175    if len(set(v.shape[0] for v in ordered_dict_data.values())) != 1:
176      shape_dict_of_x = {k: ordered_dict_data[k].shape for k in feature_keys}
177
178      if target_keys is None:
179        shape_of_y = None
180      elif isinstance(target_keys, string_types):
181        shape_of_y = y.shape
182      else:
183        shape_of_y = {k: ordered_dict_data[k].shape for k in target_keys}
184
185      raise ValueError('Length of tensors in x and y is mismatched. All '
186                       'elements in x and y must have the same length.\n'
187                       'Shapes in x: {}\n'
188                       'Shapes in y: {}\n'.format(shape_dict_of_x, shape_of_y))
189
190    queue = feeding_functions._enqueue_data(  # pylint: disable=protected-access
191        ordered_dict_data,
192        queue_capacity,
193        shuffle=shuffle,
194        num_threads=num_threads,
195        enqueue_size=batch_size,
196        num_epochs=num_epochs)
197
198    batch = (
199        queue.dequeue_many(batch_size)
200        if num_epochs is None else queue.dequeue_up_to(batch_size))
201
202    # Remove the first `Tensor` in `batch`, which is the row number.
203    if batch:
204      batch.pop(0)
205
206    if isinstance(x, np.ndarray):
207      # Return as the same type as original array.
208      features = batch[0]
209    else:
210      # Return as the original dict type
211      features = dict(zip(feature_keys, batch[:len(feature_keys)]))
212
213    if target_keys is None:
214      # TODO(martinwicke), return consistent result
215      return features
216    elif isinstance(target_keys, string_types):
217      target = batch[-1]
218      return features, target
219    else:
220      target = dict(zip(target_keys, batch[-len(target_keys):]))
221      return features, target
222
223  return input_fn
224