19d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 29d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# 39d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Licensed under the Apache License, Version 2.0 (the "License"); 49d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# you may not use this file except in compliance with the License. 59d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# You may obtain a copy of the License at 69d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# 79d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# http://www.apache.org/licenses/LICENSE-2.0 89d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# 99d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Unless required by applicable law or agreed to in writing, software 109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# distributed under the License is distributed on an "AS IS" BASIS, 119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# See the License for the specific language governing permissions and 139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# limitations under the License. 149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# ============================================================================== 159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer"""Core classes and core ops for LabeledTensor. 169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerCore ops are ops which will eventually be called by LabeledTensor methods, 189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerand ops which a core op depends upon. 199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerFor example, `add` is a core op because we'll eventually support the `+` 209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyeroperator. 219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerNon-core ops should go in `ops.py`. 229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer""" 239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom __future__ import absolute_import 249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom __future__ import division 259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom __future__ import print_function 269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport collections 289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport contextlib 299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport numbers 309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport types 319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerimport numpy as np 339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom six import binary_type 349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom six import string_types 359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom six import text_type 369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom six.moves import range # pylint: disable=redefined-builtin 379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc 399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.framework import dtypes 409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.framework import ops 419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.framework import tensor_shape 429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.ops import array_ops 439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfrom tensorflow.python.ops import math_ops 449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: disable=invalid-name 469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Types coercible to Axis.labels 489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# We use this instead of collections.Sequence to exclude strings. 499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerLabelsLike = tc.Union(np.ndarray, range, list, tuple) 509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Types coercible to a tf.Dimension 529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerDimensionLike = tc.Optional(tc.Union(tensor_shape.Dimension, int)) 539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Types usable for axis values 559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerAxisValue = tc.Union(LabelsLike, DimensionLike) 569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# Valid scalar values for TensorFlow 589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerScalar = tc.Union(numbers.Number, bool, binary_type, text_type) 599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: enable=invalid-name 619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerclass Axis(object): 649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Size and label information for an axis. 659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Axis contains either a tf.Dimension indicating the size of an axis, 679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer or a tuple of tick labels for the axis. 689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer If tick labels are provided, they must be unique. 709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.accepts(object, string_types, AxisValue) 739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __init__(self, name, value): 749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Construct an Axis. 759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Name of the axis. 789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer value: Either None, an int or tf.Dimension giving the size of the axis, 799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer or a sequence that is not a string additionally providing coordinate 809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer (tick) labels. 819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ValueError: If the user provides labels with duplicate values. 849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if isinstance(value, tensor_shape.Dimension): 869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer dimension = value 879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labels = None 889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer elif isinstance(value, int) or value is None: 899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer dimension = tensor_shape.Dimension(value) 909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labels = None 919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer dimension = tensor_shape.Dimension(len(value)) 939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labels = tuple(value) 949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if dimension.value == 0: 969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # Treat a zero-length axis as if it has labels. 979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labels = () 989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if labels is not None: 1009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer index = dict(zip(labels, range(len(labels)))) 1019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if len(index) != len(labels): 1029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('Tick labels must be unique, but got {}' 1039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer .format(labels)) 1049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 1059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer index = None 1069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self._name = name # type: string_types 1089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self._dimension = dimension # type: tensor_shape.Dimension 1099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self._labels = labels # type: Optional[tuple] 1109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self._index = index # type: Optional[Dict[Any, int]] 1119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 1139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(string_types) 1149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def name(self): 1159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._name 1169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(string_types) 1189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __repr__(self): 1199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # Axis('x', Dimension(2)) 1209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # TODO(shoyer): make very long reprs more succint? 1219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return "%s('%s', %r)" % (type(self).__name__, self.name, self.value) 1229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(bool) 1249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __eq__(self, other): 1256e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower return (isinstance(other, Axis) and self.name == other.name and 1266e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower self.size == other.size and self.labels == other.labels) 1279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __hash__(self): 1299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return hash((self.name, self.size, self.labels)) 1309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(bool) 1329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __ne__(self, other): 1339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return not self == other 1349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(int) 1369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __len__(self): 1379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer size = self.size 1389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if size is None: 1399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('axis %r has unknown length' % self.name) 1409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return size 1419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 1439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(tc.Optional(tensor_shape.Dimension)) 1449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def dimension(self): 1459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._dimension 1469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 1489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(tc.Optional(int)) 1499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def size(self): 1509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._dimension.value 1519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 1539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(tc.Union(tuple, tensor_shape.Dimension)) 1549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def value(self): 1559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Returns the tf.Dimension or tuple specifying axis ticks.""" 1569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if self.labels is None: 1579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self.dimension 1589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 1599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self.labels 1609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 1629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(tc.Optional(tuple)) 1639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def labels(self): 1649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Returns the tuple containing coordinate labels, else None.""" 1659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._labels 1669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def index(self, value): 1689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Returns the integer position of the given tick label.""" 1699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if self._index is None: 1709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('Axis does not have tick labels') 1719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._index[value] 1729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# tc class for anything that can be coerced into an Axis 1759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: disable=invalid-name 1769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan HoyerAxisLike = tc.Union(Axis, tc.Tuple(string_types, AxisValue)) 1779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: enable=invalid-name 1789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(Axis) 1819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(AxisLike) 1829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef as_axis(axis_data): 1839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Convert an AxisLike object into an Axis. 1849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 1869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_data: Axis object or tuple (axis_name, axis_value) describing an axis. 1879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 1899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Axis object. This may be the original object if axis_data is an Axis. 1909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 1919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if isinstance(axis_data, Axis): 1929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis = axis_data 1939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 1949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis = Axis(*axis_data) 1959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return axis 1969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerclass Axes(collections.Mapping): 1999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Axis names and indices for a tensor. 2009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer It is an ordered mapping, with keys given by axis name and values given 202a0ffaf3caa0234653035a692858606c7bdacd63bFrank Chen by Axis objects. Duplicate axis names are not allowed. 2039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 2049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.accepts(object, tc.List(AxisLike)) 2069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __init__(self, axes): 2079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Construct an Axes. 2089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 2109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes: A list of Axis objects or (axis_name, axis_value) tuples. 2119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 2139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ValueError: If the user provides empty or duplicate axis names. 2149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 2159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self._axes = collections.OrderedDict() 2169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for axis_data in axes: 2189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis = as_axis(axis_data) 2199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name = axis.name 2219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if name in self._axes: 2229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('Duplicate axis name: %s' % name) 2239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self._axes[name] = axis 2259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __iter__(self): 2279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return iter(self._axes) 2289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(string_types) 2309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __repr__(self): 2319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # Axes([('x', Dimension(2)), 2329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # ('y', ['a', 'b', 'c']), 2339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # ('z', Dimension(4))]) 2349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer cls_name = type(self).__name__ 2359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer values = ["('%s', %r)" % (v.name, v.value) for v in self._axes.values()] 2369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer values_repr = (',\n' + ' ' * len(cls_name + '([')).join(values) 2379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return '%s([%s])' % (cls_name, values_repr) 2389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(Axis) 2409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.accepts(object, string_types) 2419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __getitem__(self, name): 2429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._axes[name] 2439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(bool) 2459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __contains__(self, name): 2469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return name in self._axes 2479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(int) 2499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __len__(self): 2509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return len(self._axes) 2519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __hash__(self): 2539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return hash(tuple(self.items())) 2549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.accepts(object, string_types) 2569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def remove(self, axis_name): 2579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Creates a new Axes object without the given axis.""" 2589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_name not in self: 2599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise KeyError(axis_name) 2609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer remaining_axes = [axis for axis in self.values() if axis.name != axis_name] 2619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return Axes(remaining_axes) 2629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerclass LabeledTensor(object): 2659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """A tensor with annotated axes. 2669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer It has the following invariants: 2689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 1) The dimensionality of the tensor is equal to the number of elements 2699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer in axes. 2709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2) The number of coordinate values in the ith dimension is equal to the 2719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer size of the tensor in the ith dimension. 2729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Attributes: 2749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer tensor: tf.Tensor containing the data. 2759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes: lt.Axes containing axis names and coordinate labels. 2769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 2779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2783866dd2b20e5b63ad4ea8d0edb7961a2252d906dJonathan Hseu @tc.accepts(object, ops.Tensor, 2799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer tc.Union(Axes, tc.Collection(tc.Union(string_types, AxisLike)))) 2809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __init__(self, tensor, axes): 281e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai """Construct a LabeledTensor. 2829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 2849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer tensor: The underlying tensor containing the data. 2859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes: An Axes object, or a collection of strings, Axis objects or tuples 2869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer of (name, value) pairs indicating the axes. 2879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 2899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ValueError: If the provided axes do not satisfy the class invariants. 2909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 2919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self._tensor = tensor 2929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer shape = tensor.get_shape() 2939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if isinstance(axes, Axes): 2959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer unvalidated_axes = axes 2969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 2979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer mutable_axes = [] 2989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 2999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for position, axis_like in enumerate(axes): 3009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if isinstance(axis_like, string_types): 3019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # The coordinates for this axes are unlabeled. 3029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # Infer the size of the axis. 3039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer value = shape[position] 3049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_like = (axis_like, value) 3059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer mutable_axes.append(axis_like) 3079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # Construct the Axis object, which will additionally validate the contents 3099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # of the object. 3109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer unvalidated_axes = Axes(mutable_axes) 3119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # Check our invariants. 3139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # First, the rank of the tensor must be equal to the number of axes. 3159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if len(shape) != len(unvalidated_axes): 3169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('Tensor rank was not equal to the number of axes: %r, %r' 3179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer % (shape, unvalidated_axes)) 3189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # Second, the size of each tensor dimension must match the size of the 3209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # corresponding indices. 3219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for (d, axis) in zip(shape, unvalidated_axes.values()): 3229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if d != axis.size: 3239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError( 3249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'Provided axis size %d does not match tensor dimension size %d' % 3259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer (axis.size, d)) 3269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self._axes = unvalidated_axes 3289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __repr__(self): 3309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # <LabeledTensor 'foo' shape=(2, 3, 4) dtype=float32 3319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # axes=[('x', Dimension(2)), 3329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # ('y', ('a', 'b', 'c'), 3339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # ('z', Dimension(4))]> 3349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes = ["('%s', %r)" % (v.name, v.value) for v in self.axes.values()] 3359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes_repr = (',\n' + ' ' * len(' axes=[')).join(axes) 3369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return ("<%s '%s' shape=%s dtype=%s\n axes=[%s]>" % 3379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer (type(self).__name__, self.tensor.name, self.tensor.get_shape(), 3389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self.tensor.dtype.name, axes_repr)) 3399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 3419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def tensor(self): 3429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._tensor 3439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def _as_graph_element(self): 3459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Support tf.Graph.as_graph_element on LabeledTensor objects. 3469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer This allows operations such as tf.name_scope to take labeled tensors. 3489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 3509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer self.tensor 3519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 3529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self.tensor 3539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 3559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def axes(self): 3569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._axes 3579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # properties/methods directly borrowed from tf.Tensor: 3599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 3619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def dtype(self): 3629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._tensor.dtype 3639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @property 365128572c316e6f2eb6346f920314ef98e88e75069A. Unique TensorFlower def shape(self): 366128572c316e6f2eb6346f920314ef98e88e75069A. Unique TensorFlower return self._tensor.shape 367128572c316e6f2eb6346f920314ef98e88e75069A. Unique TensorFlower 368128572c316e6f2eb6346f920314ef98e88e75069A. Unique TensorFlower @property 3699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def name(self): 3709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._tensor.name 3719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def get_shape(self): 3739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Returns the TensorShape that represents the shape of this tensor. 3749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer See tf.Tensor.get_shape(). 3769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 3789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer A TensorShape representing the shape of this tensor. 3799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 3809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self._tensor.get_shape() 3819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # TODO(shoyer): consider how/if to implement .eval(). Maybe it should return 3839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # an xarray.DataArray? 3849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __getitem__(self, key): 3869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # This should work exactly like tf.Tensor.__getitem__, except it preserves 3879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # labels. 3889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if not isinstance(key, tuple): 3899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer key = (key,) 3909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if len(key) != len(self.axes): 3919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('indexer %r must have the same length as the Tensor ' 3929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'rank (%r)' % (key, len(self.axes))) 3939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer selection = {a: k for a, k in zip(self.axes.keys(), key)} 3949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return slice_function(self, selection) 3959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # special methods for overloading arithmetic operations: 3979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 3989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __abs__(self): 3999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return abs_function(self) 4009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __neg__(self): 4029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return neg(self) 4039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __pos__(self): 4059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self 4069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __add__(self, other): 4089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return add(self, other) 4099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __radd__(self, other): 4119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return add(other, self) 4129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __sub__(self, other): 4149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return sub(self, other) 4159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __rsub__(self, other): 4179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return sub(other, self) 4189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __mul__(self, other): 4209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return mul(self, other) 4219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __rmul__(self, other): 4239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return mul(other, self) 4249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __truediv__(self, other): 4269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return div(self, other) 4279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer __div__ = __truediv__ 4299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __rtruediv__(self, other): 4319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return div(other, self) 4329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer __rdiv__ = __rtruediv__ 4349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __mod__(self, other): 4369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return mod(self, other) 4379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __rmod__(self, other): 4399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return mod(other, self) 4409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __pow__(self, other): 4429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return pow_function(self, other) 4439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __rpow__(self, other): 4459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return pow_function(other, self) 4469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # logical operations: 4489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __invert__(self): 4509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return logical_not(self) 4519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __and__(self, other): 4539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return logical_and(self, other) 4549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __or__(self, other): 4569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return logical_or(self, other) 4579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __xor__(self, other): 4599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return logical_xor(self, other) 4609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # boolean operations: 4629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __lt__(self, other): 4649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return less(self, other) 4659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __le__(self, other): 4679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return less_equal(self, other) 4689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __gt__(self, other): 4709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return greater(self, other) 4719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __ge__(self, other): 4739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return greater_equal(self, other) 4749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __eq__(self, other): 4769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # for consistency with tf.Tensor 4779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if not isinstance(other, LabeledTensor): 4789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return False 4799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return self.tensor == other.tensor and self.axes == other.axes 4819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __ne__(self, other): 4839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return not self == other 4849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def __hash__(self): 4869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return hash((self.tensor, self.axes)) 4879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# typecheck type abbreviations: 4909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# abbreviations for third-party types with very long reprs 4919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(tensor_shape.Dimension, 'tensorflow.Dimension') 4923866dd2b20e5b63ad4ea8d0edb7961a2252d906dJonathan Hseutc.register_type_abbreviation(ops.Tensor, 'tensorflow.Tensor') 4939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(dtypes.DType, 'tensorflow.DType') 4949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# core LabeledTensor types 4959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(Axis, 'labeled_tensor.Axis') 4969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(Axes, 'labeled_tensor.Axes') 4979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertc.register_type_abbreviation(LabeledTensor, 'labeled_tensor.LabeledTensor') 4989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 4999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5003866dd2b20e5b63ad4ea8d0edb7961a2252d906dJonathan Hseu@tc.returns(ops.Tensor) 5019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensor) 5029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef _convert_labeled_tensor_to_tensor(value, *args, **kwargs): 5039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # call ops.convert_to_tensor to handle optional arguments appropriately 5040672d414381b66dd63cc58ea860f8dd93af0083cA. Unique TensorFlower return ops.internal_convert_to_tensor(value.tensor, *args, **kwargs) 5059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5076e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlowerops.register_tensor_conversion_function(LabeledTensor, 5086e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower _convert_labeled_tensor_to_tensor) 5099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# tc class for anything that can be coerced into a LabeledTensor 5119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: disable=invalid-name 5123866dd2b20e5b63ad4ea8d0edb7961a2252d906dJonathan HseuLabeledTensorLike = tc.Union(LabeledTensor, ops.Tensor, np.ndarray, Scalar) 5139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# pylint: enable=invalid-name 5149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor) 5179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensorLike, object, tc.Optional(string_types)) 5189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef convert_to_labeled_tensor(value, dtype=None, name=None): 5199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Converts the given `value` to a `LabeledTensor`. 5209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 52185eeec0d415a1478bbeffc3d4545c795bee64e9fJonathan Hseu This function accepts `LabeledTensor` objects, 0-dimensional `Tensor` objects 5229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer and numpy arrays, and Python scalars. Higher dimensional unlabeled tensors 5239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer must use the `LabeledTensor` constructor explicitly. 5249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 5269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer value: Object to convert. 5279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer dtype: Optional element type for the returned tensor. If missing, the type 5289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer is inferred from the type of value. 5299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional name to use if a new Tensor is created. 5309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 5329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer `value` converted into a `LabeledTensor` object. 5339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 5359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ValueError: If the output would have rank>0 but the input was not already a 5369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer `LabeledTensor`. 5379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 5389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # TODO(shoyer): consider extending to accept xarray.DataArray as input. 5399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if isinstance(value, LabeledTensor): 5409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes = value.axes.values() 5419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer value = value.tensor 5429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 5439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes = [] 5449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # We call convert_to_tensor even for LabeledTensor input because it also 5469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # checks to make sure the dtype argument is compatible. 5479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer tensor = ops.convert_to_tensor(value, dtype=dtype, name=name) 5489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if len(tensor.get_shape()) != len(axes): 5499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('cannot automatically convert unlabeled arrays or tensors ' 5509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'with rank>0 into LabeledTensors: %r' % value) 5519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return LabeledTensor(tensor, axes) 5529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(Axis) 5559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(tc.Collection(Axis)) 5569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef concat_axes(axes): 5579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Concatenate a list of Axes. 5589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 5609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes: A collection of Axis objects. 5619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 5639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer The concatenation of the axes. 5649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer If all axes have labels, the result has the concatenation of the labels. 5659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Else, the result has no labels, and its size is the sum of the sizes 5669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer of the axes. 5679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 5699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ValueError: If `others` is not a collection of Axes or if it is empty. 5709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 5719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if not axes: 5729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('axes must not be empty') 5739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for a in axes: 5749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if not isinstance(a, Axis): 5759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('Expected an Axis, but got %r of type %r' % (a, type(a))) 5769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer names = set(a.name for a in axes) 5789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if len(names) > 1: 5799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('axes do not all have the same name: %r' % names) 5809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name, = names 5819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer all_have_labels = all(a.labels is not None for a in axes) 5839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer any_has_unknown_size = any(a.size is None for a in axes) 5849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if all_have_labels: 5869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer value = tuple(label for a in axes for label in a.labels) 5879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer elif any_has_unknown_size: 5889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer value = None 5899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 5909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer value = sum(len(a) for a in axes) 5919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return Axis(name, value) 5929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor) 5959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensorLike, tc.Optional(string_types)) 5969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef identity(labeled_tensor, name=None): 5979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """The identity op. 5989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 5999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer See tf.identity. 6009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 6029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor: The input tensor. 6039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional op name. 6049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 6069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer The tensor. 6079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 6089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with ops.name_scope(name, 'lt_identity', [labeled_tensor]) as scope: 6099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 6109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return LabeledTensor( 6116e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower array_ops.identity( 6126e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower labeled_tensor.tensor, name=scope), 6139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor.axes) 6149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# We don't call this slice because that shadows a built-in. Instead, we alias 6179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# this to lt.slice in __init__.py. 6189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor) 6196e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower@tc.accepts(LabeledTensorLike, 6206e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower tc.Mapping(string_types, tc.Union(int, slice)), 6219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer tc.Optional(string_types)) 6229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef slice_function(labeled_tensor, selection, name=None): 6239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Slice out a subset of the tensor. 6249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6251b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu This is an analog of tf.slice. 6269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer For example: 6279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> tensor = tf.reshape(tf.range(0, 6), [3, 2]) 6289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> labeled_tensor = lt.LabeledTensor(tensor, ['a', ('b', ['foo', 'bar'])]) 6299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> lt.slice(labeled_tensor, {'a': slice(0, 2), 'b': 1}) 6309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer <LabeledTensor 'lt_slice:...' shape=(2,) dtype=int32 6319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes=[('a', Dimension(2))]> 6329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 6349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor: The input tensor. 6359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer selection: A dictionary of type str -> Union(int, slice of int) mapping 6369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis names to sub-selections. 6379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional op name. 6389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 6409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer The slice as a `LabeledTensor`. 6419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 6429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with ops.name_scope(name, 'lt_slice', [labeled_tensor]) as scope: 6439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 6449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer slices = [] 6469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for axis_name in labeled_tensor.axes: 6489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_name not in selection: 6499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # We're not sub-selecting this axis, so use the full slice. 6509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer slices.append(slice(None)) 6519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 6529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer slices.append(selection[axis_name]) 6539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer sliced_tensor = labeled_tensor.tensor[tuple(slices)] 6559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer sliced_axes = [] 6579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for axis, s in zip(labeled_tensor.axes.values(), slices): 6589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # We sub-select this axis's index with the slice s. 6599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # `s` is either an int or a proper slice. 6619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if isinstance(s, slice): 6629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis.labels is None: 6639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # We're not tracking coordinate names for this axis. 6649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer sliced_axes.append(axis.name) 6659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 6669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer sliced_axes.append((axis.name, axis.labels[s])) 6679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 6689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # If the slice is an int this dimension now has size 1, so we remove it. 6699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer assert isinstance(s, int) 6709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6716e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower return LabeledTensor( 6726e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower array_ops.identity( 6736e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower sliced_tensor, name=scope), sliced_axes) 6749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor) 6776e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower@tc.accepts(LabeledTensorLike, 6786e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower tc.Optional(tc.Collection(string_types)), tc.Optional(string_types)) 6799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef transpose(labeled_tensor, axis_order=None, name=None): 6809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Permute a tensor's axes. 6819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer See tf.transpose. 6839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 6859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor: The input tensor. 6869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order: Optional desired axis order, as a list of names. By default, the 6879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer order of axes is reversed. 6889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional op name. 6899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 6919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer The permuted tensor. 6929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 6949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ValueError: If axis_order isn't a permutation of the existing axes. 6959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 6969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with ops.name_scope(name, 'lt_transpose', [labeled_tensor]) as scope: 6979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 6989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 6999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer original_order = list(labeled_tensor.axes.keys()) 7009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_order is None: 7019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order = list(reversed(original_order)) 7029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer elif sorted(axis_order) != sorted(original_order): 7039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError( 7049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'The new axis order must have the same names as the original axes, ' 7059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'but the new order is %r while the original order is %r' % 7069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer (axis_order, original_order)) 7079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_names = list(labeled_tensor.axes.keys()) 7099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer permutation = [axis_names.index(n) for n in axis_order] 7109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7111b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu # Note: TensorFlow doesn't copy data for the identity transpose. 7126e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower transpose_tensor = array_ops.transpose( 7136e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower labeled_tensor.tensor, permutation, name=scope) 7149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer permuted_axes = [labeled_tensor.axes[n] for n in axis_order] 7169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return LabeledTensor(transpose_tensor, permuted_axes) 7189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor) 7216e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower@tc.accepts( 7226e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower LabeledTensorLike, 7236e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower tc.Collection( 7246e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower tc.Union(string_types, tc.Tuple(string_types, collections.Hashable))), 7256e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower tc.Optional(string_types)) 7269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef expand_dims(labeled_tensor, axes, name=None): 7279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Insert dimensions of size 1. 7289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer See tf.expand_dims. 7309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 7329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor: The input tensor. 7339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes: The desired axis names as strings or tuples of (name, label), 7349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer where `label` is the coordinate name for the new dimension `name`. 7359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer These must include the existing axis names, and the existing names must 7369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer appear in the same order in this list as they do in the input tensor. 7379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional op name. 7389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 7409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer A tensor with an axis for each axis in axes. 7419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer New axes are created with size 1 and do not have labeled coordinates. 7429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 7449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer AxisOrderError: If axis names don't appear in the same order in axes 7459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer and the labeled tensor. 7469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 7479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with ops.name_scope(name, 'lt_expand_dims', [labeled_tensor]) as scope: 7489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 7499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_names = [a if isinstance(a, string_types) else a[0] for a in axes] 7519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer check_axis_order(labeled_tensor, axis_names) 7529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer reshaped_axes = [] 7549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer shape = [] 7559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for axis_spec in axes: 7569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_spec in labeled_tensor.axes: 7579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis = labeled_tensor.axes[axis_spec] 7589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer reshaped_axes.append(axis) 7599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer shape.append(-1 if axis.size is None else axis.size) 7609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 7619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if isinstance(axis_spec, string_types): 7629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer reshaped_axes.append((axis_spec, 1)) 7639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 7649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer (name, label) = axis_spec 7659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer reshaped_axes.append((name, (label,))) 7669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer shape.append(1) 7689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7696e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower reshaped_tensor = array_ops.reshape( 7706e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower labeled_tensor.tensor, shape, name=scope) 7719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return LabeledTensor(reshaped_tensor, reshaped_axes) 7739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7746e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower 7759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# This should only be added to a graph collection once. 7769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer_AXIS_ORDER_KEY = ('__axis_order',) 7779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(tc.Optional(tc.List(string_types))) 7809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef get_axis_order(): 7819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Get the axis_order set by any containing axis_order_scope. 7829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 7849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer List of strings giving an order to use for axis names, or None, if no axis 7859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer order is set. 7869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 7879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # By storing axis_order in the graph, we can ensure that axis_order_scope is 7889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # thread-safe. 7899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order_list = ops.get_collection(_AXIS_ORDER_KEY) 7909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_order_list: 7919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order, = axis_order_list 7929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 7939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order = None 7949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return axis_order 7959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 7979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(tc.Optional(tc.List(string_types))) 7989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef _set_axis_order(axis_order): 7999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order_list = ops.get_collection_ref(_AXIS_ORDER_KEY) 8009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_order_list: 8019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order_list[0] = axis_order 8029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 8039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order_list.append(axis_order) 8049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@contextlib.contextmanager 8079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(tc.Optional(tc.List(string_types))) 8089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef axis_order_scope(axis_order=None): 8099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Set axis order for the result of broadcasting operations within a scope. 8109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer This allows you to ensure that tensors resulting from arithmetic have a 8129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer predictable axis order. 8139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Example usage: 8159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with lt.axis_order_scope(['x', 'y', 'z']): 817ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner # result is guaranteed to have the correct axis order 8189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer result = w + b 8199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer You can nest scopes, in which case only the inner-most scope applies, e.g., 8219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with lt.axis_order(['x', 'y', 'z']): 8239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with lt.axis_order(): 8249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer result = w + b # uses the default (left-most) axis ordering 8259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 8279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order: optional list of strings providing axis names. By default, 8289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer creates a scope without axis order. 8299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Yields: 8319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer The provided axis_order or `None`. 8329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 8339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer original_axis_order = get_axis_order() 8349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer _set_axis_order(axis_order) 8359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer try: 8369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer yield axis_order 8379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer finally: 8389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer _set_axis_order(original_axis_order) 8399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(tc.List(string_types)) 8429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef _get_valid_axis_order(): 8439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order = get_axis_order() 8449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_order is None: 8459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise AxisOrderError('an explicit axis order must be provided with the ' 8469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'axis_order argument or by using an axis_order_scope') 8479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return axis_order 8489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerclass AxisOrderError(ValueError): 8519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Error class for cases where there is no valid axis order.""" 8529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer# TODO(shoyer): should this function accept a list of labeled tensors instead? 8559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(type(None)) 8569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensorLike, tc.Optional(tc.Collection(string_types))) 8579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef check_axis_order(labeled_tensor, axis_order=None): 8589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Verify that the given tensor has a consistent axis order. 8599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 8619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor: The input tensor. All axes on this tensor must appear in 8629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order. 8639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order: Optional desired axis order, as a list of names. If not 8649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer provided, defaults to the current axis_order_scope (if set). 8659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 8679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer AxisOrderError: If the axis_order is unavailable, inconsistent or does not 8689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer include all existing axes. 8699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 8709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 8719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_order is None: 8739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order = _get_valid_axis_order() 8749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes] 8769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if len(relevant_axis_order) < len(labeled_tensor.axes): 8789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise AxisOrderError( 8799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'not all axis names appear in the required axis order %r: %r' % 8809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer (axis_order, labeled_tensor)) 8819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if relevant_axis_order != list(labeled_tensor.axes): 8839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise AxisOrderError( 8849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'axes on a labeled tensor do not appear in the same order as the ' 8859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'required axis order %r: %r' % (axis_order, labeled_tensor)) 8869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor) 8896e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower@tc.accepts(LabeledTensorLike, 8906e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower tc.Optional(tc.Collection(string_types)), tc.Optional(string_types)) 8919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef impose_axis_order(labeled_tensor, axis_order=None, name=None): 8929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Impose desired axis order on a labeled tensor. 8939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 8949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 8959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor: The input tensor. 8969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order: Optional desired axis order, as a list of names. If not 8979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer provided, defaults to the current axis_order_scope (if set). 8989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional op name. 8999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 9019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Labeled tensor with possibly transposed axes. 9029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 9049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer AxisOrderError: If no axis_order is provided or axis_order does not contain 9059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer all axes on the input tensor. 9069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 9079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with ops.name_scope(name, 'lt_impose_axis_order', [labeled_tensor]) as scope: 9089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 9099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_order is None: 9119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_order = _get_valid_axis_order() 9129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer relevant_axis_order = [a for a in axis_order if a in labeled_tensor.axes] 9149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return transpose(labeled_tensor, relevant_axis_order, name=scope) 9169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(tc.Optional(list)) 9199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(list, list) 9209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef _find_consistent_ordering(a, b): 9219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Find the left-most consistent ordering between two lists of unique items. 9229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer A consistent ordering combines all elements in both a and b while keeping all 9249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer elements in their original order in both inputs. The left-most consistent 9259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ordering orders elements from `a` not found in `b` before elements in `b` not 9269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer found in `a`. 9279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer For example, given ['x', 'z'] and ['y', 'z'], both ['x', 'y', 'z'] and ['y', 9299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'x', 'z'] are consistent orderings because each of the inputs appears in 9309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer each consistent ordering in the same order, and ['x', 'y', 'z'] is the 9319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer left-most, because 'x' appears only in `a` and 'y' appears only in `b`. In 9329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer contrast, there is no consistent ordering between ['x', 'y'] and ['y', 'x']. 9339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 9359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer a: list with unique elements. 9369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer b: list with unique elements. 9379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 9399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer List containing all elements in either a or b, or None, if no consistent 9409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ordering exists. 9419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 9429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer a_set = set(a) 9439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer b_set = set(b) 9449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer i = 0 9459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer j = 0 9469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ordering = [] 9479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer while i < len(a) and j < len(b): 9489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if a[i] not in b_set: 9499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ordering.append(a[i]) 9509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer i += 1 9519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer elif b[j] not in a_set: 9529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ordering.append(b[j]) 9539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer j += 1 9549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer elif a[i] == b[j]: 9559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ordering.append(a[i]) 9569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer i += 1 9579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer j += 1 9589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 9599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return None 9609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ordering.extend(a[i:]) 9629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ordering.extend(b[j:]) 9639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return ordering 9659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(LabeledTensor, LabeledTensor, Axes) 9689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types)) 9699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef align(labeled_tensor_0, labeled_tensor_1, name=None): 9709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Align the axes of two tensors so they may be broadcast to each other. 9719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Axes are ordered by the current axis order scope, if present, or by the left- 9739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer most consistent ordering. An exception is raised if it is impossible to align 9749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer the tensors without a transpose (align never copies the input data). 9759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Example usage: 9779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> a = lt.LabeledTensor(tf.ones((2, 4)), ['x', 'z']) 9799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> b = lt.LabeledTensor(tf.ones((3, 4)), ['y', 'z']) 9809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> a2, b2, axes = lt.align(a, b) 9819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> a2 9829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer <LabeledTensor 'lt_align_1/lt_align_1/0:...' shape=(2, 1, 4) dtype=float32 9839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes=[('x', Dimension(2)), 9849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ('y', Dimension(1)), 9859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ('z', Dimension(4))]> 9869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> b2 9879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer <LabeledTensor 'lt_align_1/lt_align_1/1:...' shape=(1, 3, 4) dtype=float32 9889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes=[('x', Dimension(1)), 9899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ('y', Dimension(3)), 9909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ('z', Dimension(4))]> 9919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer >>> axes 9929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Axes([('x', Dimension(2)), 9939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ('y', Dimension(3)), 9949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ('z', Dimension(4))]) 9959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 9969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 9979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor_0: An input tensor. 9989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor_1: An input tensor. 9999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional op name. 10009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 10029d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer The aligned tensors and the axes the resulting tensor would have if the two 10039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer aligned tensors were broadcast to each other. The aligned tensors have the 10049d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer same rank but not necessarily the same shape, with axes in the same order. 10059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Raises: 10079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer ValueError: If axes with the same name on the inputs are not equal. 10089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer AxisOrderError: If there is no way to reshape the input tensors into the 10099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer output without a transpose. 10109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 10119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with ops.name_scope(name, 'lt_align', 10129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer [labeled_tensor_0, labeled_tensor_1]) as scope: 10139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor_0 = convert_to_labeled_tensor(labeled_tensor_0) 10159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor_1 = convert_to_labeled_tensor(labeled_tensor_1) 10169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes_0 = labeled_tensor_0.axes 10189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axes_1 = labeled_tensor_1.axes 10199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for axis_name in axes_0: 10209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_name in axes_1: 10219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axes_0[axis_name] != axes_1[axis_name]: 10229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise ValueError('Mismatched %r axis on input tensors: %r and %r' % 10239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer (axis_name, axes_0[axis_name], axes_1[axis_name])) 10249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_scope_order = get_axis_order() 10269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_scope_order is not None: 10279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # we are in an axis_order_scope 10289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer axis_names_set = set(axes_0) | set(axes_1) 10299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer new_axis_names = [a for a in axis_scope_order if a in axis_names_set] 10309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer check_axis_order(labeled_tensor_0, axis_scope_order) 10329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer check_axis_order(labeled_tensor_1, axis_scope_order) 10339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 10359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer # attempt to find a consistent ordering 10369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer new_axis_names = _find_consistent_ordering(list(axes_0), list(axes_1)) 10379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if new_axis_names is None: 10389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer raise AxisOrderError( 10399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'No consistent axis order allows for aligning tensors with axis ' 10409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'orders %r and %r without copying data. Use transpose or ' 10419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 'impose_axis_order to reorder axes on one of more of the inputs.' % 10429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer (axes_0.keys(), axes_1.keys())) 10439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10446e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower labeled_tensor_0 = expand_dims( 10456e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower labeled_tensor_0, new_axis_names, name=scope + '0') 10466e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower labeled_tensor_1 = expand_dims( 10476e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower labeled_tensor_1, new_axis_names, name=scope + '1') 10489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer broadcast_axes = [] 10509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer for axis_name in new_axis_names: 10519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer if axis_name in axes_0: 10529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer broadcast_axes.append(axes_0[axis_name]) 10539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer else: 10549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer broadcast_axes.append(axes_1[axis_name]) 10559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return labeled_tensor_0, labeled_tensor_1, Axes(broadcast_axes) 10579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(types.FunctionType) 10609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(string_types, collections.Callable) 10619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef define_unary_op(op_name, elementwise_function): 10629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Define a unary operation for labeled tensors. 10639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 10659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer op_name: string name of the TensorFlow op. 10669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer elementwise_function: function to call to evaluate the op on a single 10679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer tf.Tensor object. This function must accept two arguments: a tf.Tensor 10689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer object, and an optional `name`. 10699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 10719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Function defining the given op that acts on LabeledTensors. 10729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 10739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer default_name = 'lt_%s' % op_name 10759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(LabeledTensor) 10779d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.accepts(LabeledTensorLike, tc.Optional(string_types)) 10789d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def op(labeled_tensor, name=None): 10799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """LabeledTensor version of `tf.{op_name}`. 10809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer See `tf.{op_name}` for full details. 10829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 10849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor: Input tensor. 10859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional op name. 10869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 10889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer A LabeledTensor with result of applying `tf.{op_name}` elementwise. 10899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 10909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with ops.name_scope(name, default_name, [labeled_tensor]) as scope: 10919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor = convert_to_labeled_tensor(labeled_tensor) 10929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer result_tensor = elementwise_function(labeled_tensor.tensor, name=scope) 10939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return LabeledTensor(result_tensor, labeled_tensor.axes) 10949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10959d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer op.__doc__ = op.__doc__.format(op_name=op_name) 10969d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer op.__name__ = op_name 10979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 10989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return op 10999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11019d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerabs_function = define_unary_op('abs', math_ops.abs) 1102be270ecb79c8548a0caddf67908189e6169b1472Andrew Selleneg = define_unary_op('neg', math_ops.negative) 11039d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersign = define_unary_op('sign', math_ops.sign) 1104fb01ebb8c38b2d274f6fe9a7115b2362828a452eMartin Wickereciprocal = define_unary_op('reciprocal', math_ops.reciprocal) 11059d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersquare = define_unary_op('square', math_ops.square) 11069d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerround_function = define_unary_op('round', math_ops.round) 11079d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersqrt = define_unary_op('sqrt', math_ops.sqrt) 11089d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerrsqrt = define_unary_op('rsqrt', math_ops.rsqrt) 11099d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerexp = define_unary_op('exp', math_ops.exp) 11109d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlog = define_unary_op('log', math_ops.log) 11119d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerceil = define_unary_op('ceil', math_ops.ceil) 11129d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerfloor = define_unary_op('floor', math_ops.floor) 11139d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyercos = define_unary_op('cos', math_ops.cos) 11149d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersin = define_unary_op('sin', math_ops.sin) 11159d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertan = define_unary_op('tan', math_ops.tan) 11169d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyeracos = define_unary_op('acos', math_ops.acos) 11179d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerasin = define_unary_op('asin', math_ops.asin) 11189d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyeratan = define_unary_op('atan', math_ops.atan) 11199d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlgamma = define_unary_op('lgamma', math_ops.lgamma) 11209d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdigamma = define_unary_op('digamma', math_ops.digamma) 11219d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyererf = define_unary_op('erf', math_ops.erf) 11229d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyererfc = define_unary_op('erfc', math_ops.erfc) 11239d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlogical_not = define_unary_op('logical_not', math_ops.logical_not) 11249d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyertanh = define_unary_op('tanh', math_ops.tanh) 11259d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyersigmoid = define_unary_op('sigmoid', math_ops.sigmoid) 11269d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11279d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11289d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.returns(types.FunctionType) 11299d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer@tc.accepts(string_types, collections.Callable) 11309d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdef define_binary_op(op_name, elementwise_function): 11319d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """Define a binary operation that broadcasts labeled tensors. 11329d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11339d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 11349d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer op_name: string name of the TensorFlow op. 11359d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer elementwise_function: function to call to evaluate the op on tf.Tensor 11369d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer objects. This function must accept three arguments: two tf.Tensor objects, 11379d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer and an optional `name`. 11389d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11399d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 11409d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Function defining the given op that acts on LabeledTensors. 11419d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 11429d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11439d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer default_name = 'lt_%s' % op_name 11449d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11459d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.returns(LabeledTensor) 11469d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer @tc.accepts(LabeledTensorLike, LabeledTensorLike, tc.Optional(string_types)) 11479d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer def op(labeled_tensor_0, labeled_tensor_1, name=None): 11489d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """LabeledTensor version of `tf.{op_name}` with label based alignment. 11499d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11509d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer See `tf.{op_name}` for full details. 11519d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11529d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Args: 11539d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor_0: Input tensor. 11549d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor_1: Input tensor. 11559d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer name: Optional op name. 11569d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11579d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer Returns: 11589d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer A LabeledTensor with result of applying `tf.{op_name}` elementwise. 11599d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer """ 11609d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer with ops.name_scope(name, default_name, 11619d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer [labeled_tensor_0, labeled_tensor_1]) as scope: 11629d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11639d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer align_0, align_1, broadcast_axes = align(labeled_tensor_0, 11649d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer labeled_tensor_1) 11659d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11669d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer tensor = elementwise_function(align_0.tensor, align_1.tensor, name=scope) 11679d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11689d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return LabeledTensor(tensor, broadcast_axes) 11699d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11709d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer op.__doc__ = op.__doc__.format(op_name=op_name) 11719d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer op.__name__ = op_name 11729d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11739d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer return op 11749d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11759d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11769d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyeradd = define_binary_op('add', math_ops.add) 1177be270ecb79c8548a0caddf67908189e6169b1472Andrew Sellesub = define_binary_op('sub', math_ops.subtract) 1178be270ecb79c8548a0caddf67908189e6169b1472Andrew Sellemul = define_binary_op('mul', math_ops.multiply) 11799d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerdiv = define_binary_op('div', math_ops.div) 11809d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyermod = define_binary_op('mod', math_ops.mod) 11819d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerpow_function = define_binary_op('pow', math_ops.pow) 11829d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11839d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerequal = define_binary_op('equal', math_ops.equal) 11849d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyergreater = define_binary_op('greater', math_ops.greater) 11859d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyergreater_equal = define_binary_op('greater_equal', math_ops.greater_equal) 11869d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyernot_equal = define_binary_op('not_equal', math_ops.not_equal) 11879d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerless = define_binary_op('less', math_ops.less) 11889d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerless_equal = define_binary_op('less_equal', math_ops.less_equal) 11899d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlogical_and = define_binary_op('logical_and', math_ops.logical_and) 11909d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlogical_or = define_binary_op('logical_or', math_ops.logical_or) 11919d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerlogical_xor = define_binary_op('logical_xor', math_ops.logical_xor) 11929d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyer 11939d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyermaximum = define_binary_op('maximum', math_ops.maximum) 11949d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerminimum = define_binary_op('minimum', math_ops.minimum) 11956e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlowersquared_difference = define_binary_op('squared_difference', 11966e9265f74e20e6818651470887846d8292083f66A. Unique TensorFlower math_ops.squared_difference) 11979d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerigamma = define_binary_op('igamma', math_ops.igamma) 11989d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerigammac = define_binary_op('igammac', math_ops.igammac) 11999d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerzeta = define_binary_op('zeta', math_ops.zeta) 12009d20f4ea4b0b5792bf88ef886d0143b7aa780522Stephan Hoyerpolygamma = define_binary_op('polygamma', math_ops.polygamma) 1201