19d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
29d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer#
39d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Licensed under the Apache License, Version 2.0 (the "License");
49d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# you may not use this file except in compliance with the License.
59d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# You may obtain a copy of the License at
69d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer#
79d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer#     http://www.apache.org/licenses/LICENSE-2.0
89d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer#
99d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Unless required by applicable law or agreed to in writing, software
109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# distributed under the License is distributed on an "AS IS" BASIS,
119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# See the License for the specific language governing permissions and
139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# limitations under the License.
149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# ==============================================================================
159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer"""Core classes and core ops for LabeledTensor.
169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerCore ops are ops which will eventually be called by LabeledTensor methods,
189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerand ops which a core op depends upon.
199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerFor example, `add` is a core op because we'll eventually support the `+`
209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyeroperator.
219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerNon-core ops should go in `ops.py`.
229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer"""
239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom __future__ import absolute_import
249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom __future__ import division
259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom __future__ import print_function
269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport collections
289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport contextlib
299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport numbers
309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport types
319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport numpy as np
339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom six import binary_type
349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom six import string_types
359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom six import text_type
369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom six.moves import range  # pylint: disable=redefined-builtin
379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.framework import dtypes
409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.framework import ops
419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.framework import tensor_shape
429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.ops import array_ops
439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.ops import math_ops
449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: disable=invalid-name
469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Types coercible to Axis.labels
489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# We use this instead of collections.Sequence to exclude strings.
499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerLabelsLike = tc.Union(np.ndarray, range, list, tuple)
509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Types coercible to a tf.Dimension
529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerDimensionLike = tc.Optional(tc.Union(tensor_shape.Dimension, int))
539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Types usable for axis values
559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerAxisValue = tc.Union(LabelsLike, DimensionLike)
569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Valid scalar values for TensorFlow
589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerScalar = tc.Union(numbers.Number, bool, binary_type, text_type)
599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: enable=invalid-name
619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerclass Axis(object):
649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Size and label information for an axis.
659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Axis contains either a tf.Dimension indicating the size of an axis,
679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  or a tuple of tick labels for the axis.
689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  If tick labels are provided, they must be unique.
709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.accepts(object, string_types, AxisValue)
739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __init__(self, name, value):
749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """Construct an Axis.
759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Args:
779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      name: Name of the axis.
789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      value: Either None, an int or tf.Dimension giving the size of the axis,
799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        or a sequence that is not a string additionally providing coordinate
809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        (tick) labels.
819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Raises:
839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      ValueError: If the user provides labels with duplicate values.
849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """
859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if isinstance(value, tensor_shape.Dimension):
869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      dimension = value
879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      labels = None
889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    elif isinstance(value, int) or value is None:
899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      dimension = tensor_shape.Dimension(value)
909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      labels = None
919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    else:
929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      dimension = tensor_shape.Dimension(len(value))
939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      labels = tuple(value)
949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if dimension.value == 0:
969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      # Treat a zero-length axis as if it has labels.
979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      labels = ()
989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if labels is not None:
1009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      index = dict(zip(labels, range(len(labels))))
1019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if len(index) != len(labels):
1029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        raise ValueError('Tick labels must be unique, but got {}'
1039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                         .format(labels))
1049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    else:
1059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      index = None
1069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    self._name = name  # type: string_types
1089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    self._dimension = dimension  # type: tensor_shape.Dimension
1099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    self._labels = labels  # type: Optional[tuple]
1109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    self._index = index  # type: Optional[Dict[Any, int]]
1119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
1139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(string_types)
1149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def name(self):
1159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._name
1169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(string_types)
1189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __repr__(self):
1199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # Axis('x', Dimension(2))
1209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # TODO(shoyer): make very long reprs more succint?
1219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return "%s('%s', %r)" % (type(self).__name__, self.name, self.value)
1229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(bool)
1249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __eq__(self, other):
1256e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    return (isinstance(other, Axis) and self.name == other.name and
1266e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower            self.size == other.size and self.labels == other.labels)
1279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __hash__(self):
1299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return hash((self.name, self.size, self.labels))
1309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(bool)
1329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __ne__(self, other):
1339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return not self == other
1349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(int)
1369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __len__(self):
1379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    size = self.size
1389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if size is None:
1399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      raise ValueError('axis %r has unknown length' % self.name)
1409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return size
1419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
1439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(tc.Optional(tensor_shape.Dimension))
1449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def dimension(self):
1459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._dimension
1469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
1489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(tc.Optional(int))
1499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def size(self):
1509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._dimension.value
1519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
1539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(tc.Union(tuple, tensor_shape.Dimension))
1549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def value(self):
1559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """Returns the tf.Dimension or tuple specifying axis ticks."""
1569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if self.labels is None:
1579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      return self.dimension
1589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    else:
1599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      return self.labels
1609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
1629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(tc.Optional(tuple))
1639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def labels(self):
1649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """Returns the tuple containing coordinate labels, else None."""
1659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._labels
1669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def index(self, value):
1689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """Returns the integer position of the given tick label."""
1699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if self._index is None:
1709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      raise ValueError('Axis does not have tick labels')
1719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._index[value]
1729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# tc class for anything that can be coerced into an Axis
1759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: disable=invalid-name
1769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerAxisLike = tc.Union(Axis, tc.Tuple(string_types, AxisValue))
1779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: enable=invalid-name
1789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(Axis)
1819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(AxisLike)
1829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef as_axis(axis_data):
1839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Convert an AxisLike object into an Axis.
1849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
1869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_data: Axis object or tuple (axis_name, axis_value) describing an axis.
1879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
1899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Axis object. This may be the original object if axis_data is an Axis.
1909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
1919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if isinstance(axis_data, Axis):
1929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis = axis_data
1939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  else:
1949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis = Axis(*axis_data)
1959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  return axis
1969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
1989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerclass Axes(collections.Mapping):
1999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Axis names and indices for a tensor.
2009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  It is an ordered mapping, with keys given by axis name and values given
202a0ffaf3caa0234653035a692858606c7bdacd63bFrank Chen  by Axis objects. Duplicate axis names are not allowed.
2039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
2049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.accepts(object, tc.List(AxisLike))
2069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __init__(self, axes):
2079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """Construct an Axes.
2089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Args:
2109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      axes: A list of Axis objects or (axis_name, axis_value) tuples.
2119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Raises:
2139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      ValueError: If the user provides empty or duplicate axis names.
2149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """
2159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    self._axes = collections.OrderedDict()
2169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    for axis_data in axes:
2189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      axis = as_axis(axis_data)
2199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      name = axis.name
2219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if name in self._axes:
2229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        raise ValueError('Duplicate axis name: %s' % name)
2239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      self._axes[name] = axis
2259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __iter__(self):
2279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return iter(self._axes)
2289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(string_types)
2309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __repr__(self):
2319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # Axes([('x', Dimension(2)),
2329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    #       ('y', ['a', 'b', 'c']),
2339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    #       ('z', Dimension(4))])
2349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    cls_name = type(self).__name__
2359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    values = ["('%s', %r)" % (v.name, v.value) for v in self._axes.values()]
2369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    values_repr = (',\n' + ' ' * len(cls_name + '([')).join(values)
2379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return '%s([%s])' % (cls_name, values_repr)
2389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(Axis)
2409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.accepts(object, string_types)
2419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __getitem__(self, name):
2429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._axes[name]
2439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(bool)
2459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __contains__(self, name):
2469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return name in self._axes
2479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(int)
2499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __len__(self):
2509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return len(self._axes)
2519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __hash__(self):
2539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return hash(tuple(self.items()))
2549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.accepts(object, string_types)
2569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def remove(self, axis_name):
2579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """Creates a new Axes object without the given axis."""
2589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if axis_name not in self:
2599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      raise KeyError(axis_name)
2609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    remaining_axes = [axis for axis in self.values() if axis.name != axis_name]
2619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return Axes(remaining_axes)
2629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerclass LabeledTensor(object):
2659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """A tensor with annotated axes.
2669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  It has the following invariants:
2689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    1) The dimensionality of the tensor is equal to the number of elements
2699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    in axes.
2709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    2) The number of coordinate values in the ith dimension is equal to the
2719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    size of the tensor in the ith dimension.
2729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Attributes:
2749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    tensor: tf.Tensor containing the data.
2759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes: lt.Axes containing axis names and coordinate labels.
2769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
2779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2783866dd2b20e5b63ad4ea8d0edb7961a2252d906dJonathan Hseu  @tc.accepts(object, ops.Tensor,
2799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer              tc.Union(Axes, tc.Collection(tc.Union(string_types, AxisLike))))
2809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __init__(self, tensor, axes):
281e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    """Construct a LabeledTensor.
2829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Args:
2849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      tensor: The underlying tensor containing the data.
2859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      axes: An Axes object, or a collection of strings, Axis objects or tuples
2869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        of (name, value) pairs indicating the axes.
2879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Raises:
2899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      ValueError: If the provided axes do not satisfy the class invariants.
2909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """
2919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    self._tensor = tensor
2929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    shape = tensor.get_shape()
2939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if isinstance(axes, Axes):
2959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      unvalidated_axes = axes
2969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    else:
2979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      mutable_axes = []
2989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
2999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      for position, axis_like in enumerate(axes):
3009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        if isinstance(axis_like, string_types):
3019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          # The coordinates for this axes are unlabeled.
3029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          # Infer the size of the axis.
3039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          value = shape[position]
3049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          axis_like = (axis_like, value)
3059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        mutable_axes.append(axis_like)
3079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      # Construct the Axis object, which will additionally validate the contents
3099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      # of the object.
3109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      unvalidated_axes = Axes(mutable_axes)
3119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # Check our invariants.
3139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # First, the rank of the tensor must be equal to the number of axes.
3159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if len(shape) != len(unvalidated_axes):
3169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      raise ValueError('Tensor rank was not equal to the number of axes: %r, %r'
3179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                       % (shape, unvalidated_axes))
3189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # Second, the size of each tensor dimension must match the size of the
3209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # corresponding indices.
3219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    for (d, axis) in zip(shape, unvalidated_axes.values()):
3229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if d != axis.size:
3239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        raise ValueError(
3249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer            'Provided axis size %d does not match tensor dimension size %d' %
3259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer            (axis.size, d))
3269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    self._axes = unvalidated_axes
3289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __repr__(self):
3309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # <LabeledTensor 'foo' shape=(2, 3, 4) dtype=float32
3319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    #  axes=[('x', Dimension(2)),
3329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    #        ('y', ('a', 'b', 'c'),
3339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    #        ('z', Dimension(4))]>
3349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes = ["('%s', %r)" % (v.name, v.value) for v in self.axes.values()]
3359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes_repr = (',\n' + ' ' * len(' axes=[')).join(axes)
3369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return ("<%s '%s' shape=%s dtype=%s\n axes=[%s]>" %
3379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer            (type(self).__name__, self.tensor.name, self.tensor.get_shape(),
3389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer             self.tensor.dtype.name, axes_repr))
3399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
3419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def tensor(self):
3429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._tensor
3439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def _as_graph_element(self):
3459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """Support tf.Graph.as_graph_element on LabeledTensor objects.
3469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    This allows operations such as tf.name_scope to take labeled tensors.
3489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Returns:
3509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      self.tensor
3519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """
3529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self.tensor
3539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
3559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def axes(self):
3569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._axes
3579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # properties/methods directly borrowed from tf.Tensor:
3599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
3619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def dtype(self):
3629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._tensor.dtype
3639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @property
365128572c316e6f2eb6346f920314ef98e88e75069A. Unique TensorFlower  def shape(self):
366128572c316e6f2eb6346f920314ef98e88e75069A. Unique TensorFlower    return self._tensor.shape
367128572c316e6f2eb6346f920314ef98e88e75069A. Unique TensorFlower
368128572c316e6f2eb6346f920314ef98e88e75069A. Unique TensorFlower  @property
3699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def name(self):
3709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._tensor.name
3719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def get_shape(self):
3739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """Returns the TensorShape that represents the shape of this tensor.
3749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    See tf.Tensor.get_shape().
3769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Returns:
3789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      A TensorShape representing the shape of this tensor.
3799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """
3809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self._tensor.get_shape()
3819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # TODO(shoyer): consider how/if to implement .eval(). Maybe it should return
3839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # an xarray.DataArray?
3849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __getitem__(self, key):
3869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # This should work exactly like tf.Tensor.__getitem__, except it preserves
3879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # labels.
3889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if not isinstance(key, tuple):
3899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      key = (key,)
3909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if len(key) != len(self.axes):
3919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      raise ValueError('indexer %r must have the same length as the Tensor '
3929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                       'rank (%r)' % (key, len(self.axes)))
3939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    selection = {a: k for a, k in zip(self.axes.keys(), key)}
3949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return slice_function(self, selection)
3959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # special methods for overloading arithmetic operations:
3979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
3989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __abs__(self):
3999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return abs_function(self)
4009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __neg__(self):
4029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return neg(self)
4039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __pos__(self):
4059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self
4069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __add__(self, other):
4089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return add(self, other)
4099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __radd__(self, other):
4119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return add(other, self)
4129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __sub__(self, other):
4149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return sub(self, other)
4159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __rsub__(self, other):
4179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return sub(other, self)
4189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __mul__(self, other):
4209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return mul(self, other)
4219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __rmul__(self, other):
4239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return mul(other, self)
4249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __truediv__(self, other):
4269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return div(self, other)
4279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  __div__ = __truediv__
4299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __rtruediv__(self, other):
4319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return div(other, self)
4329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  __rdiv__ = __rtruediv__
4349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __mod__(self, other):
4369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return mod(self, other)
4379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __rmod__(self, other):
4399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return mod(other, self)
4409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __pow__(self, other):
4429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return pow_function(self, other)
4439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __rpow__(self, other):
4459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return pow_function(other, self)
4469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # logical operations:
4489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __invert__(self):
4509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return logical_not(self)
4519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __and__(self, other):
4539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return logical_and(self, other)
4549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __or__(self, other):
4569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return logical_or(self, other)
4579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __xor__(self, other):
4599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return logical_xor(self, other)
4609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # boolean operations:
4629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __lt__(self, other):
4649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return less(self, other)
4659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __le__(self, other):
4679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return less_equal(self, other)
4689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __gt__(self, other):
4709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return greater(self, other)
4719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __ge__(self, other):
4739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return greater_equal(self, other)
4749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __eq__(self, other):
4769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    # for consistency with tf.Tensor
4779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if not isinstance(other, LabeledTensor):
4789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      return False
4799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return self.tensor == other.tensor and self.axes == other.axes
4819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __ne__(self, other):
4839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return not self == other
4849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def __hash__(self):
4869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return hash((self.tensor, self.axes))
4879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# typecheck type abbreviations:
4909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# abbreviations for third-party types with very long reprs
4919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(tensor_shape.Dimension, 'tensorflow.Dimension')
4923866dd2b20e5b63ad4ea8d0edb7961a2252d906dJonathan Hseutc.register_type_abbreviation(ops.Tensor, 'tensorflow.Tensor')
4939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(dtypes.DType, 'tensorflow.DType')
4949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# core LabeledTensor types
4959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(Axis, 'labeled_tensor.Axis')
4969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(Axes, 'labeled_tensor.Axes')
4979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(LabeledTensor, 'labeled_tensor.LabeledTensor')
4989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
4999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5003866dd2b20e5b63ad4ea8d0edb7961a2252d906dJonathan Hseu@tc.returns(ops.Tensor)
5019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensor)
5029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef _convert_labeled_tensor_to_tensor(value, *args, **kwargs):
5039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # call ops.convert_to_tensor to handle optional arguments appropriately
5040672d414381b66dd63cc58ea860f8dd93af0083cA. Unique TensorFlower  return ops.internal_convert_to_tensor(value.tensor, *args, **kwargs)
5059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5076e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlowerops.register_tensor_conversion_function(LabeledTensor,
5086e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower                                        _convert_labeled_tensor_to_tensor)
5099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# tc class for anything that can be coerced into a LabeledTensor
5119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: disable=invalid-name
5123866dd2b20e5b63ad4ea8d0edb7961a2252d906dJonathan HseuLabeledTensorLike = tc.Union(LabeledTensor, ops.Tensor, np.ndarray, Scalar)
5139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: enable=invalid-name
5149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor)
5179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensorLike, object, tc.Optional(string_types))
5189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef convert_to_labeled_tensor(value, dtype=None, name=None):
5199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Converts the given `value` to a `LabeledTensor`.
5209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
52185eeec0d415a1478bbeffc3d4545c795bee64e9fJonathan Hseu  This function accepts `LabeledTensor` objects, 0-dimensional `Tensor` objects
5229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  and numpy arrays, and Python scalars. Higher dimensional unlabeled tensors
5239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  must use the `LabeledTensor` constructor explicitly.
5249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
5269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    value: Object to convert.
5279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    dtype: Optional element type for the returned tensor. If missing, the type
5289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      is inferred from the type of value.
5299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    name: Optional name to use if a new Tensor is created.
5309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
5329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    `value` converted into a `LabeledTensor` object.
5339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Raises:
5359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    ValueError: If the output would have rank>0 but the input was not already a
5369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      `LabeledTensor`.
5379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
5389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # TODO(shoyer): consider extending to accept xarray.DataArray as input.
5399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if isinstance(value, LabeledTensor):
5409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes = value.axes.values()
5419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    value = value.tensor
5429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  else:
5439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes = []
5449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # We call convert_to_tensor even for LabeledTensor input because it also
5469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # checks to make sure the dtype argument is compatible.
5479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  tensor = ops.convert_to_tensor(value, dtype=dtype, name=name)
5489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if len(tensor.get_shape()) != len(axes):
5499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    raise ValueError('cannot automatically convert unlabeled arrays or tensors '
5509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                     'with rank>0 into LabeledTensors: %r' % value)
5519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  return LabeledTensor(tensor, axes)
5529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(Axis)
5559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(tc.Collection(Axis))
5569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef concat_axes(axes):
5579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Concatenate a list of Axes.
5589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
5609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes: A collection of Axis objects.
5619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
5639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    The concatenation of the axes.
5649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    If all axes have labels, the result has the concatenation of the labels.
5659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Else, the result has no labels, and its size is the sum of the sizes
5669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    of the axes.
5679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Raises:
5699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    ValueError: If `others` is not a collection of Axes or if it is empty.
5709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
5719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if not axes:
5729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    raise ValueError('axes must not be empty')
5739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  for a in axes:
5749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if not isinstance(a, Axis):
5759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      raise ValueError('Expected an Axis, but got %r of type %r' % (a, type(a)))
5769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  names = set(a.name for a in axes)
5789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if len(names) > 1:
5799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    raise ValueError('axes do not all have the same name: %r' % names)
5809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  name, = names
5819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  all_have_labels = all(a.labels is not None for a in axes)
5839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  any_has_unknown_size = any(a.size is None for a in axes)
5849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if all_have_labels:
5869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    value = tuple(label for a in axes for label in a.labels)
5879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  elif any_has_unknown_size:
5889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    value = None
5899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  else:
5909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    value = sum(len(a) for a in axes)
5919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  return Axis(name, value)
5929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor)
5959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensorLike, tc.Optional(string_types))
5969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef identity(labeled_tensor, name=None):
5979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """The identity op.
5989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
5999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  See tf.identity.
6009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
6029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor: The input tensor.
6039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    name: Optional op name.
6049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
6069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    The tensor.
6079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
6089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  with ops.name_scope(name, 'lt_identity', [labeled_tensor]) as scope:
6099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
6109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return LabeledTensor(
6116e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower        array_ops.identity(
6126e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower            labeled_tensor.tensor, name=scope),
6139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        labeled_tensor.axes)
6149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# We don't call this slice because that shadows a built-in. Instead, we alias
6179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# this to lt.slice in __init__.py.
6189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor)
6196e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower@tc.accepts(LabeledTensorLike,
6206e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower            tc.Mapping(string_types, tc.Union(int, slice)),
6219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer            tc.Optional(string_types))
6229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef slice_function(labeled_tensor, selection, name=None):
6239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Slice out a subset of the tensor.
6249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6251b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu  This is an analog of tf.slice.
6269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  For example:
6279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  >>> tensor = tf.reshape(tf.range(0, 6), [3, 2])
6289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  >>> labeled_tensor = lt.LabeledTensor(tensor, ['a', ('b', ['foo', 'bar'])])
6299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  >>> lt.slice(labeled_tensor, {'a': slice(0, 2), 'b': 1})
6309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  <LabeledTensor 'lt_slice:...' shape=(2,) dtype=int32
6319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer   axes=[('a', Dimension(2))]>
6329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
6349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor: The input tensor.
6359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    selection: A dictionary of type str -> Union(int, slice of int) mapping
6369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      axis names to sub-selections.
6379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    name: Optional op name.
6389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
6409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    The slice as a `LabeledTensor`.
6419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
6429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  with ops.name_scope(name, 'lt_slice', [labeled_tensor]) as scope:
6439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
6449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    slices = []
6469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    for axis_name in labeled_tensor.axes:
6489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if axis_name not in selection:
6499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        # We're not sub-selecting this axis, so use the full slice.
6509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        slices.append(slice(None))
6519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      else:
6529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        slices.append(selection[axis_name])
6539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    sliced_tensor = labeled_tensor.tensor[tuple(slices)]
6559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    sliced_axes = []
6579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    for axis, s in zip(labeled_tensor.axes.values(), slices):
6589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      # We sub-select this axis's index with the slice s.
6599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      # `s` is either an int or a proper slice.
6619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if isinstance(s, slice):
6629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        if axis.labels is None:
6639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          # We're not tracking coordinate names for this axis.
6649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          sliced_axes.append(axis.name)
6659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        else:
6669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          sliced_axes.append((axis.name, axis.labels[s]))
6679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      else:
6689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        # If the slice is an int this dimension now has size 1, so we remove it.
6699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        assert isinstance(s, int)
6709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6716e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    return LabeledTensor(
6726e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower        array_ops.identity(
6736e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower            sliced_tensor, name=scope), sliced_axes)
6749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor)
6776e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower@tc.accepts(LabeledTensorLike,
6786e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower            tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
6799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef transpose(labeled_tensor, axis_order=None, name=None):
6809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Permute a tensor's axes.
6819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  See tf.transpose.
6839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
6859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor: The input tensor.
6869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order: Optional desired axis order, as a list of names. By default, the
6879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      order of axes is reversed.
6889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    name: Optional op name.
6899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
6919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    The permuted tensor.
6929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Raises:
6949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    ValueError: If axis_order isn't a permutation of the existing axes.
6959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
6969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  with ops.name_scope(name, 'lt_transpose', [labeled_tensor]) as scope:
6979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
6989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
6999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    original_order = list(labeled_tensor.axes.keys())
7009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if axis_order is None:
7019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      axis_order = list(reversed(original_order))
7029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    elif sorted(axis_order) != sorted(original_order):
7039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      raise ValueError(
7049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          'The new axis order must have the same names as the original axes, '
7059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          'but the new order is %r while the original order is %r' %
7069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          (axis_order, original_order))
7079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_names = list(labeled_tensor.axes.keys())
7099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    permutation = [axis_names.index(n) for n in axis_order]
7109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7111b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu    # Note: TensorFlow doesn't copy data for the identity transpose.
7126e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    transpose_tensor = array_ops.transpose(
7136e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower        labeled_tensor.tensor, permutation, name=scope)
7149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    permuted_axes = [labeled_tensor.axes[n] for n in axis_order]
7169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return LabeledTensor(transpose_tensor, permuted_axes)
7189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor)
7216e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower@tc.accepts(
7226e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    LabeledTensorLike,
7236e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    tc.Collection(
7246e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower        tc.Union(string_types, tc.Tuple(string_types, collections.Hashable))),
7256e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    tc.Optional(string_types))
7269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef expand_dims(labeled_tensor, axes, name=None):
7279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Insert dimensions of size 1.
7289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  See tf.expand_dims.
7309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
7329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor: The input tensor.
7339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes: The desired axis names as strings or tuples of (name, label),
7349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      where `label` is the coordinate name for the new dimension `name`.
7359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      These must include the existing axis names, and the existing names must
7369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      appear in the same order in this list as they do in the input tensor.
7379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    name: Optional op name.
7389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
7409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    A tensor with an axis for each axis in axes.
7419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    New axes are created with size 1 and do not have labeled coordinates.
7429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Raises:
7449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    AxisOrderError: If axis names don't appear in the same order in axes
7459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      and the labeled tensor.
7469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
7479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  with ops.name_scope(name, 'lt_expand_dims', [labeled_tensor]) as scope:
7489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
7499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_names = [a if isinstance(a, string_types) else a[0] for a in axes]
7519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    check_axis_order(labeled_tensor, axis_names)
7529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    reshaped_axes = []
7549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    shape = []
7559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    for axis_spec in axes:
7569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if axis_spec in labeled_tensor.axes:
7579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        axis = labeled_tensor.axes[axis_spec]
7589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        reshaped_axes.append(axis)
7599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        shape.append(-1 if axis.size is None else axis.size)
7609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      else:
7619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        if isinstance(axis_spec, string_types):
7629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          reshaped_axes.append((axis_spec, 1))
7639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        else:
7649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          (name, label) = axis_spec
7659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          reshaped_axes.append((name, (label,)))
7669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        shape.append(1)
7689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7696e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    reshaped_tensor = array_ops.reshape(
7706e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower        labeled_tensor.tensor, shape, name=scope)
7719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return LabeledTensor(reshaped_tensor, reshaped_axes)
7739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7746e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower
7759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# This should only be added to a graph collection once.
7769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer_AXIS_ORDER_KEY = ('__axis_order',)
7779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(tc.Optional(tc.List(string_types)))
7809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef get_axis_order():
7819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Get the axis_order set by any containing axis_order_scope.
7829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
7849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    List of strings giving an order to use for axis names, or None, if no axis
7859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    order is set.
7869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
7879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # By storing axis_order in the graph, we can ensure that axis_order_scope is
7889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  # thread-safe.
7899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  axis_order_list = ops.get_collection(_AXIS_ORDER_KEY)
7909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if axis_order_list:
7919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order, = axis_order_list
7929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  else:
7939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order = None
7949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  return axis_order
7959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
7979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(tc.Optional(tc.List(string_types)))
7989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef _set_axis_order(axis_order):
7999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  axis_order_list = ops.get_collection_ref(_AXIS_ORDER_KEY)
8009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if axis_order_list:
8019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order_list[0] = axis_order
8029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  else:
8039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order_list.append(axis_order)
8049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@contextlib.contextmanager
8079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(tc.Optional(tc.List(string_types)))
8089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef axis_order_scope(axis_order=None):
8099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Set axis order for the result of broadcasting operations within a scope.
8109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  This allows you to ensure that tensors resulting from arithmetic have a
8129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  predictable axis order.
8139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Example usage:
8159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    with lt.axis_order_scope(['x', 'y', 'z']):
817ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      # result is guaranteed to have the correct axis order
8189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      result = w + b
8199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  You can nest scopes, in which case only the inner-most scope applies, e.g.,
8219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    with lt.axis_order(['x', 'y', 'z']):
8239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      with lt.axis_order():
8249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        result = w + b  # uses the default (left-most) axis ordering
8259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
8279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order: optional list of strings providing axis names. By default,
8289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      creates a scope without axis order.
8299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Yields:
8319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    The provided axis_order or `None`.
8329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
8339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  original_axis_order = get_axis_order()
8349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  _set_axis_order(axis_order)
8359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  try:
8369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    yield axis_order
8379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  finally:
8389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    _set_axis_order(original_axis_order)
8399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(tc.List(string_types))
8429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef _get_valid_axis_order():
8439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  axis_order = get_axis_order()
8449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if axis_order is None:
8459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    raise AxisOrderError('an explicit axis order must be provided with the '
8469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                         'axis_order argument or by using an axis_order_scope')
8479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  return axis_order
8489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerclass AxisOrderError(ValueError):
8519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Error class for cases where there is no valid axis order."""
8529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# TODO(shoyer): should this function accept a list of labeled tensors instead?
8559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(type(None))
8569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types)))
8579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef check_axis_order(labeled_tensor, axis_order=None):
8589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Verify that the given tensor has a consistent axis order.
8599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
8619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor: The input tensor. All axes on this tensor must appear in
8629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      axis_order.
8639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order: Optional desired axis order, as a list of names. If not
8649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      provided, defaults to the current axis_order_scope (if set).
8659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Raises:
8679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    AxisOrderError: If the axis_order is unavailable, inconsistent or does not
8689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      include all existing axes.
8699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
8709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
8719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if axis_order is None:
8739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order = _get_valid_axis_order()
8749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes]
8769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if len(relevant_axis_order) < len(labeled_tensor.axes):
8789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    raise AxisOrderError(
8799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        'not all axis names appear in the required axis order %r: %r' %
8809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        (axis_order, labeled_tensor))
8819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  if relevant_axis_order != list(labeled_tensor.axes):
8839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    raise AxisOrderError(
8849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        'axes on a labeled tensor do not appear in the same order as the '
8859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        'required axis order %r: %r' % (axis_order, labeled_tensor))
8869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor)
8896e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower@tc.accepts(LabeledTensorLike,
8906e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower            tc.Optional(tc.Collection(string_types)), tc.Optional(string_types))
8919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef impose_axis_order(labeled_tensor, axis_order=None, name=None):
8929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Impose desired axis order on a labeled tensor.
8939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
8949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
8959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor: The input tensor.
8969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_order: Optional desired axis order, as a list of names. If not
8979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      provided, defaults to the current axis_order_scope (if set).
8989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    name: Optional op name.
8999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
9019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Labeled tensor with possibly transposed axes.
9029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Raises:
9049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    AxisOrderError: If no axis_order is provided or axis_order does not contain
9059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      all axes on the input tensor.
9069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
9079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  with ops.name_scope(name, 'lt_impose_axis_order', [labeled_tensor]) as scope:
9089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
9099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if axis_order is None:
9119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      axis_order = _get_valid_axis_order()
9129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes]
9149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return transpose(labeled_tensor, relevant_axis_order, name=scope)
9169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(tc.Optional(list))
9199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(list, list)
9209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef _find_consistent_ordering(a, b):
9219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Find the left-most consistent ordering between two lists of unique items.
9229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  A consistent ordering combines all elements in both a and b while keeping all
9249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  elements in their original order in both inputs. The left-most consistent
9259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  ordering orders elements from `a` not found in `b` before elements in `b` not
9269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  found in `a`.
9279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  For example, given ['x', 'z'] and ['y', 'z'], both ['x', 'y', 'z'] and ['y',
9299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  'x', 'z'] are consistent orderings because each of the inputs appears in
9309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  each consistent ordering in the same order, and ['x', 'y', 'z'] is the
9319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  left-most, because 'x' appears only in `a` and 'y' appears only in `b`. In
9329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  contrast, there is no consistent ordering between ['x', 'y'] and ['y', 'x'].
9339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
9359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    a: list with unique elements.
9369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    b: list with unique elements.
9379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
9399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    List containing all elements in either a or b, or None, if no consistent
9409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    ordering exists.
9419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
9429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  a_set = set(a)
9439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  b_set = set(b)
9449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  i = 0
9459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  j = 0
9469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  ordering = []
9479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  while i < len(a) and j < len(b):
9489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if a[i] not in b_set:
9499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      ordering.append(a[i])
9509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      i += 1
9519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    elif b[j] not in a_set:
9529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      ordering.append(b[j])
9539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      j += 1
9549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    elif a[i] == b[j]:
9559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      ordering.append(a[i])
9569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      i += 1
9579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      j += 1
9589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    else:
9599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      return None
9609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  ordering.extend(a[i:])
9629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  ordering.extend(b[j:])
9639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  return ordering
9659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor, LabeledTensor, Axes)
9689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types))
9699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef align(labeled_tensor_0, labeled_tensor_1, name=None):
9709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Align the axes of two tensors so they may be broadcast to each other.
9719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Axes are ordered by the current axis order scope, if present, or by the left-
9739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  most consistent ordering. An exception is raised if it is impossible to align
9749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  the tensors without a transpose (align never copies the input data).
9759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Example usage:
9779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    >>> a = lt.LabeledTensor(tf.ones((2, 4)), ['x', 'z'])
9799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    >>> b = lt.LabeledTensor(tf.ones((3, 4)), ['y', 'z'])
9809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    >>> a2, b2, axes = lt.align(a, b)
9819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    >>> a2
9829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    <LabeledTensor 'lt_align_1/lt_align_1/0:...' shape=(2, 1, 4) dtype=float32
9839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer     axes=[('x', Dimension(2)),
9849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer           ('y', Dimension(1)),
9859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer           ('z', Dimension(4))]>
9869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    >>> b2
9879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    <LabeledTensor 'lt_align_1/lt_align_1/1:...' shape=(1, 3, 4) dtype=float32
9889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer     axes=[('x', Dimension(1)),
9899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer           ('y', Dimension(3)),
9909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer           ('z', Dimension(4))]>
9919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    >>> axes
9929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Axes([('x', Dimension(2)),
9939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          ('y', Dimension(3)),
9949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          ('z', Dimension(4))])
9959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
9969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
9979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor_0: An input tensor.
9989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor_1: An input tensor.
9999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    name: Optional op name.
10009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
10029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    The aligned tensors and the axes the resulting tensor would have if the two
10039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    aligned tensors were broadcast to each other. The aligned tensors have the
10049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    same rank but not necessarily the same shape, with axes in the same order.
10059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Raises:
10079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    ValueError: If axes with the same name on the inputs are not equal.
10089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    AxisOrderError: If there is no way to reshape the input tensors into the
10099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      output without a transpose.
10109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
10119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  with ops.name_scope(name, 'lt_align',
10129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                      [labeled_tensor_0, labeled_tensor_1]) as scope:
10139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor_0 = convert_to_labeled_tensor(labeled_tensor_0)
10159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    labeled_tensor_1 = convert_to_labeled_tensor(labeled_tensor_1)
10169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes_0 = labeled_tensor_0.axes
10189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axes_1 = labeled_tensor_1.axes
10199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    for axis_name in axes_0:
10209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if axis_name in axes_1:
10219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        if axes_0[axis_name] != axes_1[axis_name]:
10229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer          raise ValueError('Mismatched %r axis on input tensors: %r and %r' %
10239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                           (axis_name, axes_0[axis_name], axes_1[axis_name]))
10249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    axis_scope_order = get_axis_order()
10269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    if axis_scope_order is not None:
10279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      # we are in an axis_order_scope
10289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      axis_names_set = set(axes_0) | set(axes_1)
10299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      new_axis_names = [a for a in axis_scope_order if a in axis_names_set]
10309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      check_axis_order(labeled_tensor_0, axis_scope_order)
10329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      check_axis_order(labeled_tensor_1, axis_scope_order)
10339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    else:
10359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      # attempt to find a consistent ordering
10369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      new_axis_names = _find_consistent_ordering(list(axes_0), list(axes_1))
10379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if new_axis_names is None:
10389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        raise AxisOrderError(
10399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer            'No consistent axis order allows for aligning tensors with axis '
10409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer            'orders %r and %r without copying data. Use transpose or '
10419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer            'impose_axis_order to reorder axes on one of more of the inputs.' %
10429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer            (axes_0.keys(), axes_1.keys()))
10439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10446e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    labeled_tensor_0 = expand_dims(
10456e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower        labeled_tensor_0, new_axis_names, name=scope + '0')
10466e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower    labeled_tensor_1 = expand_dims(
10476e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower        labeled_tensor_1, new_axis_names, name=scope + '1')
10489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    broadcast_axes = []
10509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    for axis_name in new_axis_names:
10519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      if axis_name in axes_0:
10529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        broadcast_axes.append(axes_0[axis_name])
10539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      else:
10549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer        broadcast_axes.append(axes_1[axis_name])
10559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    return labeled_tensor_0, labeled_tensor_1, Axes(broadcast_axes)
10579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(types.FunctionType)
10609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(string_types, collections.Callable)
10619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef define_unary_op(op_name, elementwise_function):
10629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Define a unary operation for labeled tensors.
10639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
10659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    op_name: string name of the TensorFlow op.
10669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    elementwise_function: function to call to evaluate the op on a single
10679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      tf.Tensor object. This function must accept two arguments: a tf.Tensor
10689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      object, and an optional `name`.
10699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
10719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Function defining the given op that acts on LabeledTensors.
10729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
10739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  default_name = 'lt_%s' % op_name
10759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(LabeledTensor)
10779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.accepts(LabeledTensorLike, tc.Optional(string_types))
10789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def op(labeled_tensor, name=None):
10799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """LabeledTensor version of `tf.{op_name}`.
10809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    See `tf.{op_name}` for full details.
10829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Args:
10849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      labeled_tensor: Input tensor.
10859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      name: Optional op name.
10869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Returns:
10889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      A LabeledTensor with result of applying `tf.{op_name}` elementwise.
10899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """
10909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    with ops.name_scope(name, default_name, [labeled_tensor]) as scope:
10919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      labeled_tensor = convert_to_labeled_tensor(labeled_tensor)
10929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      result_tensor = elementwise_function(labeled_tensor.tensor, name=scope)
10939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      return LabeledTensor(result_tensor, labeled_tensor.axes)
10949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  op.__doc__ = op.__doc__.format(op_name=op_name)
10969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  op.__name__ = op_name
10979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
10989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  return op
10999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerabs_function = define_unary_op('abs', math_ops.abs)
1102be270ecb79c8548a0caddf67908189e6169b1472Andrew Selleneg = define_unary_op('neg', math_ops.negative)
11039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersign = define_unary_op('sign', math_ops.sign)
1104fb01ebb8c38b2d274f6fe9a7115b2362828a452eMartin Wickereciprocal = define_unary_op('reciprocal', math_ops.reciprocal)
11059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersquare = define_unary_op('square', math_ops.square)
11069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerround_function = define_unary_op('round', math_ops.round)
11079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersqrt = define_unary_op('sqrt', math_ops.sqrt)
11089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerrsqrt = define_unary_op('rsqrt', math_ops.rsqrt)
11099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerexp = define_unary_op('exp', math_ops.exp)
11109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlog = define_unary_op('log', math_ops.log)
11119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerceil = define_unary_op('ceil', math_ops.ceil)
11129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfloor = define_unary_op('floor', math_ops.floor)
11139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyercos = define_unary_op('cos', math_ops.cos)
11149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersin = define_unary_op('sin', math_ops.sin)
11159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertan = define_unary_op('tan', math_ops.tan)
11169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyeracos = define_unary_op('acos', math_ops.acos)
11179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerasin = define_unary_op('asin', math_ops.asin)
11189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyeratan = define_unary_op('atan', math_ops.atan)
11199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlgamma = define_unary_op('lgamma', math_ops.lgamma)
11209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdigamma = define_unary_op('digamma', math_ops.digamma)
11219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyererf = define_unary_op('erf', math_ops.erf)
11229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyererfc = define_unary_op('erfc', math_ops.erfc)
11239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlogical_not = define_unary_op('logical_not', math_ops.logical_not)
11249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertanh = define_unary_op('tanh', math_ops.tanh)
11259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersigmoid = define_unary_op('sigmoid', math_ops.sigmoid)
11269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(types.FunctionType)
11299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(string_types, collections.Callable)
11309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef define_binary_op(op_name, elementwise_function):
11319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """Define a binary operation that broadcasts labeled tensors.
11329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Args:
11349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    op_name: string name of the TensorFlow op.
11359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    elementwise_function: function to call to evaluate the op on tf.Tensor
11369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      objects. This function must accept three arguments: two tf.Tensor objects,
11379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      and an optional `name`.
11389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  Returns:
11409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Function defining the given op that acts on LabeledTensors.
11419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  """
11429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  default_name = 'lt_%s' % op_name
11449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.returns(LabeledTensor)
11469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  @tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types))
11479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  def op(labeled_tensor_0, labeled_tensor_1, name=None):
11489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """LabeledTensor version of `tf.{op_name}` with label based alignment.
11499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    See `tf.{op_name}` for full details.
11519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Args:
11539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      labeled_tensor_0: Input tensor.
11549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      labeled_tensor_1: Input tensor.
11559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      name: Optional op name.
11569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    Returns:
11589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      A LabeledTensor with result of applying `tf.{op_name}` elementwise.
11599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    """
11609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer    with ops.name_scope(name, default_name,
11619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                        [labeled_tensor_0, labeled_tensor_1]) as scope:
11629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      align_0, align_1, broadcast_axes = align(labeled_tensor_0,
11649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer                                               labeled_tensor_1)
11659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      tensor = elementwise_function(align_0.tensor, align_1.tensor, name=scope)
11679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer      return LabeledTensor(tensor, broadcast_axes)
11699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  op.__doc__ = op.__doc__.format(op_name=op_name)
11719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  op.__name__ = op_name
11729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer  return op
11749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyeradd = define_binary_op('add', math_ops.add)
1177be270ecb79c8548a0caddf67908189e6169b1472Andrew Sellesub = define_binary_op('sub', math_ops.subtract)
1178be270ecb79c8548a0caddf67908189e6169b1472Andrew Sellemul = define_binary_op('mul', math_ops.multiply)
11799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdiv = define_binary_op('div', math_ops.div)
11809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyermod = define_binary_op('mod', math_ops.mod)
11819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerpow_function = define_binary_op('pow', math_ops.pow)
11829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerequal = define_binary_op('equal', math_ops.equal)
11849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyergreater = define_binary_op('greater', math_ops.greater)
11859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyergreater_equal = define_binary_op('greater_equal', math_ops.greater_equal)
11869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyernot_equal = define_binary_op('not_equal', math_ops.not_equal)
11879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerless = define_binary_op('less', math_ops.less)
11889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerless_equal = define_binary_op('less_equal', math_ops.less_equal)
11899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlogical_and = define_binary_op('logical_and', math_ops.logical_and)
11909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlogical_or = define_binary_op('logical_or', math_ops.logical_or)
11919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlogical_xor = define_binary_op('logical_xor', math_ops.logical_xor)
11929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer
11939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyermaximum = define_binary_op('maximum', math_ops.maximum)
11949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerminimum = define_binary_op('minimum', math_ops.minimum)
11956e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlowersquared_difference = define_binary_op('squared_difference',
11966e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower                                      math_ops.squared_difference)
11979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerigamma = define_binary_op('igamma', math_ops.igamma)
11989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerigammac = define_binary_op('igammac', math_ops.igammac)
11999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerzeta = define_binary_op('zeta', math_ops.zeta)
12009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerpolygamma = define_binary_op('polygamma', math_ops.polygamma)
1201