10cf9ed3a719c0782695154d5a0bca260001cec15A. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower#
3d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
4d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# you may not use this file except in compliance with the License.
5d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# You may obtain a copy of the License at
6d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower#
7d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
8d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower#
9d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
10d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
11d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# See the License for the specific language governing permissions and
13d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# limitations under the License.
14d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# =============================================================================
15d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower"""Tests for functions."""
16d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
17d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerfrom __future__ import absolute_import
18d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerfrom __future__ import division
19d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerfrom __future__ import print_function
20d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
21c0f6357c4a080edb10dd089151dd523834ea80fcA. Unique TensorFlowerimport re
2220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyenimport sys
231804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murrayimport time
247760ce56fc3ab4ab8cdc408e29d8ad8b539c417eJosh Levenberg
25d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerimport numpy as np
26d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
2700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milnefrom tensorflow.core.framework import function_pb2
2858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.core.protobuf import config_pb2
295e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlowerfrom tensorflow.core.protobuf import rewriter_config_pb2
3058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.client import session
3158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import constant_op
3258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import dtypes
3358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import errors_impl
34d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerfrom tensorflow.python.framework import function
35b876065afe85ef7b8e7334e506af2f84b3a6add1Alexandre Passosfrom tensorflow.python.framework import graph_to_function_def
3658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import ops
3758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import tensor_shape
389624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichevfrom tensorflow.python.framework import test_util
3958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import array_ops
4058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import clip_ops
4158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import control_flow_ops
420a43371c76c2cbf47a93186acececc576b71c06bA. Unique TensorFlowerfrom tensorflow.python.ops import functional_ops
43f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlowerfrom tensorflow.python.ops import gen_logging_ops
4458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import gradients_impl
4558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import init_ops
46c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passosfrom tensorflow.python.ops import linalg_ops
4758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import logging_ops
4858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import math_ops
4958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import nn_ops
5058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import random_ops
5158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import variable_scope
5258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import variables
5358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.platform import test
5458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.platform import tf_logging
55d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
56d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
57f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlowerdef _OptimizerOptions():
58f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower  for cse in [False, True]:
59f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower    for inline in [False, True]:
60f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower      for cfold in [False, True]:
6158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        yield config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
6258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney            optimizer_options=config_pb2.OptimizerOptions(
6358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney                opt_level=config_pb2.OptimizerOptions.L0,
64f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower                do_common_subexpression_elimination=cse,
65f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower                do_function_inlining=inline,
66f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower                do_constant_folding=cfold)))
67f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
68f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
692cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api
702cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevclass FunctionTest(test.TestCase):
719624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  """Test methods for verifying Function support.
729624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
739624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  These test methods are used as mix-ins in two test cases: with
749624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  and without C API support.
759624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  """
769624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
779624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  def testIdentity(self):
789624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
799624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    @function.Defun(dtypes.float32, func_name="MyIdentity")
809624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    def MyIdentityFunc(a):
819624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      return a
829624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
839624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    with ops.Graph().as_default():
849624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      call = MyIdentityFunc([18.0])
859624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      self.assertEqual("MyIdentity", call.op.name)
869624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      with session.Session() as sess:
879624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        self.assertAllEqual([18.0], sess.run(call))
889624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
891804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray  def testIdentityImplicitDeref(self):
901804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray
911804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray    @function.Defun(dtypes.float32, func_name="MyIdentity")
921804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray    def MyIdentityFunc(a):
931804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray      return a
941804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray
951804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray    with ops.Graph().as_default():
961804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray      var = variables.Variable([18.0])
971804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray      call = MyIdentityFunc(var._ref())  # pylint: disable=protected-access
981804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray      self.assertEqual("MyIdentity", call.op.name)
991804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray      for cfg in _OptimizerOptions():
1001804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray        with session.Session(config=cfg) as sess:
1011804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray          sess.run(var.initializer)
1021804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray          self.assertAllEqual([18.0], sess.run(call))
1031804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray
1049624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  def testIdentityOutputName(self):
1059624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1069624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    @function.Defun(
1079624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        dtypes.float32, func_name="MyIdentity", out_names=["my_result_name"])
1089624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    def MyIdentityFunc(a):
1099624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      return a
1109624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1119624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    with ops.Graph().as_default():
1129624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      call = MyIdentityFunc([18.0])
1139624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      self.assertEqual("MyIdentity", call.op.name)
1149624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      with session.Session() as sess:
1159624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        self.assertAllEqual([18.0], sess.run(call))
1169624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1179624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  def testTooManyOutputNames(self):
1189624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1199624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    @function.Defun(
1209624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        dtypes.float32, func_name="MyIdentity",
1219624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        out_names=["my_result1", "my_result2"])
1229624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    def MyIdentityFunc(a):
1239624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      return a
1249624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1259624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    with ops.Graph().as_default():
1269624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      with self.assertRaisesRegexp(
1272198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev          errors_impl.InvalidArgumentError,
1282198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev          (r"output names must be either empty or equal in size to outputs. "
1292198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev           "output names size = 2 outputs size = 1")):
1309624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        MyIdentityFunc([18.0])
131d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
132d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower  def testDefineFunction2Args(self):
133d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
13458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, dtypes.float32, func_name="APlus2B")
135d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    def APlus2B(a, b):
136d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      return a + b * 2
137d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
13858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
139fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      call = APlus2B([1.0], [2.0])
1407a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual("APlus2B", call.op.name)
14158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with session.Session() as sess:
142d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower        self.assertAllEqual([5.0], sess.run(call))
143d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
1442198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev  def testFunctionWithNoOutput(self):
1459624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1469624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    @function.Defun(dtypes.float32, dtypes.float32)
1479624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    def APlus2B(a, b):
1482198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev      c = a + b * 2  # Create some ops to have nodes in the body
1492198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev      print(c)  # Using 'print' to make lint happy
1509624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1519624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    with ops.Graph().as_default():
1522198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev      # Call function. There should be no exceptions.
1532198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev      APlus2B([1.0], [2.0])
1549624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1559624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  def testDefineFunction2ArgsOutputName(self):
1569624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1579624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    @function.Defun(
1589624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        dtypes.float32,
1599624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        dtypes.float32,
1609624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        func_name="APlus2B",
1619624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        out_names=["my_result_name"])
1629624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    def APlus2B(a, b):
1639624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      return a + b * 2
1649624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1659624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev    with ops.Graph().as_default():
1669624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      call = APlus2B([1.0], [2.0])
1679624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      self.assertEqual("APlus2B", call.op.name)
1689624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev      with session.Session() as sess:
1699624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev        self.assertAllEqual([5.0], sess.run(call))
1709624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev
1712d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower  def testDefineFunctionDuplicateOutputs(self):
1722d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower
17358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, func_name="Duplicate")
1742d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower    def Duplicate(a):
1752d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      b = a + 1.0
1762d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      return b, b
1772d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower
17858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
1792d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower    with g.as_default():
1802d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      Duplicate([3.0])
1812d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      func_sig = g.as_graph_def().library.function[0].signature
1822d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      # The names given to both outputs should be different
1832d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      # even though the same tensor is emitted to both.
1842d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      out_names = [a.name for a in func_sig.output_arg]
1852d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      self.assertEqual(2, len(out_names))
1862d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower      self.assertNotEqual(out_names[0], out_names[1])
1872d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower
188d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower  def testGradientFunc(self):
189d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
19058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, func_name="XSquarePlusOneFn")
191d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    def XSquarePlusOne(x):
192d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      return x * x + 1.0
193d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
19458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, dtypes.float32)
195d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    def XSquarePlusOneGrad(x, dy):
1966bf83b7061df648d4751bc443782348fc8ea5c17Eugene Brevdo      dx = functional_ops._symbolic_gradient(
19758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          input=[x, dy], Tout=[dtypes.float32], f="XSquarePlusOneFn", name="dx")
198d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      return dx
199d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
20058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
201d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    with g.as_default():
202fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      call_f = XSquarePlusOne([2.0])
203fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      call_g = XSquarePlusOneGrad([2.0], [0.1])
204d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
20558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with session.Session() as sess:
206d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower        self.assertAllClose([5.0], sess.run(call_f))
207d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower        self.assertAllClose([0.4], sess.run(call_g))
208d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
20984f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower  def testTanhSymGrad(self):
210ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
21158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32)
2121d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def Forward(x):
21358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return math_ops.reduce_sum(math_ops.tanh(x))
2141d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
21558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
21684f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower    with g.as_default():
21758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = array_ops.placeholder(dtypes.float32)
21884f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower      y = Forward(x)
21958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dx = gradients_impl.gradients([y], [x])
22084f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower
22184f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower    inp = np.array([-1, 1, 2, -2], dtype=np.float32)
22284f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower    feed = {x: inp}
22358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
22458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        optimizer_options=config_pb2.OptimizerOptions(
22558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney            opt_level=config_pb2.OptimizerOptions.L1,
22658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney            do_function_inlining=True)))
22758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with session.Session(graph=g, config=cfg) as sess:
22884f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower      out, = sess.run(dx, feed)
22984f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower    self.assertAllClose(1 - np.square(np.tanh(inp)), out)
23084f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower
231f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower  def testCustomGradient(self):
23258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    dtype = dtypes.float32
233f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
2341d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    @function.Defun(dtype, dtype, dtype)
2351d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def XentLossGrad(logits, labels, dloss):
23658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dlogits = array_ops.reshape(dloss, [-1, 1]) * (
23758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          nn_ops.softmax(logits) - labels)
23858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dlabels = array_ops.zeros_like(labels)
2391d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      # Takes exp(dlogits) to differentiate it from the "correct" gradient.
24058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return math_ops.exp(dlogits), dlabels
241f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
2421d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    @function.Defun(dtype, dtype, grad_func=XentLossGrad)
2431d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def XentLoss(logits, labels):
24458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return math_ops.reduce_sum(labels * math_ops.log(nn_ops.softmax(logits)),
24558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney                                 1)
246f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
24758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
2481d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    with g.as_default():
24958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      logits = array_ops.placeholder(dtype)
25058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      labels = array_ops.placeholder(dtype)
251f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower      loss = XentLoss(logits, labels)
25258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dlogits = gradients_impl.gradients([loss], [logits])
253f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
254f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower    x = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32)
255f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower    prob = np.exp(x) / np.sum(np.exp(x), 1, keepdims=1)
256f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower    y = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32)
257f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower    for cfg in _OptimizerOptions():
25858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      tf_logging.info("cfg = %s", cfg)
25958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with session.Session(graph=g, config=cfg) as sess:
260f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower        out, = sess.run(dlogits, {logits: x, labels: y})
261f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower      self.assertAllClose(out, np.exp(prob - y))
262f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
263f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower  def testCustomGradientError(self):
26458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    dtype = dtypes.float32
265f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
2661d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    @function.Defun(dtype, dtype, dtype)
2671d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def Grad(x, dy, dz):
2681d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      # Should have returned 1 result.
2691d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      return x, dy + dz
270f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
2711d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    @function.Defun(dtype, grad_func=Grad)
2721d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def Forward(x):
2731d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      return x, x
274f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
27558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
2761d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    with g.as_default():
27758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      inp = array_ops.placeholder(dtype)
27858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      out = math_ops.add_n(Forward(inp))
27958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dinp = gradients_impl.gradients(out, [inp])
280f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
281f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower    x = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32)
28258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with session.Session(graph=g) as sess:
283f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower      with self.assertRaisesRegexp(
28458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          errors_impl.InvalidArgumentError,
285f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower          "SymGrad expects to return 1.*but get 2.*instead"):
286f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower        _ = sess.run(dinp, {inp: x})
287f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower
288d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower  def testSymGradShape(self):
28958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
290d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    with g.as_default():
29158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = array_ops.placeholder(dtypes.float32, [25, 4])
29258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      y = array_ops.placeholder(dtypes.float32, [200, 100])
29358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dz = array_ops.placeholder(dtypes.float32, [1])
294d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      # We assume Foo is a function of (x, y) -> (z) Then, Foo's
295d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      # gradient function is (x, y, dz) -> (dx, dy).  dx's shape
296d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      # should be the same as x's; and dy's shape should be the same
297d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      # as y's.
298ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower      dx, dy = functional_ops._symbolic_gradient(
29958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          input=[x, y, dz], Tout=[dtypes.float32] * 2, f="Foo")
3007a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual(x.get_shape(), dx.get_shape())
3017a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual(y.get_shape(), dy.get_shape())
302d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
303f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower  def testSymGradAttr(self):
304aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov
305f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower    @function.Defun(noinline=True)
306f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower    def Foo(x):
307f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower      return x * 2
308f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower
309aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov    self.assertTrue(
31058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        Foo.instantiate([dtypes.float32]).definition.attr["_noinline"].b)
311aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov
31258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
313f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower    with g.as_default():
31458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = constant_op.constant(3.0)
315f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower      y = Foo(x)
31658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dx, = gradients_impl.gradients(y, [x])
317f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower
31858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
31958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        optimizer_options=config_pb2.OptimizerOptions(
32058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney            opt_level=config_pb2.OptimizerOptions.L0,
321f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower            do_common_subexpression_elimination=True,
322f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower            do_function_inlining=True,
323f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower            do_constant_folding=True)))
324f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower
325f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower    with self.test_session(graph=g, config=cfg):
326f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower      self.assertAllClose(y.eval(), 6.)
327f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower      self.assertAllClose(dx.eval(), 2.)
328f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower
3292eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower  def _testZNoDepOnY(self, use_const_grad_ys):
33058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, dtypes.float32)
331bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower    def Foo(x, y):  # pylint: disable=unused-argument
3321d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      return x * 2
3331d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
33458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
335449ecb561f6b480a6043d23160be00f35b524aa9A. Unique TensorFlower      # z = Foo(x, y). z doe
33658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = constant_op.constant(1.0)
33758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      y = constant_op.constant(2.0)
338449ecb561f6b480a6043d23160be00f35b524aa9A. Unique TensorFlower      z = Foo(x, y)
3392eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower      if use_const_grad_ys:
3402eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower        dx, dy = gradients_impl.gradients([z], [x, y], grad_ys=[1.0])
3412eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower      else:
3422eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower        dx, dy = gradients_impl.gradients([z], [x, y])
34358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with session.Session() as sess:
344449ecb561f6b480a6043d23160be00f35b524aa9A. Unique TensorFlower        dx_val, dy_val = sess.run([dx, dy])
3457a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower        self.assertEqual([2.0], dx_val)
3467a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower        self.assertEqual([0.0], dy_val)
347449ecb561f6b480a6043d23160be00f35b524aa9A. Unique TensorFlower
3482eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower  def testZNoDepOnY(self):
3492eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower    self._testZNoDepOnY(False)
3502eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower
3512eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower  def testZNoDepOnYConstGradYs(self):
3522eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower    # Tests for constant folding of grad_ys
3532eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower    self._testZNoDepOnY(True)
3542eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower
355d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower  def testDefineFunctionNoArgs(self):
356d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
3577a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    @function.Defun(func_name="AConstant")
358d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    def AConstant():
35958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return constant_op.constant([42])
360d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
36158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
362e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower
363e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower      call = AConstant()
3647a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual("AConstant", call.op.name)
36558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with session.Session() as sess:
366d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower        self.assertAllEqual([42], sess.run(call))
367d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
368d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower  def testDefineFunctionNames(self):
369d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
37058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, func_name="Foo")
371d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    def Foo(a):
372d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      return a + 1
373d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
37458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
375fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      call1 = Foo([1.0])
3767a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual("Foo", call1.op.name)
377fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      call2 = Foo([1.0])
3787a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual("Foo_1", call2.op.name)
379bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower      # pylint: disable=unexpected-keyword-arg
380fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      call3 = Foo([1.0], name="mine")
3817a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual("mine", call3.op.name)
38258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with ops.name_scope("my"):
383fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower        call4 = Foo([1.0], name="precious")
3847a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower        self.assertEqual("my/precious", call4.op.name)
385d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
386f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower  def testNoOp(self):
387fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower
38858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32)
389f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower    def Foo(x):
390f28ae398cc5b875b936ad6e5cd4d280928c38409Olivia Nordquist      y = logging_ops.Print(x, [], "Hello")
39158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with ops.control_dependencies([y]):
39258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        z = control_flow_ops.no_op()
39358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with ops.control_dependencies([z]):
394f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower        return x * 2
395f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower
39658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default(), self.test_session():
39758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      z = Foo(constant_op.constant(3.0))
398f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower      self.assertAllEqual(z.eval(), 6.0)
399f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower
40023c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower  def testAssertOp(self):
401fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower
40258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32)
403f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower    def Foo(x):
40458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      check = gen_logging_ops._assert(math_ops.greater(x, 0), [x])
40558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with ops.control_dependencies([check]):
406f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower        return x * 2
407f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower
40858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
409f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower    with g.as_default(), self.test_session():
41058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
41158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
412f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower                                   "assertion failed.*-3"):
41358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0)
414f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower
4159624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  @test_util.disable_c_api   # Op._add_control_inputs doesn't work with C API
41623c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower  def testAssertWrapper(self):
4177a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
41858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32)
41923c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower    def MyFn(x):
42058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with ops.control_dependencies(
42158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          [control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]):
42258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        return array_ops.identity(x)
42323c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower
42423c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower    with self.test_session():
4257a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual(1.0, MyFn(1.0).eval())
42658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
42758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney                                   "assertion"):
42823c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower        _ = MyFn(100.0).eval()
42923c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower
4309624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  @test_util.disable_c_api   # Op._add_control_inputs doesn't work with C API
431689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower  def testWhileLoopCallsFunc(self):
432689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower    with self.test_session(use_gpu=True) as sess:
433689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
434689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      @function.Defun(dtypes.float32)
435689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      def Times2(x):
436689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        constant_two = constant_op.constant(2, dtypes.int32)
437689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        two_on_gpu = math_ops.cast(constant_two, dtypes.float32)
438689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        return x * two_on_gpu
439689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
440689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      def Body(x):
441689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        x2 = Times2(x)
442689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        x2.set_shape([])
443689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        return x2
444689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
445689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0])
446689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
447689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      ans = sess.run(loop)
448689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      self.assertAllClose(ans, 131072.)
449689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
4509624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev  @test_util.disable_c_api   # Op._add_control_inputs doesn't work with C API
451e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins  def testControlFlowStrictness(self):
452e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins    """Inlined functions must not execute in a untaken control flow branch."""
453e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins
454e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins    @function.Defun(dtypes.int32)
455e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins    def AssertFail(x):
456e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      # Assertion that always fails and does not have a data dependency on `x`.
457e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      assert_false = control_flow_ops.Assert(False, [42])
458e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      with ops.control_dependencies([assert_false]):
459e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins        return array_ops.identity(x)
460e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins
461e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins    with ops.device("CPU"):
462e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      pred = array_ops.placeholder(dtypes.bool)
463e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      x = array_ops.placeholder(dtypes.int32)
464e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      cond = control_flow_ops.cond(pred, lambda: x + 1, lambda: AssertFail(x))
465e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      # pylint: disable=unnecessary-lambda
466e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      loop = control_flow_ops.while_loop(lambda y: pred,
467e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins                                         lambda y: AssertFail(y), [x])
468e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      # pylint: enable=unnecessary-lambda
469e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins
4705e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower    rewriter_config = rewriter_config_pb2.RewriterConfig(
4715e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower        dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
472e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins    # Enables inlining.
4735e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower    config = config_pb2.ConfigProto(
4745e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower        graph_options=config_pb2.GraphOptions(
4755e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower            optimizer_options=config_pb2.OptimizerOptions(
4765e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower                opt_level=config_pb2.OptimizerOptions.L0,
4775e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower                do_common_subexpression_elimination=True,
4785e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower                do_function_inlining=True,
4795e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower                do_constant_folding=True),
4805e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower            rewrite_options=rewriter_config))
481e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins
482e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins    with session.Session(config=config) as sess:
483e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      # Since the 'False' branch is not taken, the assertion should not fire.
484e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      self.assertEqual(4, sess.run(cond, {pred: True, x: 3}))
485e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins
486e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      # The assertion should still fire if the False branch is taken.
487e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
488e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins                                   "assertion"):
489e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins        sess.run(cond, {pred: False, x: 3})
490e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins
491e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      # Similarly for loops.
492e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      self.assertEqual(3, sess.run(loop, {pred: False, x: 3}))
493e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins      with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
494e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins                                   "assertion"):
495e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins        sess.run(loop, {pred: True, x: 3})
496e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins
497fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower  def testVar(self):
498fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower
49958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32)
500fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower    def Foo(x):
501fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      return x * x + 1
502fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower
50358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
504fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower    with g.as_default():
50558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      v = variables.Variable(constant_op.constant(10.0))
506fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      z = Foo(v)
507fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower
508fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower    with self.test_session(graph=g):
50958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      variables.global_variables_initializer().run()
510fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      self.assertAllEqual(z.eval(), 101.)
511fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower
5124d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower  def testResourceVarAsImplicitInput(self):
5134d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower    g = ops.Graph()
5140a667240282f6215b36811475245509254d0127eAlexandre Passos    with g.as_default(), ops.device("cpu:0"):
5154d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower      v = variable_scope.get_variable(
5164d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower          "var", (4, 4), dtypes.float32, use_resource=True)
5174d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower
5184d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower      @function.Defun()
5194d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower      def Foo():
5204d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower        return array_ops.identity(v)
5214d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower
5224d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower      y = v.value()
5234d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower      z = Foo()
5244d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower
5254d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower    with self.test_session(graph=g):
5264d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower      v.initializer.run()
5274d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower      self.assertAllEqual(y.eval(), z.eval())
5284d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower
529d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower  def testDefineErrors(self):
53058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
5317a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "can not return None"):
532d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
533e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        @function.Defun()
5347a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower        def TwoNone():
5357a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower          return None, None
5367a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
5377a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower        _ = TwoNone.definition
5387a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
539e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "are not supported"):
540d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
541e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        @function.Defun()
542e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        def DefaultArg(unused_a=12):
54358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          return constant_op.constant([1])
544d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
545e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        _ = DefaultArg.definition
54619b7fd80780d02372f76076bc8eb40d55a89a301A. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "are not supported"):
547e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower
548e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        @function.Defun()
549e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        def KwArgs(**unused_kwargs):
55058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          return constant_op.constant([1])
551e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower
552e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        _ = KwArgs.definition
553d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "specified input types"):
554e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower
55558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        @function.Defun(dtypes.float32)
556e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        def PlusMinusV2(a, b):
557e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower          return a + b, b - a
558e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower
559e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        _ = PlusMinusV2.definition
560a3c34f649d6f3d6c188cc59fe884facf4aad117eA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "specified input types"):
561e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower
56258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        @function.Defun(dtypes.float32, dtypes.float32, dtypes.float32)
563e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        def PlusMinusV3(a, b):
564e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower          return a + b, b - a
565e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower
566e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        _ = PlusMinusV3.definition
567d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
568d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower  def testCallErrors(self):
569d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
570e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower    @function.Defun()
571d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    def Const():
57258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return constant_op.constant(1)
573d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
57458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.int32)
575d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    def PlusOne(a):
576d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      return a + 1
577d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
57858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.int32, dtypes.int32)
579d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower    def PlusMinus(a, b):
580d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      return a + b, b - a
581d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
58258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
583d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
584e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower      _ = Const()
585bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower      # pylint: disable=too-many-function-args
586bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower      # pylint: disable=unexpected-keyword-arg
587bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower      # pylint: disable=no-value-for-parameter
588d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "arguments: 0"):
589fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower        _ = Const(1)
590d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "arguments: 0"):
591fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower        _ = Const(1, 2)
592d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
593d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "arguments: 1"):
594e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        _ = PlusOne()
595fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      _ = PlusOne(1)
596d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "arguments: 1"):
597fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower        _ = PlusOne(1, 2)
598d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
599d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "arguments: 2"):
600e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower        _ = PlusMinus()
601d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "arguments: 2"):
602fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower        _ = PlusMinus(1)
603fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      _ = PlusMinus(1, 2)
604d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
605fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      _ = PlusOne(1, name="p1")
606d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "Unknown keyword arguments"):
60728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower        _ = PlusOne(1, device="/device:GPU:0")
608d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
6091d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower  def testFunctionDecorator(self):
610ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
61158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, func_name="Minus1")
6121d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def Minus1(b):
6131d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      return b - 1.0
614a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower
61558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
616fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      call1 = Minus1([2.])
6176bf83b7061df648d4751bc443782348fc8ea5c17Eugene Brevdo      self.assertTrue(isinstance(Minus1, function._DefinedFunction))
6186bf83b7061df648d4751bc443782348fc8ea5c17Eugene Brevdo      self.assertEqual(Minus1.name, "Minus1")
619d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      # pylint: disable=unexpected-keyword-arg
620d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower      call2 = Minus1(call1, name="next")
6211d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      # pylint: enable=unexpected-keyword-arg
6227a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertEqual("next", call2.op.name)
62358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with session.Session() as sess:
624d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower        self.assertAllEqual([1], sess.run(call1))
625d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower        self.assertAllEqual([0], sess.run(call2))
626d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
627a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower  def testNestedFunction(self):
628ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
62958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32)
6301d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def Cube(x):
6311d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      return x * x * x
6321d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
63358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, dtypes.float32)
6341d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def CubeXPlusY(x, y):
6351d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      return Cube(x) + y
6361d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
63758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
638fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      z = CubeXPlusY(3.0, -2.0)
6391d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      with self.test_session():
6401d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower        self.assertAllEqual(z.eval(), 25.0)
6413f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
6421d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower  def testNestedDefinedFunction(self):
643ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
64458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32, dtypes.float32)
6451d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def CubeXPlusY(x, y):
646ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
64758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      @function.Defun(dtypes.float32)
648a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower      def Cube(x):
649a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower        return x * x * x
650ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
6511d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      return Cube(x) + y
652ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
65358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
654fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower      z = CubeXPlusY(3.0, -2.0)
655a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower      with self.test_session():
656a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower        self.assertAllEqual(z.eval(), 25.0)
657a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower
6581d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower  def testUnusedFunction(self):
6591d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    invoked = False
6601d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    # pylint: disable=unused-variable
6611d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    @function.Defun()
6621d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def Unused():
6631d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      invoked = True
66458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return constant_op.constant(42.)
665ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
6661d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    self.assertFalse(invoked)
66758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
6681d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    with g.as_default():
669ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
6701d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      @function.Defun()
6711d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      def Unused2():
6721d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower        invoked = True
67358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        return constant_op.constant(7.)
674ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
67558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      constant_op.constant(3.)
6761d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    # pylint: enable=unused-variable
6771d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    self.assertFalse(invoked)
6781d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    gdef = g.as_graph_def()
6797a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    self.assertEqual(0, len(gdef.library.function))
6801d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
6812c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower  def testReduction(self):
68258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
6832c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower
6842c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower    # BN0 is computing batch normed matrix along rows.
6852c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower    def BN0(x):
68658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      mean = math_ops.reduce_mean(x, [0])
68758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      var = math_ops.reduce_mean(math_ops.square(x - mean))  # biased var
68858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      rstd = math_ops.rsqrt(var + 1e-8)
6892c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower      return (x - mean) * rstd
6902c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower
6911d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    # Wraps BatchNorm in a tf function.
69258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    @function.Defun(dtypes.float32)
6931d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    def BN1(x):
6941d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      return BN0(x)
6951d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
6961d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower    with g.as_default():
69758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = array_ops.placeholder(dtypes.float32)
6982c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower      y0 = BN0(x)  # A plain graph
6992c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower      y1 = BN1(x)  # A tf function
70058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dx0, = gradients_impl.gradients([y0], [x])
70158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dx1, = gradients_impl.gradients([y1], [x])
7021d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
7032c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower    # Both should produce the same result and gradient.
7042c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower    with self.test_session(graph=g) as sess:
7052c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower      vals = sess.run([y0, y1, dx0, dx1], {x: np.random.uniform(size=(3, 7))})
7062c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower      self.assertAllClose(vals[0], vals[1])
7072c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower      self.assertAllClose(vals[2], vals[3])
7082c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower
709b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower  def testCapture(self):
71058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
711b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower    with g.as_default():
71258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      w = variables.Variable(constant_op.constant([[1.0]]))
71358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      b = variables.Variable(constant_op.constant([2.0]))
714b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
715b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower      # Foo() captures w and b.
71658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      @function.Defun(dtypes.float32)
717b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower      def Foo(x):
718b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
719b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower        # Plus() captures b.
72058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        @function.Defun(dtypes.float32)
721b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower        def Plus(y):
722b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower          return y + b
723b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
72458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        return Plus(math_ops.matmul(w, x))
725b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
72658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      y = Foo(constant_op.constant([[10.]]))
727b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
728b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower    with self.test_session(graph=g):
72958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      variables.global_variables_initializer().run()
730b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower      self.assertAllEqual(y.eval(), [[12.0]])
731b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
732b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower  def testCaptureControls(self):
73358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
734b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower    with g.as_default():
73558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = constant_op.constant([10.0])
73658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = logging_ops.Print(x, [x], "outer")
737b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
73858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      @function.Defun(dtypes.float32)
739b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower      def Foo(y):
74058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        with ops.control_dependencies([x]):
74158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          y = logging_ops.Print(y, [y], "inner")
742b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower        return y
743b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
744b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower      with self.assertRaisesRegexp(ValueError, "not an element of this graph."):
745b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower        # NOTE: We still do not support capturing control deps.
746b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower        _ = Foo(x)
747b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower
7481c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne  def testCaptureInWhileLoop(self):
7491c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne    g = ops.Graph()
7501c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne    with g.as_default():
7511c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      x = constant_op.constant(1)
7521c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne
7531c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      @function.Defun()
7541c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      def Foo():
7551c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne        return control_flow_ops.while_loop(lambda i: i < 10,
7561c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne                                           lambda i: i + x,
7571c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne                                           [0])
7581c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      y = Foo()
7591c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne
7601c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne    with self.test_session(graph=g) as sess:
7611c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      self.assertEqual(sess.run(y), 10)
7621c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne
7631c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne  def testCaptureInCond(self):
7641c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne    g = ops.Graph()
7651c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne    with g.as_default():
7661c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      x = constant_op.constant(1)
7671c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne
7681c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      @function.Defun(dtypes.bool)
7691c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      def Foo(pred):
7701c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne        return control_flow_ops.cond(pred,
7711c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne                                     lambda: x,
7721c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne                                     lambda: x + 1)
7731c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      y = Foo(True)
7741c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      z = Foo(False)
7751c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne
7761c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne    with self.test_session(graph=g) as sess:
7771c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      self.assertEqual(sess.run(y), 1)
7781c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne      self.assertEqual(sess.run(z), 2)
7791c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne
780b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower  def testStableName(self):
781b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower
782b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower    @function.Defun()
783b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower    def Foo(x, y, z):
78458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return math_ops.tanh(math_ops.matmul(x, y) + z)
785b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower
7862198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev    # We added more randomness to function names in C API.
7872198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev    # TODO(iga): Remove this if statement when we switch to C API.
7882198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev    if ops._USE_C_API:  # pylint: disable=protected-access
7891804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray      if sys.byteorder == "big":
79020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen        self.assertEqual("Foo_kEdkAG8SJvg",
79120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen                         Foo.instantiate([dtypes.float32] * 3).name)
79220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen      else:
79320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen        self.assertEqual("Foo_aCYSbwBkR5A",
79420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen                         Foo.instantiate([dtypes.float32] * 3).name)
7952198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev    else:
7962198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev      self.assertEqual("Foo_d643acf7",
7972198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev                       Foo.instantiate([dtypes.float32] * 3).name)
7982eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
7992eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower  def testSignatureHash(self):
8002eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower    # Foo.Inner and Bar.Inner have identical function body but have
8012eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower    # different signatures. They should be treated as two different functions.
8022eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
8032eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower    @function.Defun()
8042eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower    def Foo(x):
8052eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
8062eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      @function.Defun()
8072eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      def Inner(x):
8082eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower        return x + 10.
8092eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
8102eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      return Inner(x)
8112eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
8122eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower    @function.Defun()
8132eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower    def Bar(x):
8142eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
8152eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      @function.Defun()
8162eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      def Inner(x, unused_y, unused_z):
8172eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower        return x + 10.
8182eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
8192eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      return Inner(x, 2., 3.)
8202eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
82158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
8222eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower    with g.as_default():
82358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = constant_op.constant(10.0)
8242eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      y = Foo(x)
8252eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      z = Bar(x)
8262eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower
8272eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower    with self.test_session(graph=g) as sess:
8282eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      v0, v1 = sess.run([y, z])
8292eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      self.assertAllEqual(v0, 20.)
8302eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower      self.assertAllEqual(v1, 20.)
831b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower
8328343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower  def testShapeFunction(self):
833689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
834689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower    @function.Defun(
835689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        dtypes.float32, shape_func=lambda op: [op.inputs[0].get_shape()])
8368343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower    def Foo(x):
8378343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower      return x + 1.0
8388343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower
8398343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower    @function.Defun(
8408343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower        shape_func=lambda op: [[1] + op.inputs[0].get_shape().as_list()])
8418343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower    def Bar(x):
8428343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower      return array_ops.stack([x])
8438343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower
8448343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower    g = ops.Graph()
8458343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower    with g.as_default():
8468343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower      x = Foo([1.0, 2.0])
8478343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower      self.assertEqual(x.get_shape().as_list(), [2])
8488343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower      y = Bar(array_ops.zeros([1, 2, 3]))
8498343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower      self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3])
8508343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower
851e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower  def testVariableReuse(self):
852689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
853e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    def LinearWithReuse(input_tensor, reuse=None):
854e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      size = input_tensor.shape.dims[1]
855e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      with variable_scope.variable_scope("linear", reuse=reuse):
856689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        w = variable_scope.get_variable(
857689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower            "w", shape=[size, size], dtype=input_tensor.dtype)
858e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      return math_ops.matmul(input_tensor, w)
859e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
860e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    @function.Defun(dtypes.float32)
861e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    def Foo(inputs):
862e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      inputs = array_ops.reshape(inputs, [32, 100])
863e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      hidden = LinearWithReuse(inputs)
864e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      return LinearWithReuse(hidden, reuse=True)
865e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
866e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    input_op = array_ops.placeholder(shape=[32, 100], dtype=dtypes.float32)
867e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    output_op = Foo(input_op)
868e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
869e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    global_vars = variables.global_variables()
870e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    self.assertEqual(len(global_vars), 1)
871e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    self.assertEqual(global_vars[0].name, "linear/w:0")
872e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
873e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    with session.Session() as sess:
874e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      sess.run(variables.global_variables_initializer())
875689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      output_val = sess.run(
876689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          output_op, feed_dict={input_op: np.random.rand(32, 100)})
877e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      self.assertEqual(output_val.shape, (32, 100))
878e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
879e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower  def testFunctionCallInDifferentVariableScopes(self):
880689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
881e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    @function.Defun(dtypes.float32)
882e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    def Foo(inputs):
883689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      var = variable_scope.get_variable(
884689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          "var",
885689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          shape=[10],
886689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          dtype=dtypes.float32,
887689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          initializer=init_ops.ones_initializer())
888e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      return inputs + var
889e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
890e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32)
891e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    with variable_scope.variable_scope("vs1"):
892e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      out1_op = Foo(input_op)
893e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
894e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    with variable_scope.variable_scope("vs2"):
895e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      out2_op = Foo(input_op)
896e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
897e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    global_vars = variables.global_variables()
898e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    self.assertEqual(len(global_vars), 1)
899e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    self.assertEqual(global_vars[0].name, "vs1/var:0")
900e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
901e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower    with session.Session() as sess:
902e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      sess.run(variables.global_variables_initializer())
903689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower      out1, out2 = sess.run(
904689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          [out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)})
905e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      self.assertAllEqual(out1, np.linspace(2, 11, 10))
906e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower      self.assertAllEqual(out2, np.linspace(2, 11, 10))
907e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower
908c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos  def testTwoInputsSameOp(self):
909c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos    g = ops.Graph()
910c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos    with g.as_default():
911c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos      m = array_ops.placeholder(dtypes.float32)
912c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos      s, u, v = linalg_ops.svd(m)
913c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos      ss = math_ops.reduce_sum(s)
914c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos      uu = math_ops.reduce_sum(u)
915c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos      vv = math_ops.reduce_sum(v)
916c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos      result = ss + uu + vv
917b876065afe85ef7b8e7334e506af2f84b3a6add1Alexandre Passos    f = graph_to_function_def.graph_to_function_def(
918c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos        g,
919c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos        g.get_operations()[1:],  # skip the placeholder
920c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos        [s, u, v],
921c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos        [result])
922c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos    self.assertEqual(len(f.signature.input_arg), 3)
923c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos
9244723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan  def testGradientWithIntegerFunctionArgument(self):
9254723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan    @function.Defun(dtypes.int32, dtypes.float32)
9264723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan    def Foo(t, x):
9274723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan      return x[t]
9284723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan
9294723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan    g = ops.Graph()
9304723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan    with g.as_default():
9314723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan      inp = array_ops.placeholder(dtypes.float32)
9324723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan      t = constant_op.constant(0, dtypes.int32)
9334723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan      out = Foo(t, inp)
9344723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan      dinp, = gradients_impl.gradients(out, [inp])
9354723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan
9364723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan    x = np.zeros((2,)).astype(np.float32)
9374723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan    with session.Session(graph=g) as sess:
9384723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan      self.assertAllClose(
9394723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan          np.array([1.0, 0.0]).astype(np.float32),
9404723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan          sess.run(dinp, {inp: x}))
9414723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan
9421c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray  def testFunctionMarkedStateful(self):
9431c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray
9441c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    @function.Defun(dtypes.int32, dtypes.float32)
9451c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    def Foo(t, x):
9461c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray      return x[t]
9471c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray
9481c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    @function.Defun(dtypes.int64)
9491c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    def Bar(x):
9501c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray      return x
9511c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray
9521c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    # NOTE(mrry): All functions are currently considered stateless by the
9531c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    # runtime, so we simulate a "stateful" function.
9541c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    # TODO(b/70565970): Remove this hack when we are able to build stateful
9551c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    # functions using the API.
9561c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    # pylint: disable=protected-access
9571c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    Foo._signature.is_stateful = True
9581c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    Bar._signature.is_stateful = True
9591c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    # pylint: enable=protected-access
9601c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray
9611c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    result_1 = Foo(3, [1.0, 2.0, 3.0, 4.0])
9621c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64))
9631c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray
9641c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray    with session.Session() as sess:
9651c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray      self.assertEqual(4.0, sess.run(result_1))
9661c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray      self.assertEqual(100, sess.run(result_2))
9671c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray      self.assertEqual((4.0, 100), sess.run((result_1, result_2)))
9681c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray
969185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray  def testStatefulFunction(self):
970185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray
971185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    @function.Defun()
972185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    def FunctionWithStatelessOp():
973185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray      return constant_op.constant(42.0)
974185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray
975185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    @function.Defun()
976185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    def FunctionWithStatefulOp():
977185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray      return random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32)
978185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray
979185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    @function.Defun()
980185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    def FunctionWithStatelessFunctionCall():
981185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray      return FunctionWithStatelessOp()
982185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray
983185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    @function.Defun()
984185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    def FunctionWithStatefulFunctionCall():
985185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray      return FunctionWithStatefulOp()
986185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray
987185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    # Test that the `is_stateful` bit is propagated.
988185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    self.assertFalse(FunctionWithStatelessOp.definition.signature.is_stateful)
989185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    self.assertTrue(FunctionWithStatefulOp.definition.signature.is_stateful)
990185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    self.assertFalse(
991185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray        FunctionWithStatelessFunctionCall.definition.signature.is_stateful)
992185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    self.assertTrue(
993185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray        FunctionWithStatefulFunctionCall.definition.signature.is_stateful)
994185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray
995185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    # Ensure that two invocations of the same random-number-generating
996185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    # function produce different results.
997185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    result1 = FunctionWithStatefulFunctionCall()
998185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    result2 = FunctionWithStatefulFunctionCall()
999185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray
1000185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    # Statefulness affects how the function is treated by the various
1001185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    # optimization passes, so run the test in each optimizer
1002185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    # configuration.
1003185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray    for config in _OptimizerOptions():
1004185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray      with session.Session(config=config) as sess:
1005185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray        val1, val2 = sess.run((result1, result2))
1006185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray        self.assertFalse(all(val1 == val2))
1007185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray        val3, val4 = sess.run((result1, result2))
1008185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray        self.assertFalse(all(val3 == val1))
1009185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray        self.assertFalse(all(val4 == val2))
1010185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray
101132d138db751c541e951d1958cac4918214e9644eDerek Murray  def testSameFunctionOnTwoDevices(self):
101232d138db751c541e951d1958cac4918214e9644eDerek Murray
101332d138db751c541e951d1958cac4918214e9644eDerek Murray    @function.Defun(dtypes.float32)
101432d138db751c541e951d1958cac4918214e9644eDerek Murray    def AddOne(x):
101532d138db751c541e951d1958cac4918214e9644eDerek Murray      return x + 1.0
101632d138db751c541e951d1958cac4918214e9644eDerek Murray
101732d138db751c541e951d1958cac4918214e9644eDerek Murray    with ops.device("/cpu:0"):
101832d138db751c541e951d1958cac4918214e9644eDerek Murray      f_0 = AddOne(41.0)
101932d138db751c541e951d1958cac4918214e9644eDerek Murray
102032d138db751c541e951d1958cac4918214e9644eDerek Murray    with ops.device("/cpu:1"):
102132d138db751c541e951d1958cac4918214e9644eDerek Murray      f_1 = AddOne(43.0)
102232d138db751c541e951d1958cac4918214e9644eDerek Murray
102332d138db751c541e951d1958cac4918214e9644eDerek Murray    for config in _OptimizerOptions():
102432d138db751c541e951d1958cac4918214e9644eDerek Murray      config.device_count["CPU"] = 2
102532d138db751c541e951d1958cac4918214e9644eDerek Murray      with session.Session(config=config) as sess:
102632d138db751c541e951d1958cac4918214e9644eDerek Murray        self.assertEqual(42.0, sess.run(f_0))
102732d138db751c541e951d1958cac4918214e9644eDerek Murray        self.assertEqual(44.0, sess.run(f_1))
102832d138db751c541e951d1958cac4918214e9644eDerek Murray        self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
102932d138db751c541e951d1958cac4918214e9644eDerek Murray
10303f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
10312cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api
103200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milneclass FunctionsFromProtos(test.TestCase):
103300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
103400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def expectFunctionsEqual(self, func, grad_func=None, new_func=None):
103500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    if new_func is None:
103600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      # Make a copy of func.definition to avoid any bugs masked by using the
103700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      # same object
103800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      serialized_fdef = func.definition.SerializeToString()
103900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      # Serialize and then deserialize `func` to create `new_func`
104000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      fdef = function_pb2.FunctionDef.FromString(serialized_fdef)
104100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      new_func = function._from_definition(fdef, grad_func=grad_func)
104200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(func.name, new_func.name)
104300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(func.definition, new_func.definition)
104400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(func.grad_func_name, new_func.grad_func_name)
104500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(func.declared_input_types, new_func.declared_input_types)
104600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(func.captured_inputs, new_func.captured_inputs)
104700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
104800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def testBasic(self):
1049689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
105000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, dtypes.float32)
105100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def Foo(x, y):
105200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return x + y
1053689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
105400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.expectFunctionsEqual(Foo)
105500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
105600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def testGradFunc(self):
1057689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
105800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, dtypes.float32)
105900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def G(x, dy):
106000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return x * dy
106100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
106200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, grad_func=G)
106300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def F(x):
106400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return math_ops.exp(x) - math_ops.exp(-x)
1065689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
106600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.expectFunctionsEqual(F, grad_func=G)
106700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
106800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def testCapturedInputs(self):
106900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    c = constant_op.constant(10, dtypes.int64)
1070689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
107100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.int64)
107200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def Foo(x):
107300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return x + c
107400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
107500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    new_func = function._from_definition(Foo.definition)
107600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
107700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(Foo.name, new_func.name)
107800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(Foo.definition, new_func.definition)
107900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(Foo.grad_func_name, new_func.grad_func_name)
108000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
108100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # Captured inputs are added as regular inputs to the function definition
108200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(new_func.declared_input_types,
108300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne                     Foo.declared_input_types + (dtypes.int64,))
108400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(len(new_func.captured_inputs), 0)
108500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
108600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def testNestedFunctions(self):
1087689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
108800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32)
108900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def Outer(x):
109000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
109100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      @function.Defun(dtypes.float32)
109200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      def Inner(y):
109300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne        return y + 1
109400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
109500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return Inner(Inner(x))
109600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
109700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.expectFunctionsEqual(Outer)
109800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
109900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def testFromLibrary(self):
110000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # Define some functions with different gradient functions. Note that many of
110100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # the below functions are identical since function bodies don't matter for
110200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # this test.
110300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
110400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, dtypes.float32)
110500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def G1(x, dy):
110600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return x * dy
110700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
110800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, dtypes.float32)
110900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def G2(x, dy):
111000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return x * dy
111100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
111200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # F1 and F2 have the same gradient function
111300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, grad_func=G1)
111400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def F1(x):
111500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return math_ops.exp(x) - math_ops.exp(-x)
111600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
111700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, grad_func=G1)
111800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def F2(x):
111900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return math_ops.exp(x) - math_ops.exp(-x)
112000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
112100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # F3 has a different gradient function
112200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, grad_func=G2)
112300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def F3(x):
112400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return math_ops.exp(x) - math_ops.exp(-x)
112500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
112600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # F4 has no gradient function
112700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32)
112800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def F4(x):
112900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return math_ops.exp(x) - math_ops.exp(-x)
113000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
113100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # Instantiate all functions
113200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    g = ops.Graph()
113300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    with g.as_default():
113400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      c = constant_op.constant(1.0, dtypes.float32)
113500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      f1 = F1(c)
113600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      f2 = F2(c)
113700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      f3 = F3(c)
113800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      f4 = F4(c)
113900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      gradients_impl.gradients([f1, f2, f3, f4], c)
114000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
114100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library = g.as_graph_def().library
114200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    new_funcs = function._from_library(library)
114300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
114400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def CheckNewFunc(func):
114500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      new_func = [f for f in new_funcs if f.name == func.name]
114600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      self.assertEqual(len(new_func), 1)
114700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      self.expectFunctionsEqual(func, new_func=new_func[0])
114800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
114900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    CheckNewFunc(G1)
115000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    CheckNewFunc(G2)
115100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    CheckNewFunc(F1)
115200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    CheckNewFunc(F2)
115300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    CheckNewFunc(F3)
115400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    CheckNewFunc(F4)
115500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
115600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def testFromLibraryEmptyLib(self):
115700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library = function_pb2.FunctionDefLibrary()
115800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    self.assertEqual(len(function._from_library(library)), 0)
115900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
116000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def testFromLibraryMissingFuncDef(self):
1161689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
116200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32, dtypes.float32)
116300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def G1(x, dy):
116400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return x * dy
116500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
116600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32)
116700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def F1(x):
116800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return math_ops.exp(x) - math_ops.exp(-x)
116900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
117000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient = function_pb2.GradientDef()
117100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient.function_name = F1.name
117200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient.gradient_func = G1.name
117300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
117400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # Create invalid function def that is missing G1 function def
117500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library = function_pb2.FunctionDefLibrary()
117600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library.gradient.extend([gradient])
117700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library.function.extend([F1.definition])
117800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
117900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    with self.assertRaisesRegexp(
11802198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev        ValueError,
11812198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev        "FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"):
118200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      function._from_library(library)
118300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
118400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # Create invalid function def that is missing F1 function def
118500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library = function_pb2.FunctionDefLibrary()
118600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library.gradient.extend([gradient])
118700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library.function.extend([G1.definition])
118800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
118900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    with self.assertRaisesRegexp(
11902198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev        ValueError,
11912198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev        "FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"):
119200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      function._from_library(library)
119300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
119400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne  def testFromLibraryCyclicGradFuncs(self):
1195689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
119600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32)
119700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def F1(x):
119800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return math_ops.exp(x) - math_ops.exp(-x)
119900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
120000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    @function.Defun(dtypes.float32)
120100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    def F2(x):
120200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      return math_ops.exp(x) - math_ops.exp(-x)
120300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
120400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # Create invalid function def library where F1 has gradient function F2 and
120500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    # F2 has gradient function F1
120600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library = function_pb2.FunctionDefLibrary()
120700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library.function.extend([F1.definition, F2.definition])
120800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
120900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient1 = function_pb2.GradientDef()
121000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient1.function_name = F1.name
121100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient1.gradient_func = F2.name
121200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
121300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient2 = function_pb2.GradientDef()
121400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient2.function_name = F2.name
121500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    gradient2.gradient_func = F1.name
121600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
121700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    library.gradient.extend([gradient1, gradient2])
121800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
121900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne    with self.assertRaisesRegexp(
122000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne        ValueError, "FunctionDefLibrary contains cyclic gradient functions!"):
122100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne      function._from_library(library)
122200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
122300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne
12242cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api
122558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass FunctionOverloadTest(test.TestCase):
12267a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12277a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower  def testBasic(self):
12287a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12297a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    @function.Defun()
12307a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    def Sinh(x):
123158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return 1 / 2. * (math_ops.exp(x) - math_ops.exp(-x))
12327a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
123358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
12347a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    with g.as_default():
123558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = Sinh(constant_op.constant(0.25, dtypes.float32))
123658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      y = Sinh(constant_op.constant(0.25, dtypes.float64))
12377a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12387a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    with self.test_session(graph=g):
12397a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertAllClose(x.eval(), np.sinh(0.25))
12407a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      self.assertAllClose(y.eval(), np.sinh(0.25))
12417a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12427a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower  def testGradient(self):
12437a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12447a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    @function.Defun(func_name="Spec")
12457a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    def G(x, dy):
12467a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      return x * dy
12477a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12487a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    @function.Defun(grad_func=G)
12497a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    def F(x):
125058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return math_ops.exp(x) - math_ops.exp(-x)
12517a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
125258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    for dtype in [dtypes.float32, dtypes.float64]:
125358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      g = ops.Graph()
12547a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      with g.as_default():
125558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        x = constant_op.constant(0.25, dtype)
12567a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower        y = F(x)
125758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        dx, = gradients_impl.gradients(y, x)
12587a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12597a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower        with self.test_session(graph=g):
12607a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower          self.assertAllClose(dx.eval(), 0.25)
12617a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12627a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower  def testDocString(self):
12637a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12647a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    @function.Defun()
12657a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    def Foo(x):
12667a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      """Successor of x."""
12677a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      return x + 1
12687a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
126958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
12707a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    with g.as_default():
12717a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower      _ = Foo(1)
12727a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12737a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower    self.assertEqual(g.as_graph_def().library.function[0].signature.description,
12747a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower                     "Successor of x.")
12757a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12767a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower
12772cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api
1278b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlowerclass FunctionCaptureByValueTest(test.TestCase):
1279b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1280b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower  def testCaptureByValue(self):
1281b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower    g = ops.Graph()
1282b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower    with g.as_default():
1283b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower      w = constant_op.constant([[1.0]])
1284b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower      b = constant_op.constant([2.0])
1285b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1286b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower      # Foo() captures w and b.
1287b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower      @function.Defun(dtypes.float32, capture_by_value=True)
1288b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower      def Foo(x):
1289b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1290b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower        # Plus() captures b.
1291b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower        @function.Defun(dtypes.float32, capture_by_value=True)
1292b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower        def Plus(y):
1293b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower          return y + b
1294b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1295b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower        self.assertEqual(0, len(Plus.captured_inputs))
1296b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1297b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower        return Plus(math_ops.matmul(w, x))
1298b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1299b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower      y = Foo(constant_op.constant([[10.]]))
1300b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1301b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower    self.assertEqual(0, len(Foo.captured_inputs))
1302b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1303b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower    with self.test_session(graph=g):
1304b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower      self.assertAllEqual(y.eval(), [[12.0]])
1305b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
1306b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower
13072cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api
130858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass UnrollLSTMTest(test.TestCase):
13093f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower  BATCH_SIZE = 16
13103f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower  LSTM_DIMS = 32
1311e27da590fec8eed886ecd1cf3d0c2575dffeaa09A. Unique TensorFlower  NUM_UNROLL = 20
13123f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
13133f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower  def _Weights(self):
13143f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    dims = self.LSTM_DIMS
131558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    return random_ops.random_uniform([2 * dims, 4 * dims], -1, 1, seed=123456)
13163f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
13173f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower  def _Input(self):
131858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    return random_ops.random_uniform(
1319ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower        [self.NUM_UNROLL, self.BATCH_SIZE, self.LSTM_DIMS], seed=654321)
13203f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
1321084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  # Helper to construct a LSTM cell graph.
1322084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  @classmethod
1323084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  def LSTMCell(cls, x, mprev, cprev, weights):
13240e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower    xm = array_ops.concat([x, mprev], 1)
132558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    i_i, i_g, f_g, o_g = array_ops.split(
132658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        value=math_ops.matmul(xm, weights), num_or_size_splits=4, axis=1)
132758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    new_c = math_ops.sigmoid(f_g) * cprev + math_ops.sigmoid(
132858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        i_g) * math_ops.tanh(i_i)
132958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    new_c = clip_ops.clip_by_value(new_c, -50.0, 50.0)
133058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    new_m = math_ops.sigmoid(o_g) * math_ops.tanh(new_c)
1331084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower    return new_m, new_c
1332084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
13333f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower  def _BuildForward(self, weights, inp, mode="cell"):
13343f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
13353f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    def Loop(cell, w, i):
133658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = array_ops.unstack(i, self.NUM_UNROLL)
133758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      m = array_ops.zeros_like(x[0])
133858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      c = array_ops.zeros_like(x[0])
13393f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      for i in range(self.NUM_UNROLL):
13403f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        m, c = cell(x[i], m, c, w)
13413f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      return m
13423f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
13433f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    cell = UnrollLSTMTest.LSTMCell
13443f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    if mode == "complete":
13453f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      # Constructs the complete graph in python.
13463f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      return Loop(cell, weights, inp)
13473f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
134858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    cell = function.Defun(dtypes.float32, dtypes.float32, dtypes.float32,
134958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney                          dtypes.float32)(cell)
13503f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    if mode == "cell":
13513f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      # Just represent the LSTM as a function.
13523f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      return Loop(cell, weights, inp)
13533f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
13543f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    if mode == "loop":
13553f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      # Wraps the whole loop as a function.
135658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      @function.Defun(dtypes.float32, dtypes.float32)
13573f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      def LSTMLoop(w, i):
13583f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        return Loop(cell, w, i)
13593f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
13603f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      return LSTMLoop(weights, inp)
13613f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
13623f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    if mode == "loop10":
13633f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      # Wraps 10 lstm steps into one function, and the whole loop
13643f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      # into another calling the formers.
13653f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
136619b7fd80780d02372f76076bc8eb40d55a89a301A. Unique TensorFlower      # Groups 10 steps at a time.
136758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      @function.Defun(dtypes.float32, dtypes.float32, dtypes.float32,
136858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney                      *([dtypes.float32] * 10))
136919b7fd80780d02372f76076bc8eb40d55a89a301A. Unique TensorFlower      def Loop10(w, m, c, *args):
137019b7fd80780d02372f76076bc8eb40d55a89a301A. Unique TensorFlower        for x in args:
13713f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower          m, c = cell(x, m, c, w)
13723f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        return m, c
13733f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
137458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      @function.Defun(dtypes.float32, dtypes.float32)
13753f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      def LSTMLoop10(weights, inp):
137658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        x = array_ops.unstack(inp, self.NUM_UNROLL)
137758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        m = array_ops.zeros_like(x[0])
137858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        c = array_ops.zeros_like(x[0])
13793f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        assert self.NUM_UNROLL % 10 == 0
13803f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        for i in range(0, self.NUM_UNROLL, 10):
13813f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower          m, c = Loop10(weights, m, c, *x[i:i + 10])
13823f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        return m
13833f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
13843f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower      return LSTMLoop10(weights, inp)
138566c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower
1386084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  def testUnrollLSTM(self):
138766c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower    # Run one step of the unrolled lstm graph.
1388e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower    def RunForward(mode, cfg=None):
138958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      tf_logging.info("mode = %s", mode)
139058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      g = ops.Graph()
139166c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower      start = time.time()
139266c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower      with g.as_default():
13933f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        weights = self._Weights()
13943f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        inp = self._Input()
13953f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        m = self._BuildForward(weights, inp, mode)
139666c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower      gdef = g.as_graph_def()
139766c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower      finish = time.time()
139858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
1399d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower                      len(str(gdef)), len(gdef.SerializeToString()))
140058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with g.as_default(), session.Session(config=cfg) as sess:
14013f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        return sess.run(m)
14023f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
14033f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    mv0 = RunForward("complete")
1404f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower    for cfg in _OptimizerOptions():
140558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      tf_logging.info("cfg = %s", cfg)
1406e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      mv1 = RunForward("cell", cfg)
1407e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      mv2 = RunForward("loop", cfg)
1408e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      mv3 = RunForward("loop10", cfg)
1409e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      self.assertAllClose(mv0, mv1, rtol=1e-4)
1410e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      self.assertAllClose(mv0, mv2, rtol=1e-4)
1411e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      self.assertAllClose(mv0, mv3, rtol=1e-4)
141266c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower
1413084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower  def testUnrollLSTMGrad(self):
1414084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower    # Run one step of the unrolled lstm graph.
1415e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower    def RunForwardBackward(mode, cfg=None):
141658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      tf_logging.info("mode = %s", mode)
141758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      g = ops.Graph()
1418084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      start = time.time()
14190c6c4b93dedf6a2c654e012a4ffe1df642834419A. Unique TensorFlower      with g.as_default():
14203f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        weights = self._Weights()
14213f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        inp = self._Input()
14223f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        m = self._BuildForward(weights, inp, mode)
142358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        loss = math_ops.reduce_sum(math_ops.square(m))
142458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        dw = gradients_impl.gradients([loss], [weights])
1425084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      gdef = g.as_graph_def()
1426084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower      finish = time.time()
142758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
1428d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower                      len(str(gdef)), len(gdef.SerializeToString()))
142958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with g.as_default(), session.Session(config=cfg) as sess:
14303f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower        return sess.run(dw)
14313f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower
14323f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower    d0 = RunForwardBackward("complete")
1433f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower    for cfg in _OptimizerOptions():
143458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      tf_logging.info("cfg = %s", cfg)
1435e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      d1 = RunForwardBackward("cell", cfg)
1436e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      d2 = RunForwardBackward("loop", cfg)
1437e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower      d3 = RunForwardBackward("loop10", cfg)
1438439d5bd72e879e40308f7197a33a1a20625a673bGunhan Gulsoy      self.assertAllClose(d0, d1, rtol=1e-4, atol=1e-4)
1439439d5bd72e879e40308f7197a33a1a20625a673bGunhan Gulsoy      self.assertAllClose(d0, d2, rtol=1e-4, atol=1e-4)
1440439d5bd72e879e40308f7197a33a1a20625a673bGunhan Gulsoy      self.assertAllClose(d0, d3, rtol=1e-4, atol=1e-4)
1441084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower
1442d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower
14432cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api
144458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass FunctionInlineControlTest(test.TestCase):
14450cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower
14460cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower  def testFoo(self):
144758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    dtype = dtypes.float32
144858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
144958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        optimizer_options=config_pb2.OptimizerOptions(
145058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney            opt_level=config_pb2.OptimizerOptions.L0,
14510cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower            do_common_subexpression_elimination=True,
14520cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower            do_function_inlining=True,
14530cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower            do_constant_folding=True)))
1454c0f6357c4a080edb10dd089151dd523834ea80fcA. Unique TensorFlower    cell_func_call_pattern = re.compile(r"Cell[^/]*\(")
14550cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower    for noinline in [False, True]:
1456ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
1457adc15a46c56070ef92b890b65b6de0147abeff80A. Unique TensorFlower      @function.Defun(dtype, noinline=noinline)
14581d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      def Cell(v):
14591d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower        # If v is a vector [n, 1], x is a big square matrix.
146058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        x = math_ops.tanh(v + array_ops.transpose(v, [1, 0]))
1461b9f548d041ba8d66102c6d195e645051f1bee52fYukun Chen        return math_ops.reduce_sum(x, 1, keepdims=True)
14620cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower
14631d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      @function.Defun(dtype)
14641d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      def Forward(x):
14651d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower        for _ in range(10):
1466bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower          # pylint: disable=cell-var-from-loop
146775b07290907abf759f05d8ea087f4623de316db0A. Unique TensorFlower          x = Cell(x)
146858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        return math_ops.reduce_sum(x, [0, 1])
14690cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower
1470aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov      self.assertEqual(noinline, Cell.definition.attr["_noinline"].b)
1471aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov
147258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      g = ops.Graph()
14731d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      with g.as_default():
147458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        x = array_ops.placeholder(dtype)
14750cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower        y = Forward(x)
147658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney        dx, = gradients_impl.gradients([y], [x])
14770cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower
147822cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro      np.random.seed(321)
147922cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro      inp = np.random.uniform(-1, 1, [16, 1]).astype(np.float32)
148058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      run_metadata = config_pb2.RunMetadata()
148158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with session.Session(graph=g, config=cfg) as sess:
1482689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower        ans = sess.run(
1483689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower            [y, dx], {x: inp},
1484689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower            run_metadata=run_metadata,
1485689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower            options=config_pb2.RunOptions(
1486689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower                trace_level=config_pb2.RunOptions.FULL_TRACE))
148722cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro        print(ans[0], np.sum(ans[1]))
148822cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro        self.assertAllClose(ans[0], 255.971, rtol=1e-3)
148922cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro        self.assertAllClose(np.sum(ans[1]), 13.0408, rtol=1e-3)
14900cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower
1491aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov      def MetadataHasCell(run_metadata):
1492aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov        for dev_stats in run_metadata.step_stats.dev_stats:
1493aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov          for node_stats in dev_stats.node_stats:
1494c0f6357c4a080edb10dd089151dd523834ea80fcA. Unique TensorFlower            if cell_func_call_pattern.search(node_stats.timeline_label):
1495aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov              return True
1496aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov        return False
1497aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov
1498aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov      self.assertEqual(MetadataHasCell(run_metadata), noinline)
1499aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov
15000cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower
150158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney@function.Defun(*[dtypes.float32] * 3)
15021d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlowerdef Linear(w, b, x):
150358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney  return nn_ops.relu(math_ops.matmul(x, w) + b)
15041d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
15051d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
150658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney@function.Defun(*[dtypes.float32] * 5)
15071d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlowerdef Linear2(w1, b1, w2, b2, x):
15081d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower  return Linear(w2, b2, Linear(w1, b1, x))
15091d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
15101d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
15112cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev# Set C API before defining module level functions
15122cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevops._USE_C_API = True
15132cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
15142cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
15152cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@function.Defun(*[dtypes.float32] * 3)
15162cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevdef LinearWithCApi(w, b, x):
15172cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev  return nn_ops.relu(math_ops.matmul(x, w) + b)
15182cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
15192cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
15202cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@function.Defun(*[dtypes.float32] * 5)
15212cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevdef Linear2WithCApi(w1, b1, w2, b2, x):
15222cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev  return LinearWithCApi(w2, b2, LinearWithCApi(w1, b1, x))
15232cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
15242cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
15252cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev# Unset C API after defining module level functions
15262cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevops._USE_C_API = False
15272cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
15282cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
152958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass ModuleFunctionTest(test.TestCase):
15301d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
15311d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower  def testBasic(self):
153258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    with ops.Graph().as_default():
153358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      a, b, c, d, e = [
1534689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5)
153558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      ]
15361d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      y = Linear(a, b, c)
15371d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower      z = Linear2(a, b, c, d, e)
153858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with session.Session() as sess:
15391d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower        self.assertAllEqual([[1]], sess.run(y))
15401d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower        self.assertAllEqual([[5]], sess.run(z))
15411d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower
15422cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev  @test_util.enable_c_api
15432cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev  def testBasicWithCApi(self):
15442cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev    with ops.Graph().as_default():
15452cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev      a, b, c, d, e = [
15462cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev          constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5)
15472cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev      ]
15482cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev      y = LinearWithCApi(a, b, c)
15492cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev      z = Linear2WithCApi(a, b, c, d, e)
15502cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev      with session.Session() as sess:
15512cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev        self.assertAllEqual([[1]], sess.run(y))
15522cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev        self.assertAllEqual([[5]], sess.run(z))
15532cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev
1554ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower
15552cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api
155658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass VariableHoistingTest(test.TestCase):
1557d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
15586237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos  def _testSimpleModel(self, use_forward_func, use_resource=False):
1559d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1560d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    def _Model(x):
156158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      w = variable_scope.get_variable(
156258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          "w", (64, 64),
15636237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos          initializer=init_ops.random_uniform_initializer(seed=312),
15646237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos          use_resource=use_resource)
156558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      b = variable_scope.get_variable(
1566689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          "b", (64),
1567689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower          initializer=init_ops.zeros_initializer(),
15686237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos          use_resource=use_resource),
156958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      return math_ops.sigmoid(math_ops.matmul(x, w) + b)
1570d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1571d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    @function.Defun()
1572d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    def Model(x):
1573d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower      return _Model(x)
1574d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1575d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    cvars = []
1576d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1577d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    @function.Defun()
1578d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    def Grad(x, y0):
1579d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower      if use_forward_func:
1580d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower        y = Model(x)
1581d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower      else:
1582d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower        y = _Model(x)
158358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      loss = math_ops.reduce_mean(
158458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney          math_ops.reduce_sum(y0 * math_ops.log(y), 1), 0)
1585b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower      arg_w, arg_b = function.get_extra_args()
158658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      self.assertEqual(arg_w.get_shape(), tensor_shape.TensorShape([64, 64]))
158758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      self.assertEqual(arg_b.get_shape(), tensor_shape.TensorShape([64]))
158858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      dw, db = gradients_impl.gradients(loss, [arg_w, arg_b])
1589d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower      cvars.extend(function.get_extra_vars())
1590d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower      return loss, dw, db
1591d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
159258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney    g = ops.Graph()
1593d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    with g.as_default():
159458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      x = random_ops.random_normal([64, 64], seed=100)
159558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      y0 = random_ops.random_normal([64, 64], seed=200)
159658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      with variable_scope.variable_scope("Foo"):
1597d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower        loss, dw, db = Grad(x, y0)
1598d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1599d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertEqual(2, len(cvars))
1600d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    w, b = cvars[:2]
1601d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertEqual("Foo/w", w.op.name)
1602d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertEqual("Foo/b", b.op.name)
1603d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1604d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    with self.test_session(graph=g) as sess:
160558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney      sess.run(variables.global_variables_initializer())
1606d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower      w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
1607d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1608d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllEqual(w.shape, (64, 64))
1609d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllClose(np.sum(w), 2050.44)
1610d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllEqual(b.shape, (64,))
1611d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllClose(np.sum(b), 0.0)
1612d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllClose(loss, -2.27, rtol=1e-2)
1613d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllEqual(dw.shape, (64, 64))
1614d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllClose(np.sum(dw), -1.04, rtol=1e-2)
1615d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllEqual(db.shape, (64,))
1616d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self.assertAllClose(np.sum(db), 0.509, rtol=1e-2)
1617d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1618d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower  def testBasic(self):
1619d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self._testSimpleModel(True)
1620d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower    self._testSimpleModel(False)
1621d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1622020501c695119f76bba0dd2bb47abfeaa939d669Alexandre Passos  def testBasicResource(self):
16236237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos    self._testSimpleModel(True, use_resource=True)
16246237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos    self._testSimpleModel(False, use_resource=True)
1625d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower
1626689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower
1627d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerif __name__ == "__main__":
162858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney  test.main()
1629