1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementations of different data feeders to provide data for TF trainer."""
16
17# TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import itertools
24import math
25
26import numpy as np
27import six
28from six.moves import xrange  # pylint: disable=redefined-builtin
29
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import tensor_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.platform import tf_logging as logging
34
35# pylint: disable=g-multiple-import,g-bad-import-order
36from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels
37from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels
38
39# pylint: enable=g-multiple-import,g-bad-import-order
40
41
42def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None):
43  """Returns shape for input and output of the data feeder."""
44  x_is_dict, y_is_dict = isinstance(
45      x_shape, dict), y_shape is not None and isinstance(y_shape, dict)
46  if y_is_dict and n_classes is not None:
47    assert isinstance(n_classes, dict)
48
49  if batch_size is None:
50    batch_size = list(x_shape.values())[0][0] if x_is_dict else x_shape[0]
51  elif batch_size <= 0:
52    raise ValueError('Invalid batch_size %d.' % batch_size)
53
54  if x_is_dict:
55    input_shape = {}
56    for k, v in list(x_shape.items()):
57      input_shape[k] = [batch_size] + (list(v[1:]) if len(v) > 1 else [1])
58  else:
59    x_shape = list(x_shape[1:]) if len(x_shape) > 1 else [1]
60    input_shape = [batch_size] + x_shape
61
62  if y_shape is None:
63    return input_shape, None, batch_size
64
65  def out_el_shape(out_shape, num_classes):
66    out_shape = list(out_shape[1:]) if len(out_shape) > 1 else []
67    # Skip first dimension if it is 1.
68    if out_shape and out_shape[0] == 1:
69      out_shape = out_shape[1:]
70    if num_classes is not None and num_classes > 1:
71      return [batch_size] + out_shape + [num_classes]
72    else:
73      return [batch_size] + out_shape
74
75  if not y_is_dict:
76    output_shape = out_el_shape(y_shape, n_classes)
77  else:
78    output_shape = dict([(k,
79                          out_el_shape(v, n_classes[k]
80                                       if n_classes is not None and
81                                       k in n_classes else None))
82                         for k, v in list(y_shape.items())])
83
84  return input_shape, output_shape, batch_size
85
86
87def _data_type_filter(x, y):
88  """Filter data types into acceptable format."""
89  if HAS_DASK:
90    x = extract_dask_data(x)
91    if y is not None:
92      y = extract_dask_labels(y)
93  if HAS_PANDAS:
94    x = extract_pandas_data(x)
95    if y is not None:
96      y = extract_pandas_labels(y)
97  return x, y
98
99
100def _is_iterable(x):
101  return hasattr(x, 'next') or hasattr(x, '__next__')
102
103
104def setup_train_data_feeder(x,
105                            y,
106                            n_classes,
107                            batch_size=None,
108                            shuffle=True,
109                            epochs=None):
110  """Create data feeder, to sample inputs from dataset.
111
112  If `x` and `y` are iterators, use `StreamingDataFeeder`.
113
114  Args:
115    x: numpy, pandas or Dask matrix or dictionary of aforementioned. Also
116      supports iterables.
117    y: numpy, pandas or Dask array or dictionary of aforementioned. Also
118      supports
119      iterables.
120    n_classes: number of classes. Must be None or same type as y. In case, `y`
121      is `dict`
122      (or iterable which returns dict) such that `n_classes[key] = n_classes for
123        y[key]`
124    batch_size: size to split data into parts. Must be >= 1.
125    shuffle: Whether to shuffle the inputs.
126    epochs: Number of epochs to run.
127
128  Returns:
129    DataFeeder object that returns training data.
130
131  Raises:
132    ValueError: if one of `x` and `y` is iterable and the other is not.
133  """
134  x, y = _data_type_filter(x, y)
135  if HAS_DASK:
136    # pylint: disable=g-import-not-at-top
137    import dask.dataframe as dd
138    if (isinstance(x, (dd.Series, dd.DataFrame)) and
139        (y is None or isinstance(y, (dd.Series, dd.DataFrame)))):
140      data_feeder_cls = DaskDataFeeder
141    else:
142      data_feeder_cls = DataFeeder
143  else:
144    data_feeder_cls = DataFeeder
145
146  if _is_iterable(x):
147    if y is not None and not _is_iterable(y):
148      raise ValueError('Both x and y should be iterators for '
149                       'streaming learning to work.')
150    return StreamingDataFeeder(x, y, n_classes, batch_size)
151  return data_feeder_cls(
152      x, y, n_classes, batch_size, shuffle=shuffle, epochs=epochs)
153
154
155def _batch_data(x, batch_size=None):
156  if (batch_size is not None) and (batch_size <= 0):
157    raise ValueError('Invalid batch_size %d.' % batch_size)
158
159  x_first_el = six.next(x)
160  x = itertools.chain([x_first_el], x)
161
162  chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(
163      x_first_el, dict) else []
164  chunk_filled = False
165  for data in x:
166    if isinstance(data, dict):
167      for k, v in list(data.items()):
168        chunk[k].append(v)
169        if (batch_size is not None) and (len(chunk[k]) >= batch_size):
170          chunk[k] = np.matrix(chunk[k])
171          chunk_filled = True
172      if chunk_filled:
173        yield chunk
174        chunk = dict([(k, []) for k in list(x_first_el.keys())]) if isinstance(
175            x_first_el, dict) else []
176        chunk_filled = False
177    else:
178      chunk.append(data)
179      if (batch_size is not None) and (len(chunk) >= batch_size):
180        yield np.matrix(chunk)
181        chunk = []
182
183  if isinstance(x_first_el, dict):
184    for k, v in list(data.items()):
185      chunk[k] = np.matrix(chunk[k])
186    yield chunk
187  else:
188    yield np.matrix(chunk)
189
190
191def setup_predict_data_feeder(x, batch_size=None):
192  """Returns an iterable for feeding into predict step.
193
194  Args:
195    x: numpy, pandas, Dask array or dictionary of aforementioned. Also supports
196      iterable.
197    batch_size: Size of batches to split data into. If `None`, returns one
198      batch of full size.
199
200  Returns:
201    List or iterator (or dictionary thereof) of parts of data to predict on.
202
203  Raises:
204    ValueError: if `batch_size` <= 0.
205  """
206  if HAS_DASK:
207    x = extract_dask_data(x)
208  if HAS_PANDAS:
209    x = extract_pandas_data(x)
210  if _is_iterable(x):
211    return _batch_data(x, batch_size)
212  if len(x.shape) == 1:
213    x = np.reshape(x, (-1, 1))
214  if batch_size is not None:
215    if batch_size <= 0:
216      raise ValueError('Invalid batch_size %d.' % batch_size)
217    n_batches = int(math.ceil(float(len(x)) / batch_size))
218    return [x[i * batch_size:(i + 1) * batch_size] for i in xrange(n_batches)]
219  return [x]
220
221
222def setup_processor_data_feeder(x):
223  """Sets up processor iterable.
224
225  Args:
226    x: numpy, pandas or iterable.
227
228  Returns:
229    Iterable of data to process.
230  """
231  if HAS_PANDAS:
232    x = extract_pandas_matrix(x)
233  return x
234
235
236def check_array(array, dtype):
237  """Checks array on dtype and converts it if different.
238
239  Args:
240    array: Input array.
241    dtype: Expected dtype.
242
243  Returns:
244    Original array or converted.
245  """
246  # skip check if array is instance of other classes, e.g. h5py.Dataset
247  # to avoid copying array and loading whole data into memory
248  if isinstance(array, (np.ndarray, list)):
249    array = np.array(array, dtype=dtype, order=None, copy=False)
250  return array
251
252
253def _access(data, iloc):
254  """Accesses an element from collection, using integer location based indexing.
255
256  Args:
257    data: array-like. The collection to access
258    iloc: `int` or `list` of `int`s. Location(s) to access in `collection`
259
260  Returns:
261    The element of `a` found at location(s) `iloc`.
262  """
263  if HAS_PANDAS:
264    import pandas as pd  # pylint: disable=g-import-not-at-top
265    if isinstance(data, pd.Series) or isinstance(data, pd.DataFrame):
266      return data.iloc[iloc]
267  return data[iloc]
268
269
270def _check_dtype(dtype):
271  if dtypes.as_dtype(dtype) == dtypes.float64:
272    logging.warn(
273        'float64 is not supported by many models, consider casting to float32.')
274  return dtype
275
276
277class DataFeeder(object):
278  """Data feeder is an example class to sample data for TF trainer."""
279
280  def __init__(self,
281               x,
282               y,
283               n_classes,
284               batch_size=None,
285               shuffle=True,
286               random_state=None,
287               epochs=None):
288    """Initializes a DataFeeder instance.
289
290    Args:
291      x: One feature sample which can either Nd numpy matrix of shape
292        `[n_samples, n_features, ...]` or dictionary of Nd numpy matrix.
293      y: label vector, either floats for regression or class id for
294        classification. If matrix, will consider as a sequence of labels.
295        Can be `None` for unsupervised setting. Also supports dictionary of
296        labels.
297      n_classes: Number of classes, 0 and 1 are considered regression, `None`
298        will pass through the input labels without one-hot conversion. Also, if
299        `y` is `dict`, then `n_classes` must be `dict` such that
300        `n_classes[key] = n_classes for label y[key]`, `None` otherwise.
301      batch_size: Mini-batch size to accumulate samples in one mini batch.
302      shuffle: Whether to shuffle `x`.
303      random_state: Numpy `RandomState` object to reproduce sampling.
304      epochs: Number of times to iterate over input data before raising
305        `StopIteration` exception.
306
307    Attributes:
308      x: Input features (ndarray or dictionary of ndarrays).
309      y: Input label (ndarray or dictionary of ndarrays).
310      n_classes: Number of classes (if `None`, pass through indices without
311        one-hot conversion).
312      batch_size: Mini-batch size to accumulate.
313      input_shape: Shape of the input (or dictionary of shapes).
314      output_shape: Shape of the output (or dictionary of shapes).
315      input_dtype: DType of input (or dictionary of shapes).
316      output_dtype: DType of output (or dictionary of shapes.
317    """
318    x_is_dict, y_is_dict = isinstance(
319        x, dict), y is not None and isinstance(y, dict)
320    if isinstance(y, list):
321      y = np.array(y)
322
323    self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items())
324                   ]) if x_is_dict else check_array(x, x.dtype)
325    self._y = None if y is None else (dict(
326        [(k, check_array(v, v.dtype)) for k, v in list(y.items())])
327                                      if y_is_dict else check_array(y, y.dtype))
328
329    # self.n_classes is not None means we're converting raw target indices
330    # to one-hot.
331    if n_classes is not None:
332      if not y_is_dict:
333        y_dtype = (
334            np.int64 if n_classes is not None and n_classes > 1 else np.float32)
335        self._y = (None if y is None else check_array(y, dtype=y_dtype))
336
337    self.n_classes = n_classes
338    self.max_epochs = epochs
339
340    x_shape = dict([(k, v.shape) for k, v in list(self._x.items())
341                   ]) if x_is_dict else self._x.shape
342    y_shape = dict([(k, v.shape) for k, v in list(self._y.items())
343                   ]) if y_is_dict else None if y is None else self._y.shape
344
345    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
346        x_shape, y_shape, n_classes, batch_size)
347
348    # Input dtype matches dtype of x.
349    self._input_dtype = (
350        dict([(k, _check_dtype(v.dtype)) for k, v in list(self._x.items())])
351        if x_is_dict else _check_dtype(self._x.dtype))
352
353    # self._output_dtype == np.float32 when y is None
354    self._output_dtype = (
355        dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())])
356        if y_is_dict else (_check_dtype(self._y.dtype)
357                           if y is not None else np.float32))
358
359    # self.n_classes is None means we're passing in raw target indices
360    if n_classes is not None and y_is_dict:
361      for key in list(n_classes.keys()):
362        if key in self._output_dtype:
363          self._output_dtype[key] = np.float32
364
365    self._shuffle = shuffle
366    self.random_state = np.random.RandomState(
367        42) if random_state is None else random_state
368
369    if x_is_dict:
370      num_samples = list(self._x.values())[0].shape[0]
371    elif tensor_util.is_tensor(self._x):
372      num_samples = self._x.shape[
373          0].value  # shape will be a Dimension, extract an int
374    else:
375      num_samples = self._x.shape[0]
376
377    if self._shuffle:
378      self.indices = self.random_state.permutation(num_samples)
379    else:
380      self.indices = np.array(range(num_samples))
381    self.offset = 0
382    self.epoch = 0
383    self._epoch_placeholder = None
384
385  @property
386  def x(self):
387    return self._x
388
389  @property
390  def y(self):
391    return self._y
392
393  @property
394  def shuffle(self):
395    return self._shuffle
396
397  @property
398  def input_dtype(self):
399    return self._input_dtype
400
401  @property
402  def output_dtype(self):
403    return self._output_dtype
404
405  @property
406  def batch_size(self):
407    return self._batch_size
408
409  def make_epoch_variable(self):
410    """Adds a placeholder variable for the epoch to the graph.
411
412    Returns:
413      The epoch placeholder.
414    """
415    self._epoch_placeholder = array_ops.placeholder(
416        dtypes.int32, [1], name='epoch')
417    return self._epoch_placeholder
418
419  def input_builder(self):
420    """Builds inputs in the graph.
421
422    Returns:
423      Two placeholders for inputs and outputs.
424    """
425
426    def get_placeholder(shape, dtype, name_prepend):
427      if shape is None:
428        return None
429      if isinstance(shape, dict):
430        placeholder = {}
431        for key in list(shape.keys()):
432          placeholder[key] = array_ops.placeholder(
433              dtypes.as_dtype(dtype[key]), [None] + shape[key][1:],
434              name=name_prepend + '_' + key)
435      else:
436        placeholder = array_ops.placeholder(
437            dtypes.as_dtype(dtype), [None] + shape[1:], name=name_prepend)
438      return placeholder
439
440    self._input_placeholder = get_placeholder(self.input_shape,
441                                              self._input_dtype, 'input')
442    self._output_placeholder = get_placeholder(self.output_shape,
443                                               self._output_dtype, 'output')
444    return self._input_placeholder, self._output_placeholder
445
446  def set_placeholders(self, input_placeholder, output_placeholder):
447    """Sets placeholders for this data feeder.
448
449    Args:
450      input_placeholder: Placeholder for `x` variable. Should match shape
451        of the examples in the x dataset.
452      output_placeholder: Placeholder for `y` variable. Should match
453        shape of the examples in the y dataset. Can be `None`.
454    """
455    self._input_placeholder = input_placeholder
456    self._output_placeholder = output_placeholder
457
458  def get_feed_params(self):
459    """Function returns a `dict` with data feed params while training.
460
461    Returns:
462      A `dict` with data feed params while training.
463    """
464    return {
465        'epoch': self.epoch,
466        'offset': self.offset,
467        'batch_size': self._batch_size
468    }
469
470  def get_feed_dict_fn(self):
471    """Returns a function that samples data into given placeholders.
472
473    Returns:
474      A function that when called samples a random subset of batch size
475      from `x` and `y`.
476    """
477    x_is_dict, y_is_dict = isinstance(
478        self._x, dict), self._y is not None and isinstance(self._y, dict)
479
480    # Assign input features from random indices.
481    def extract(data, indices):
482      return (np.array(_access(data, indices)).reshape((indices.shape[0], 1))
483              if len(data.shape) == 1 else _access(data, indices))
484
485    # assign labels from random indices
486    def assign_label(data, shape, dtype, n_classes, indices):
487      shape[0] = indices.shape[0]
488      out = np.zeros(shape, dtype=dtype)
489      for i in xrange(out.shape[0]):
490        sample = indices[i]
491        # self.n_classes is None means we're passing in raw target indices
492        if n_classes is None:
493          out[i] = _access(data, sample)
494        else:
495          if n_classes > 1:
496            if len(shape) == 2:
497              out.itemset((i, int(_access(data, sample))), 1.0)
498            else:
499              for idx, value in enumerate(_access(data, sample)):
500                out.itemset(tuple([i, idx, value]), 1.0)
501          else:
502            out[i] = _access(data, sample)
503      return out
504
505    def _feed_dict_fn():
506      """Function that samples data into given placeholders."""
507      if self.max_epochs is not None and self.epoch + 1 > self.max_epochs:
508        raise StopIteration
509      assert self._input_placeholder is not None
510      feed_dict = {}
511      if self._epoch_placeholder is not None:
512        feed_dict[self._epoch_placeholder.name] = [self.epoch]
513
514      # Take next batch of indices.
515      x_len = list(
516          self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0]
517      end = min(x_len, self.offset + self._batch_size)
518      batch_indices = self.indices[self.offset:end]
519
520      # adding input placeholder
521      feed_dict.update(
522          dict([(self._input_placeholder[k].name, extract(v, batch_indices))
523                for k, v in list(self._x.items())]) if x_is_dict else {
524                    self._input_placeholder.name:
525                        extract(self._x, batch_indices)
526                })
527
528      # move offset and reset it if necessary
529      self.offset += self._batch_size
530      if self.offset >= x_len:
531        self.indices = self.random_state.permutation(
532            x_len) if self._shuffle else np.array(range(x_len))
533        self.offset = 0
534        self.epoch += 1
535
536      # return early if there are no labels
537      if self._output_placeholder is None:
538        return feed_dict
539
540      # adding output placeholders
541      if y_is_dict:
542        for k, v in list(self._y.items()):
543          n_classes = (self.n_classes[k] if k in self.n_classes else
544                       None) if self.n_classes is not None else None
545          shape, dtype = self.output_shape[k], self._output_dtype[k]
546          feed_dict.update({
547              self._output_placeholder[k].name:
548                  assign_label(v, shape, dtype, n_classes, batch_indices)
549          })
550      else:
551        shape, dtype, n_classes = (self.output_shape, self._output_dtype,
552                                   self.n_classes)
553        feed_dict.update({
554            self._output_placeholder.name:
555                assign_label(self._y, shape, dtype, n_classes, batch_indices)
556        })
557
558      return feed_dict
559
560    return _feed_dict_fn
561
562
563class StreamingDataFeeder(DataFeeder):
564  """Data feeder for TF trainer that reads data from iterator.
565
566  Streaming data feeder allows to read data as it comes it from disk or
567  somewhere else. It's custom to have this iterators rotate infinetly over
568  the dataset, to allow control of how much to learn on the trainer side.
569  """
570
571  def __init__(self, x, y, n_classes, batch_size):
572    """Initializes a StreamingDataFeeder instance.
573
574    Args:
575      x: iterator each element of which returns one feature sample. Sample can
576        be a Nd numpy matrix or dictionary of Nd numpy matrices.
577      y: iterator each element of which returns one label sample. Sample can be
578        a Nd numpy matrix or dictionary of Nd numpy matrices with 1 or many
579        classes regression values.
580      n_classes: indicator of how many classes the corresponding label sample
581        has for the purposes of one-hot conversion of label. In case where `y`
582        is a dictionary, `n_classes` must be dictionary (with same keys as `y`)
583        of how many classes there are in each label in `y`. If key is
584        present in `y` and missing in `n_classes`, the value is assumed `None`
585        and no one-hot conversion will be applied to the label with that key.
586      batch_size: Mini batch size to accumulate samples in one batch. If set
587        `None`, then assumes that iterator to return already batched element.
588
589    Attributes:
590      x: input features (or dictionary of input features).
591      y: input label (or dictionary of output features).
592      n_classes: number of classes.
593      batch_size: mini batch size to accumulate.
594      input_shape: shape of the input (can be dictionary depending on `x`).
595      output_shape: shape of the output (can be dictionary depending on `y`).
596      input_dtype: dtype of input (can be dictionary depending on `x`).
597      output_dtype: dtype of output (can be dictionary depending on `y`).
598    """
599    # pylint: disable=invalid-name,super-init-not-called
600    x_first_el = six.next(x)
601    self._x = itertools.chain([x_first_el], x)
602    if y is not None:
603      y_first_el = six.next(y)
604      self._y = itertools.chain([y_first_el], y)
605    else:
606      y_first_el = None
607      self._y = None
608    self.n_classes = n_classes
609
610    x_is_dict = isinstance(x_first_el, dict)
611    y_is_dict = y is not None and isinstance(y_first_el, dict)
612    if y_is_dict and n_classes is not None:
613      assert isinstance(n_classes, dict)
614
615    # extract shapes for first_elements
616    if x_is_dict:
617      x_first_el_shape = dict(
618          [(k, [1] + list(v.shape)) for k, v in list(x_first_el.items())])
619    else:
620      x_first_el_shape = [1] + list(x_first_el.shape)
621
622    if y_is_dict:
623      y_first_el_shape = dict(
624          [(k, [1] + list(v.shape)) for k, v in list(y_first_el.items())])
625    elif y is None:
626      y_first_el_shape = None
627    else:
628      y_first_el_shape = (
629          [1] + list(y_first_el[0].shape
630                     if isinstance(y_first_el, list) else y_first_el.shape))
631
632    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
633        x_first_el_shape, y_first_el_shape, n_classes, batch_size)
634
635    # Input dtype of x_first_el.
636    if x_is_dict:
637      self._input_dtype = dict(
638          [(k, _check_dtype(v.dtype)) for k, v in list(x_first_el.items())])
639    else:
640      self._input_dtype = _check_dtype(x_first_el.dtype)
641
642    # Output dtype of y_first_el.
643    def check_y_dtype(el):
644      if isinstance(el, np.ndarray):
645        return el.dtype
646      elif isinstance(el, list):
647        return check_y_dtype(el[0])
648      else:
649        return _check_dtype(np.dtype(type(el)))
650
651    # Output types are floats, due to both softmaxes and regression req.
652    if n_classes is not None and (y is None or not y_is_dict) and n_classes > 0:
653      self._output_dtype = np.float32
654    elif y_is_dict:
655      self._output_dtype = dict(
656          [(k, check_y_dtype(v)) for k, v in list(y_first_el.items())])
657    elif y is None:
658      self._output_dtype = None
659    else:
660      self._output_dtype = check_y_dtype(y_first_el)
661
662  def get_feed_params(self):
663    """Function returns a `dict` with data feed params while training.
664
665    Returns:
666      A `dict` with data feed params while training.
667    """
668    return {'batch_size': self._batch_size}
669
670  def get_feed_dict_fn(self):
671    """Returns a function, that will sample data and provide it to placeholders.
672
673    Returns:
674      A function that when called samples a random subset of batch size
675      from x and y.
676    """
677    self.stopped = False
678
679    def _feed_dict_fn():
680      """Samples data and provides it to placeholders.
681
682      Returns:
683        `dict` of input and output tensors.
684      """
685
686      def init_array(shape, dtype):
687        """Initialize array of given shape or dict of shapes and dtype."""
688        if shape is None:
689          return None
690        elif isinstance(shape, dict):
691          return dict(
692              [(k, np.zeros(shape[k], dtype[k])) for k in list(shape.keys())])
693        else:
694          return np.zeros(shape, dtype=dtype)
695
696      def put_data_array(dest, index, source=None, n_classes=None):
697        """Puts data array into container."""
698        if source is None:
699          dest = dest[:index]
700        elif n_classes is not None and n_classes > 1:
701          if len(self.output_shape) == 2:
702            dest.itemset((index, source), 1.0)
703          else:
704            for idx, value in enumerate(source):
705              dest.itemset(tuple([index, idx, value]), 1.0)
706        else:
707          if len(dest.shape) > 1:
708            dest[index, :] = source
709          else:
710            dest[index] = source[0] if isinstance(source, list) else source
711        return dest
712
713      def put_data_array_or_dict(holder, index, data=None, n_classes=None):
714        """Puts data array or data dictionary into container."""
715        if holder is None:
716          return None
717        if isinstance(holder, dict):
718          if data is None:
719            data = {k: None for k in holder.keys()}
720          assert isinstance(data, dict)
721          for k in holder.keys():
722            num_classes = n_classes[k] if (n_classes is not None and
723                                           k in n_classes) else None
724            holder[k] = put_data_array(holder[k], index, data[k], num_classes)
725        else:
726          holder = put_data_array(holder, index, data, n_classes)
727        return holder
728
729      if self.stopped:
730        raise StopIteration
731
732      inp = init_array(self.input_shape, self._input_dtype)
733      out = init_array(self.output_shape, self._output_dtype)
734
735      for i in xrange(self._batch_size):
736        # Add handling when queue ends.
737        try:
738          next_inp = six.next(self._x)
739          inp = put_data_array_or_dict(inp, i, next_inp, None)
740        except StopIteration:
741          self.stopped = True
742          if i == 0:
743            raise
744          inp = put_data_array_or_dict(inp, i, None, None)
745          out = put_data_array_or_dict(out, i, None, None)
746          break
747
748        if self._y is not None:
749          next_out = six.next(self._y)
750          out = put_data_array_or_dict(out, i, next_out, self.n_classes)
751
752      # creating feed_dict
753      if isinstance(inp, dict):
754        feed_dict = dict([(self._input_placeholder[k].name, inp[k])
755                          for k in list(self._input_placeholder.keys())])
756      else:
757        feed_dict = {self._input_placeholder.name: inp}
758      if self._y is not None:
759        if isinstance(out, dict):
760          feed_dict.update(
761              dict([(self._output_placeholder[k].name, out[k])
762                    for k in list(self._output_placeholder.keys())]))
763        else:
764          feed_dict.update({self._output_placeholder.name: out})
765
766      return feed_dict
767
768    return _feed_dict_fn
769
770
771class DaskDataFeeder(object):
772  """Data feeder for that reads data from dask.Series and dask.DataFrame.
773
774  Numpy arrays can be serialized to disk and it's possible to do random seeks
775  into them. DaskDataFeeder will remove requirement to have full dataset in the
776  memory and still do random seeks for sampling of batches.
777  """
778
779  def __init__(self,
780               x,
781               y,
782               n_classes,
783               batch_size,
784               shuffle=True,
785               random_state=None,
786               epochs=None):
787    """Initializes a DaskDataFeeder instance.
788
789    Args:
790      x: iterator that returns for each element, returns features.
791      y: iterator that returns for each element, returns 1 or many classes /
792        regression values.
793      n_classes: indicator of how many classes the label has.
794      batch_size: Mini batch size to accumulate.
795      shuffle: Whether to shuffle the inputs.
796      random_state: random state for RNG. Note that it will mutate so use a
797        int value for this if you want consistent sized batches.
798      epochs: Number of epochs to run.
799
800    Attributes:
801      x: input features.
802      y: input label.
803      n_classes: number of classes.
804      batch_size: mini batch size to accumulate.
805      input_shape: shape of the input.
806      output_shape: shape of the output.
807      input_dtype: dtype of input.
808      output_dtype: dtype of output.
809
810    Raises:
811      ValueError: if `x` or `y` are `dict`, as they are not supported currently.
812    """
813
814    if isinstance(x, dict) or isinstance(y, dict):
815      raise ValueError(
816          'DaskDataFeeder does not support dictionaries at the moment.')
817
818    # pylint: disable=invalid-name,super-init-not-called
819    import dask.dataframe as dd  # pylint: disable=g-import-not-at-top
820    # TODO(terrytangyuan): check x and y dtypes in dask_io like pandas
821    self._x = x
822    self._y = y
823    # save column names
824    self._x_columns = list(x.columns)
825    if isinstance(y.columns[0], str):
826      self._y_columns = list(y.columns)
827    else:
828      # deal with cases where two DFs have overlapped default numeric colnames
829      self._y_columns = len(self._x_columns) + 1
830      self._y = self._y.rename(columns={y.columns[0]: self._y_columns})
831
832    # TODO(terrytangyuan): deal with unsupervised cases
833    # combine into a data frame
834    self.df = dd.multi.concat([self._x, self._y], axis=1)
835    self.n_classes = n_classes
836
837    x_count = x.count().compute()[0]
838    x_shape = (x_count, len(self._x.columns))
839    y_shape = (x_count, len(self._y.columns))
840    # TODO(terrytangyuan): Add support for shuffle and epochs.
841    self._shuffle = shuffle
842    self.epochs = epochs
843    self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape(
844        x_shape, y_shape, n_classes, batch_size)
845    self.sample_fraction = self._batch_size / float(x_count)
846    self._input_dtype = _check_dtype(self._x.dtypes[0])
847    self._output_dtype = _check_dtype(self._y.dtypes[self._y_columns])
848    if random_state is None:
849      self.random_state = 66
850    else:
851      self.random_state = random_state
852
853  def get_feed_params(self):
854    """Function returns a `dict` with data feed params while training.
855
856    Returns:
857      A `dict` with data feed params while training.
858    """
859    return {'batch_size': self._batch_size}
860
861  def get_feed_dict_fn(self, input_placeholder, output_placeholder):
862    """Returns a function, that will sample data and provide it to placeholders.
863
864    Args:
865      input_placeholder: tf.placeholder for input features mini batch.
866      output_placeholder: tf.placeholder for output labels.
867
868    Returns:
869      A function that when called samples a random subset of batch size
870      from x and y.
871    """
872
873    def _feed_dict_fn():
874      """Samples data and provides it to placeholders."""
875      # TODO(ipolosukhin): option for with/without replacement (dev version of
876      # dask)
877      sample = self.df.random_split(
878          [self.sample_fraction, 1 - self.sample_fraction],
879          random_state=self.random_state)
880      inp = extract_pandas_matrix(sample[0][self._x_columns].compute()).tolist()
881      out = extract_pandas_matrix(sample[0][self._y_columns].compute())
882      # convert to correct dtype
883      inp = np.array(inp, dtype=self._input_dtype)
884      # one-hot encode out for each class for cross entropy loss
885      if HAS_PANDAS:
886        import pandas as pd  # pylint: disable=g-import-not-at-top
887        if not isinstance(out, pd.Series):
888          out = out.flatten()
889      out_max = self._y.max().compute().values[0]
890      encoded_out = np.zeros((out.size, out_max + 1), dtype=self._output_dtype)
891      encoded_out[np.arange(out.size), out] = 1
892      return {input_placeholder.name: inp, output_placeholder.name: encoded_out}
893
894    return _feed_dict_fn
895