10cf9ed3a719c0782695154d5a0bca260001cec15A. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# 3d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License"); 4d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# you may not use this file except in compliance with the License. 5d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# You may obtain a copy of the License at 6d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# 7d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# http://www.apache.org/licenses/LICENSE-2.0 8d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# 9d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software 10d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS, 11d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# See the License for the specific language governing permissions and 13d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# limitations under the License. 14d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower# ============================================================================= 15d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower"""Tests for functions.""" 16d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 17d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerfrom __future__ import absolute_import 18d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerfrom __future__ import division 19d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerfrom __future__ import print_function 20d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 21c0f6357c4a080edb10dd089151dd523834ea80fcA. Unique TensorFlowerimport re 2220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyenimport sys 231804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murrayimport time 247760ce56fc3ab4ab8cdc408e29d8ad8b539c417eJosh Levenberg 25d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerimport numpy as np 26d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 2700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milnefrom tensorflow.core.framework import function_pb2 2858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.core.protobuf import config_pb2 295e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlowerfrom tensorflow.core.protobuf import rewriter_config_pb2 3058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.client import session 3158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import constant_op 3258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import dtypes 3358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import errors_impl 34d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerfrom tensorflow.python.framework import function 35b876065afe85ef7b8e7334e506af2f84b3a6add1Alexandre Passosfrom tensorflow.python.framework import graph_to_function_def 3658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import ops 3758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.framework import tensor_shape 389624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichevfrom tensorflow.python.framework import test_util 3958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import array_ops 4058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import clip_ops 4158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import control_flow_ops 420a43371c76c2cbf47a93186acececc576b71c06bA. Unique TensorFlowerfrom tensorflow.python.ops import functional_ops 43f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlowerfrom tensorflow.python.ops import gen_logging_ops 4458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import gradients_impl 4558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import init_ops 46c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passosfrom tensorflow.python.ops import linalg_ops 4758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import logging_ops 4858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import math_ops 4958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import nn_ops 5058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import random_ops 5158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import variable_scope 5258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.ops import variables 5358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.platform import test 5458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyfrom tensorflow.python.platform import tf_logging 55d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 56d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 57f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlowerdef _OptimizerOptions(): 58f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower for cse in [False, True]: 59f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower for inline in [False, True]: 60f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower for cfold in [False, True]: 6158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney yield config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( 6258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney optimizer_options=config_pb2.OptimizerOptions( 6358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney opt_level=config_pb2.OptimizerOptions.L0, 64f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower do_common_subexpression_elimination=cse, 65f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower do_function_inlining=inline, 66f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower do_constant_folding=cfold))) 67f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 68f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 692cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api 702cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevclass FunctionTest(test.TestCase): 719624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev """Test methods for verifying Function support. 729624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 739624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev These test methods are used as mix-ins in two test cases: with 749624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev and without C API support. 759624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev """ 769624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 779624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def testIdentity(self): 789624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 799624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev @function.Defun(dtypes.float32, func_name="MyIdentity") 809624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def MyIdentityFunc(a): 819624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev return a 829624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 839624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with ops.Graph().as_default(): 849624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev call = MyIdentityFunc([18.0]) 859624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev self.assertEqual("MyIdentity", call.op.name) 869624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with session.Session() as sess: 879624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev self.assertAllEqual([18.0], sess.run(call)) 889624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 891804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray def testIdentityImplicitDeref(self): 901804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray 911804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray @function.Defun(dtypes.float32, func_name="MyIdentity") 921804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray def MyIdentityFunc(a): 931804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray return a 941804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray 951804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray with ops.Graph().as_default(): 961804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray var = variables.Variable([18.0]) 971804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access 981804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray self.assertEqual("MyIdentity", call.op.name) 991804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray for cfg in _OptimizerOptions(): 1001804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray with session.Session(config=cfg) as sess: 1011804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray sess.run(var.initializer) 1021804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray self.assertAllEqual([18.0], sess.run(call)) 1031804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray 1049624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def testIdentityOutputName(self): 1059624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1069624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev @function.Defun( 1079624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev dtypes.float32, func_name="MyIdentity", out_names=["my_result_name"]) 1089624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def MyIdentityFunc(a): 1099624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev return a 1109624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1119624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with ops.Graph().as_default(): 1129624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev call = MyIdentityFunc([18.0]) 1139624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev self.assertEqual("MyIdentity", call.op.name) 1149624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with session.Session() as sess: 1159624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev self.assertAllEqual([18.0], sess.run(call)) 1169624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1179624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def testTooManyOutputNames(self): 1189624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1199624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev @function.Defun( 1209624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev dtypes.float32, func_name="MyIdentity", 1219624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev out_names=["my_result1", "my_result2"]) 1229624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def MyIdentityFunc(a): 1239624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev return a 1249624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1259624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with ops.Graph().as_default(): 1269624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with self.assertRaisesRegexp( 1272198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev errors_impl.InvalidArgumentError, 1282198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev (r"output names must be either empty or equal in size to outputs. " 1292198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev "output names size = 2 outputs size = 1")): 1309624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev MyIdentityFunc([18.0]) 131d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 132d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def testDefineFunction2Args(self): 133d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 13458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32, func_name="APlus2B") 135d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def APlus2B(a, b): 136d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower return a + b * 2 137d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 13858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 139fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower call = APlus2B([1.0], [2.0]) 1407a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual("APlus2B", call.op.name) 14158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session() as sess: 142d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower self.assertAllEqual([5.0], sess.run(call)) 143d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 1442198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev def testFunctionWithNoOutput(self): 1459624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1469624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev @function.Defun(dtypes.float32, dtypes.float32) 1479624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def APlus2B(a, b): 1482198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev c = a + b * 2 # Create some ops to have nodes in the body 1492198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev print(c) # Using 'print' to make lint happy 1509624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1519624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with ops.Graph().as_default(): 1522198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev # Call function. There should be no exceptions. 1532198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev APlus2B([1.0], [2.0]) 1549624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1559624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def testDefineFunction2ArgsOutputName(self): 1569624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1579624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev @function.Defun( 1589624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev dtypes.float32, 1599624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev dtypes.float32, 1609624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev func_name="APlus2B", 1619624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev out_names=["my_result_name"]) 1629624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev def APlus2B(a, b): 1639624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev return a + b * 2 1649624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1659624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with ops.Graph().as_default(): 1669624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev call = APlus2B([1.0], [2.0]) 1679624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev self.assertEqual("APlus2B", call.op.name) 1689624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev with session.Session() as sess: 1699624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev self.assertAllEqual([5.0], sess.run(call)) 1709624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev 1712d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower def testDefineFunctionDuplicateOutputs(self): 1722d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower 17358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, func_name="Duplicate") 1742d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower def Duplicate(a): 1752d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower b = a + 1.0 1762d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower return b, b 1772d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower 17858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 1792d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower with g.as_default(): 1802d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower Duplicate([3.0]) 1812d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower func_sig = g.as_graph_def().library.function[0].signature 1822d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower # The names given to both outputs should be different 1832d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower # even though the same tensor is emitted to both. 1842d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower out_names = [a.name for a in func_sig.output_arg] 1852d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower self.assertEqual(2, len(out_names)) 1862d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower self.assertNotEqual(out_names[0], out_names[1]) 1872d00e6f17df644077af331e5bcb47a0e8a0fa1b7A. Unique TensorFlower 188d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def testGradientFunc(self): 189d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 19058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, func_name="XSquarePlusOneFn") 191d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def XSquarePlusOne(x): 192d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower return x * x + 1.0 193d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 19458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32) 195d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def XSquarePlusOneGrad(x, dy): 1966bf83b7061df648d4751bc443782348fc8ea5c17Eugene Brevdo dx = functional_ops._symbolic_gradient( 19758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney input=[x, dy], Tout=[dtypes.float32], f="XSquarePlusOneFn", name="dx") 198d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower return dx 199d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 20058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 201d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with g.as_default(): 202fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower call_f = XSquarePlusOne([2.0]) 203fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower call_g = XSquarePlusOneGrad([2.0], [0.1]) 204d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 20558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session() as sess: 206d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower self.assertAllClose([5.0], sess.run(call_f)) 207d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower self.assertAllClose([0.4], sess.run(call_g)) 208d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 20984f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower def testTanhSymGrad(self): 210ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 21158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 2121d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Forward(x): 21358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return math_ops.reduce_sum(math_ops.tanh(x)) 2141d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 21558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 21684f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower with g.as_default(): 21758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = array_ops.placeholder(dtypes.float32) 21884f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower y = Forward(x) 21958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dx = gradients_impl.gradients([y], [x]) 22084f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower 22184f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower inp = np.array([-1, 1, 2, -2], dtype=np.float32) 22284f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower feed = {x: inp} 22358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( 22458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney optimizer_options=config_pb2.OptimizerOptions( 22558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney opt_level=config_pb2.OptimizerOptions.L1, 22658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney do_function_inlining=True))) 22758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session(graph=g, config=cfg) as sess: 22884f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower out, = sess.run(dx, feed) 22984f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower self.assertAllClose(1 - np.square(np.tanh(inp)), out) 23084f270625050058321116cf1e45664bcebe3f609A. Unique TensorFlower 231f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower def testCustomGradient(self): 23258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dtype = dtypes.float32 233f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 2341d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower @function.Defun(dtype, dtype, dtype) 2351d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def XentLossGrad(logits, labels, dloss): 23658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dlogits = array_ops.reshape(dloss, [-1, 1]) * ( 23758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney nn_ops.softmax(logits) - labels) 23858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dlabels = array_ops.zeros_like(labels) 2391d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower # Takes exp(dlogits) to differentiate it from the "correct" gradient. 24058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return math_ops.exp(dlogits), dlabels 241f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 2421d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower @function.Defun(dtype, dtype, grad_func=XentLossGrad) 2431d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def XentLoss(logits, labels): 24458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return math_ops.reduce_sum(labels * math_ops.log(nn_ops.softmax(logits)), 24558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney 1) 246f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 24758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 2481d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower with g.as_default(): 24958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney logits = array_ops.placeholder(dtype) 25058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney labels = array_ops.placeholder(dtype) 251f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower loss = XentLoss(logits, labels) 25258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dlogits = gradients_impl.gradients([loss], [logits]) 253f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 254f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower x = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32) 255f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower prob = np.exp(x) / np.sum(np.exp(x), 1, keepdims=1) 256f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower y = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32) 257f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower for cfg in _OptimizerOptions(): 25858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney tf_logging.info("cfg = %s", cfg) 25958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session(graph=g, config=cfg) as sess: 260f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower out, = sess.run(dlogits, {logits: x, labels: y}) 261f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower self.assertAllClose(out, np.exp(prob - y)) 262f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 263f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower def testCustomGradientError(self): 26458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dtype = dtypes.float32 265f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 2661d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower @function.Defun(dtype, dtype, dtype) 2671d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Grad(x, dy, dz): 2681d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower # Should have returned 1 result. 2691d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return x, dy + dz 270f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 2711d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower @function.Defun(dtype, grad_func=Grad) 2721d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Forward(x): 2731d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return x, x 274f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 27558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 2761d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower with g.as_default(): 27758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney inp = array_ops.placeholder(dtype) 27858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney out = math_ops.add_n(Forward(inp)) 27958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dinp = gradients_impl.gradients(out, [inp]) 280f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 281f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower x = np.random.uniform(-10., 10., size=(4, 9)).astype(np.float32) 28258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session(graph=g) as sess: 283f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower with self.assertRaisesRegexp( 28458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney errors_impl.InvalidArgumentError, 285f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower "SymGrad expects to return 1.*but get 2.*instead"): 286f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower _ = sess.run(dinp, {inp: x}) 287f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower 288d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def testSymGradShape(self): 28958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 290d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with g.as_default(): 29158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = array_ops.placeholder(dtypes.float32, [25, 4]) 29258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney y = array_ops.placeholder(dtypes.float32, [200, 100]) 29358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dz = array_ops.placeholder(dtypes.float32, [1]) 294d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower # We assume Foo is a function of (x, y) -> (z) Then, Foo's 295d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower # gradient function is (x, y, dz) -> (dx, dy). dx's shape 296d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower # should be the same as x's; and dy's shape should be the same 297d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower # as y's. 298ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower dx, dy = functional_ops._symbolic_gradient( 29958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney input=[x, y, dz], Tout=[dtypes.float32] * 2, f="Foo") 3007a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual(x.get_shape(), dx.get_shape()) 3017a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual(y.get_shape(), dy.get_shape()) 302d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 303f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower def testSymGradAttr(self): 304aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov 305f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower @function.Defun(noinline=True) 306f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower def Foo(x): 307f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower return x * 2 308f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower 309aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov self.assertTrue( 31058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney Foo.instantiate([dtypes.float32]).definition.attr["_noinline"].b) 311aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov 31258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 313f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower with g.as_default(): 31458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = constant_op.constant(3.0) 315f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower y = Foo(x) 31658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dx, = gradients_impl.gradients(y, [x]) 317f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower 31858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( 31958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney optimizer_options=config_pb2.OptimizerOptions( 32058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney opt_level=config_pb2.OptimizerOptions.L0, 321f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower do_common_subexpression_elimination=True, 322f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower do_function_inlining=True, 323f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower do_constant_folding=True))) 324f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower 325f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower with self.test_session(graph=g, config=cfg): 326f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower self.assertAllClose(y.eval(), 6.) 327f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower self.assertAllClose(dx.eval(), 2.) 328f2e46bddc9639b643829778011111932e49b6241A. Unique TensorFlower 3292eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower def _testZNoDepOnY(self, use_const_grad_ys): 33058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32) 331bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower def Foo(x, y): # pylint: disable=unused-argument 3321d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return x * 2 3331d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 33458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 335449ecb561f6b480a6043d23160be00f35b524aa9A. Unique TensorFlower # z = Foo(x, y). z doe 33658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = constant_op.constant(1.0) 33758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney y = constant_op.constant(2.0) 338449ecb561f6b480a6043d23160be00f35b524aa9A. Unique TensorFlower z = Foo(x, y) 3392eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower if use_const_grad_ys: 3402eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower dx, dy = gradients_impl.gradients([z], [x, y], grad_ys=[1.0]) 3412eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower else: 3422eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower dx, dy = gradients_impl.gradients([z], [x, y]) 34358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session() as sess: 344449ecb561f6b480a6043d23160be00f35b524aa9A. Unique TensorFlower dx_val, dy_val = sess.run([dx, dy]) 3457a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual([2.0], dx_val) 3467a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual([0.0], dy_val) 347449ecb561f6b480a6043d23160be00f35b524aa9A. Unique TensorFlower 3482eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower def testZNoDepOnY(self): 3492eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower self._testZNoDepOnY(False) 3502eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower 3512eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower def testZNoDepOnYConstGradYs(self): 3522eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower # Tests for constant folding of grad_ys 3532eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower self._testZNoDepOnY(True) 3542eeb6df0bd7c329163a6a25dd111a25a7b9ad16fA. Unique TensorFlower 355d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def testDefineFunctionNoArgs(self): 356d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 3577a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower @function.Defun(func_name="AConstant") 358d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def AConstant(): 35958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return constant_op.constant([42]) 360d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 36158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 362e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower 363e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower call = AConstant() 3647a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual("AConstant", call.op.name) 36558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session() as sess: 366d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower self.assertAllEqual([42], sess.run(call)) 367d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 368d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def testDefineFunctionNames(self): 369d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 37058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, func_name="Foo") 371d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def Foo(a): 372d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower return a + 1 373d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 37458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 375fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower call1 = Foo([1.0]) 3767a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual("Foo", call1.op.name) 377fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower call2 = Foo([1.0]) 3787a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual("Foo_1", call2.op.name) 379bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower # pylint: disable=unexpected-keyword-arg 380fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower call3 = Foo([1.0], name="mine") 3817a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual("mine", call3.op.name) 38258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.name_scope("my"): 383fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower call4 = Foo([1.0], name="precious") 3847a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual("my/precious", call4.op.name) 385d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 386f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower def testNoOp(self): 387fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower 38858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 389f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower def Foo(x): 390f28ae398cc5b875b936ad6e5cd4d280928c38409Olivia Nordquist y = logging_ops.Print(x, [], "Hello") 39158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.control_dependencies([y]): 39258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney z = control_flow_ops.no_op() 39358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.control_dependencies([z]): 394f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower return x * 2 395f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower 39658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(), self.test_session(): 39758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney z = Foo(constant_op.constant(3.0)) 398f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower self.assertAllEqual(z.eval(), 6.0) 399f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower 40023c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower def testAssertOp(self): 401fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower 40258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 403f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower def Foo(x): 40458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney check = gen_logging_ops._assert(math_ops.greater(x, 0), [x]) 40558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.control_dependencies([check]): 406f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower return x * 2 407f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower 40858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 409f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower with g.as_default(), self.test_session(): 41058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0) 41158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 412f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower "assertion failed.*-3"): 41358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0) 414f61c9d02f769bdf701d94cc09ecc94f21027f2aeA. Unique TensorFlower 4159624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API 41623c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower def testAssertWrapper(self): 4177a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 41858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 41923c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower def MyFn(x): 42058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.control_dependencies( 42158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney [control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]): 42258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return array_ops.identity(x) 42323c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower 42423c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower with self.test_session(): 4257a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual(1.0, MyFn(1.0).eval()) 42658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 42758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney "assertion"): 42823c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower _ = MyFn(100.0).eval() 42923c01e814b1eb227a165b548f9038592316678f8A. Unique TensorFlower 4309624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API 431689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower def testWhileLoopCallsFunc(self): 432689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower with self.test_session(use_gpu=True) as sess: 433689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 434689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower @function.Defun(dtypes.float32) 435689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower def Times2(x): 436689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower constant_two = constant_op.constant(2, dtypes.int32) 437689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower two_on_gpu = math_ops.cast(constant_two, dtypes.float32) 438689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower return x * two_on_gpu 439689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 440689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower def Body(x): 441689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower x2 = Times2(x) 442689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower x2.set_shape([]) 443689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower return x2 444689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 445689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower loop = control_flow_ops.while_loop(lambda x: x < 1e5, Body, [1.0]) 446689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 447689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower ans = sess.run(loop) 448689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower self.assertAllClose(ans, 131072.) 449689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 4509624d165f1f2c717eda96464fee8bf7229cc14f5Igor Ganichev @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API 451e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins def testControlFlowStrictness(self): 452e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins """Inlined functions must not execute in a untaken control flow branch.""" 453e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins 454e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins @function.Defun(dtypes.int32) 455e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins def AssertFail(x): 456e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins # Assertion that always fails and does not have a data dependency on `x`. 457e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins assert_false = control_flow_ops.Assert(False, [42]) 458e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins with ops.control_dependencies([assert_false]): 459e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins return array_ops.identity(x) 460e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins 461e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins with ops.device("CPU"): 462e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins pred = array_ops.placeholder(dtypes.bool) 463e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins x = array_ops.placeholder(dtypes.int32) 464e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins cond = control_flow_ops.cond(pred, lambda: x + 1, lambda: AssertFail(x)) 465e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins # pylint: disable=unnecessary-lambda 466e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins loop = control_flow_ops.while_loop(lambda y: pred, 467e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins lambda y: AssertFail(y), [x]) 468e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins # pylint: enable=unnecessary-lambda 469e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins 4705e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower rewriter_config = rewriter_config_pb2.RewriterConfig( 4715e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) 472e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins # Enables inlining. 4735e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower config = config_pb2.ConfigProto( 4745e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower graph_options=config_pb2.GraphOptions( 4755e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower optimizer_options=config_pb2.OptimizerOptions( 4765e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower opt_level=config_pb2.OptimizerOptions.L0, 4775e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower do_common_subexpression_elimination=True, 4785e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower do_function_inlining=True, 4795e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower do_constant_folding=True), 4805e3ed99469d32e494869b8d044620b2ef8e96a40A. Unique TensorFlower rewrite_options=rewriter_config)) 481e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins 482e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins with session.Session(config=config) as sess: 483e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins # Since the 'False' branch is not taken, the assertion should not fire. 484e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins self.assertEqual(4, sess.run(cond, {pred: True, x: 3})) 485e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins 486e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins # The assertion should still fire if the False branch is taken. 487e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 488e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins "assertion"): 489e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins sess.run(cond, {pred: False, x: 3}) 490e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins 491e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins # Similarly for loops. 492e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins self.assertEqual(3, sess.run(loop, {pred: False, x: 3})) 493e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 494e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins "assertion"): 495e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins sess.run(loop, {pred: True, x: 3}) 496e321ae5e32859b459879706fea931f4b352be7f6Peter Hawkins 497fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower def testVar(self): 498fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower 49958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 500fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower def Foo(x): 501fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower return x * x + 1 502fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower 50358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 504fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower with g.as_default(): 50558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney v = variables.Variable(constant_op.constant(10.0)) 506fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower z = Foo(v) 507fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower 508fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower with self.test_session(graph=g): 50958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney variables.global_variables_initializer().run() 510fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower self.assertAllEqual(z.eval(), 101.) 511fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower 5124d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower def testResourceVarAsImplicitInput(self): 5134d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower g = ops.Graph() 5140a667240282f6215b36811475245509254d0127eAlexandre Passos with g.as_default(), ops.device("cpu:0"): 5154d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower v = variable_scope.get_variable( 5164d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower "var", (4, 4), dtypes.float32, use_resource=True) 5174d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower 5184d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower @function.Defun() 5194d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower def Foo(): 5204d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower return array_ops.identity(v) 5214d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower 5224d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower y = v.value() 5234d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower z = Foo() 5244d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower 5254d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower with self.test_session(graph=g): 5264d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower v.initializer.run() 5274d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower self.assertAllEqual(y.eval(), z.eval()) 5284d35ecb853f5e9c7d41df3629a31dcbc7b4032caA. Unique TensorFlower 529d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def testDefineErrors(self): 53058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 5317a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "can not return None"): 532d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 533e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower @function.Defun() 5347a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower def TwoNone(): 5357a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower return None, None 5367a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 5377a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower _ = TwoNone.definition 5387a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 539e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "are not supported"): 540d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 541e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower @function.Defun() 542e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower def DefaultArg(unused_a=12): 54358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return constant_op.constant([1]) 544d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 545e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower _ = DefaultArg.definition 54619b7fd80780d02372f76076bc8eb40d55a89a301A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "are not supported"): 547e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower 548e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower @function.Defun() 549e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower def KwArgs(**unused_kwargs): 55058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return constant_op.constant([1]) 551e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower 552e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower _ = KwArgs.definition 553d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "specified input types"): 554e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower 55558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 556e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower def PlusMinusV2(a, b): 557e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower return a + b, b - a 558e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower 559e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower _ = PlusMinusV2.definition 560a3c34f649d6f3d6c188cc59fe884facf4aad117eA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "specified input types"): 561e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower 56258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32, dtypes.float32) 563e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower def PlusMinusV3(a, b): 564e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower return a + b, b - a 565e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower 566e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower _ = PlusMinusV3.definition 567d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 568d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def testCallErrors(self): 569d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 570e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower @function.Defun() 571d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def Const(): 57258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return constant_op.constant(1) 573d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 57458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.int32) 575d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def PlusOne(a): 576d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower return a + 1 577d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 57858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.int32, dtypes.int32) 579d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower def PlusMinus(a, b): 580d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower return a + b, b - a 581d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 58258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 583d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 584e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower _ = Const() 585bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower # pylint: disable=too-many-function-args 586bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower # pylint: disable=unexpected-keyword-arg 587bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower # pylint: disable=no-value-for-parameter 588d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "arguments: 0"): 589fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower _ = Const(1) 590d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "arguments: 0"): 591fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower _ = Const(1, 2) 592d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 593d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "arguments: 1"): 594e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower _ = PlusOne() 595fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower _ = PlusOne(1) 596d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "arguments: 1"): 597fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower _ = PlusOne(1, 2) 598d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 599d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "arguments: 2"): 600e5b7e1e846f3e35d90a6bb260284b041d0036059A. Unique TensorFlower _ = PlusMinus() 601d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "arguments: 2"): 602fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower _ = PlusMinus(1) 603fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower _ = PlusMinus(1, 2) 604d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 605fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower _ = PlusOne(1, name="p1") 606d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "Unknown keyword arguments"): 60728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower _ = PlusOne(1, device="/device:GPU:0") 608d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 6091d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def testFunctionDecorator(self): 610ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 61158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, func_name="Minus1") 6121d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Minus1(b): 6131d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return b - 1.0 614a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower 61558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 616fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower call1 = Minus1([2.]) 6176bf83b7061df648d4751bc443782348fc8ea5c17Eugene Brevdo self.assertTrue(isinstance(Minus1, function._DefinedFunction)) 6186bf83b7061df648d4751bc443782348fc8ea5c17Eugene Brevdo self.assertEqual(Minus1.name, "Minus1") 619d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower # pylint: disable=unexpected-keyword-arg 620d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower call2 = Minus1(call1, name="next") 6211d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower # pylint: enable=unexpected-keyword-arg 6227a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual("next", call2.op.name) 62358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session() as sess: 624d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower self.assertAllEqual([1], sess.run(call1)) 625d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower self.assertAllEqual([0], sess.run(call2)) 626d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 627a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower def testNestedFunction(self): 628ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 62958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 6301d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Cube(x): 6311d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return x * x * x 6321d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 63358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32) 6341d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def CubeXPlusY(x, y): 6351d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return Cube(x) + y 6361d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 63758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 638fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower z = CubeXPlusY(3.0, -2.0) 6391d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower with self.test_session(): 6401d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower self.assertAllEqual(z.eval(), 25.0) 6413f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 6421d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def testNestedDefinedFunction(self): 643ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 64458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32) 6451d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def CubeXPlusY(x, y): 646ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 64758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 648a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower def Cube(x): 649a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower return x * x * x 650ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 6511d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return Cube(x) + y 652ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 65358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 654fa8a4a110d64c71339a33f4d2e67f76ef997afc0A. Unique TensorFlower z = CubeXPlusY(3.0, -2.0) 655a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower with self.test_session(): 656a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower self.assertAllEqual(z.eval(), 25.0) 657a84a81a7379507f8fcdd0d6118afc2d5044d159eA. Unique TensorFlower 6581d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def testUnusedFunction(self): 6591d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower invoked = False 6601d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower # pylint: disable=unused-variable 6611d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower @function.Defun() 6621d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Unused(): 6631d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower invoked = True 66458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return constant_op.constant(42.) 665ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 6661d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower self.assertFalse(invoked) 66758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 6681d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower with g.as_default(): 669ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 6701d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower @function.Defun() 6711d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Unused2(): 6721d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower invoked = True 67358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return constant_op.constant(7.) 674ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 67558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney constant_op.constant(3.) 6761d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower # pylint: enable=unused-variable 6771d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower self.assertFalse(invoked) 6781d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower gdef = g.as_graph_def() 6797a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual(0, len(gdef.library.function)) 6801d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 6812c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower def testReduction(self): 68258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 6832c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower 6842c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower # BN0 is computing batch normed matrix along rows. 6852c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower def BN0(x): 68658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney mean = math_ops.reduce_mean(x, [0]) 68758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney var = math_ops.reduce_mean(math_ops.square(x - mean)) # biased var 68858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney rstd = math_ops.rsqrt(var + 1e-8) 6892c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower return (x - mean) * rstd 6902c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower 6911d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower # Wraps BatchNorm in a tf function. 69258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 6931d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def BN1(x): 6941d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return BN0(x) 6951d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 6961d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower with g.as_default(): 69758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = array_ops.placeholder(dtypes.float32) 6982c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower y0 = BN0(x) # A plain graph 6992c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower y1 = BN1(x) # A tf function 70058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dx0, = gradients_impl.gradients([y0], [x]) 70158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dx1, = gradients_impl.gradients([y1], [x]) 7021d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 7032c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower # Both should produce the same result and gradient. 7042c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower with self.test_session(graph=g) as sess: 7052c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower vals = sess.run([y0, y1, dx0, dx1], {x: np.random.uniform(size=(3, 7))}) 7062c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower self.assertAllClose(vals[0], vals[1]) 7072c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower self.assertAllClose(vals[2], vals[3]) 7082c46a9ed5f6faaf9f3a407582dd657efa91c24eaA. Unique TensorFlower 709b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower def testCapture(self): 71058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 711b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower with g.as_default(): 71258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney w = variables.Variable(constant_op.constant([[1.0]])) 71358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney b = variables.Variable(constant_op.constant([2.0])) 714b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 715b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower # Foo() captures w and b. 71658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 717b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower def Foo(x): 718b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 719b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower # Plus() captures b. 72058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 721b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower def Plus(y): 722b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower return y + b 723b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 72458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return Plus(math_ops.matmul(w, x)) 725b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 72658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney y = Foo(constant_op.constant([[10.]])) 727b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 728b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower with self.test_session(graph=g): 72958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney variables.global_variables_initializer().run() 730b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower self.assertAllEqual(y.eval(), [[12.0]]) 731b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 732b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower def testCaptureControls(self): 73358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 734b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower with g.as_default(): 73558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = constant_op.constant([10.0]) 73658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = logging_ops.Print(x, [x], "outer") 737b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 73858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32) 739b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower def Foo(y): 74058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.control_dependencies([x]): 74158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney y = logging_ops.Print(y, [y], "inner") 742b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower return y 743b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 744b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower with self.assertRaisesRegexp(ValueError, "not an element of this graph."): 745b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower # NOTE: We still do not support capturing control deps. 746b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower _ = Foo(x) 747b95814208554ff6dac745c3f58a93929214a5363A. Unique TensorFlower 7481c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne def testCaptureInWhileLoop(self): 7491c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne g = ops.Graph() 7501c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne with g.as_default(): 7511c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne x = constant_op.constant(1) 7521c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne 7531c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne @function.Defun() 7541c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne def Foo(): 7551c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne return control_flow_ops.while_loop(lambda i: i < 10, 7561c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne lambda i: i + x, 7571c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne [0]) 7581c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne y = Foo() 7591c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne 7601c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne with self.test_session(graph=g) as sess: 7611c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne self.assertEqual(sess.run(y), 10) 7621c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne 7631c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne def testCaptureInCond(self): 7641c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne g = ops.Graph() 7651c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne with g.as_default(): 7661c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne x = constant_op.constant(1) 7671c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne 7681c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne @function.Defun(dtypes.bool) 7691c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne def Foo(pred): 7701c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne return control_flow_ops.cond(pred, 7711c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne lambda: x, 7721c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne lambda: x + 1) 7731c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne y = Foo(True) 7741c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne z = Foo(False) 7751c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne 7761c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne with self.test_session(graph=g) as sess: 7771c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne self.assertEqual(sess.run(y), 1) 7781c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne self.assertEqual(sess.run(z), 2) 7791c4810141e71289d71bfd94a74434bd09ee6b20fSkye Wanderman-Milne 780b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower def testStableName(self): 781b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower 782b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower @function.Defun() 783b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower def Foo(x, y, z): 78458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return math_ops.tanh(math_ops.matmul(x, y) + z) 785b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower 7862198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev # We added more randomness to function names in C API. 7872198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev # TODO(iga): Remove this if statement when we switch to C API. 7882198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev if ops._USE_C_API: # pylint: disable=protected-access 7891804be16f41f5217d1ad53a8ba992d9a132d4d79Derek Murray if sys.byteorder == "big": 79020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen self.assertEqual("Foo_kEdkAG8SJvg", 79120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen Foo.instantiate([dtypes.float32] * 3).name) 79220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen else: 79320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen self.assertEqual("Foo_aCYSbwBkR5A", 79420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen Foo.instantiate([dtypes.float32] * 3).name) 7952198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev else: 7962198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev self.assertEqual("Foo_d643acf7", 7972198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev Foo.instantiate([dtypes.float32] * 3).name) 7982eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 7992eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower def testSignatureHash(self): 8002eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower # Foo.Inner and Bar.Inner have identical function body but have 8012eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower # different signatures. They should be treated as two different functions. 8022eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 8032eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower @function.Defun() 8042eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower def Foo(x): 8052eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 8062eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower @function.Defun() 8072eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower def Inner(x): 8082eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower return x + 10. 8092eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 8102eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower return Inner(x) 8112eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 8122eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower @function.Defun() 8132eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower def Bar(x): 8142eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 8152eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower @function.Defun() 8162eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower def Inner(x, unused_y, unused_z): 8172eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower return x + 10. 8182eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 8192eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower return Inner(x, 2., 3.) 8202eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 82158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 8222eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower with g.as_default(): 82358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = constant_op.constant(10.0) 8242eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower y = Foo(x) 8252eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower z = Bar(x) 8262eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower 8272eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower with self.test_session(graph=g) as sess: 8282eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower v0, v1 = sess.run([y, z]) 8292eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower self.assertAllEqual(v0, 20.) 8302eaaadae5a0afc0a92ed81cca550d57bb9b29cc1A. Unique TensorFlower self.assertAllEqual(v1, 20.) 831b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower 8328343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower def testShapeFunction(self): 833689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 834689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower @function.Defun( 835689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower dtypes.float32, shape_func=lambda op: [op.inputs[0].get_shape()]) 8368343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower def Foo(x): 8378343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower return x + 1.0 8388343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower 8398343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower @function.Defun( 8408343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower shape_func=lambda op: [[1] + op.inputs[0].get_shape().as_list()]) 8418343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower def Bar(x): 8428343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower return array_ops.stack([x]) 8438343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower 8448343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower g = ops.Graph() 8458343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower with g.as_default(): 8468343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower x = Foo([1.0, 2.0]) 8478343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower self.assertEqual(x.get_shape().as_list(), [2]) 8488343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower y = Bar(array_ops.zeros([1, 2, 3])) 8498343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3]) 8508343176bb003f3591cde09ea47636fa5a6a9cc23A. Unique TensorFlower 851e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower def testVariableReuse(self): 852689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 853e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower def LinearWithReuse(input_tensor, reuse=None): 854e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower size = input_tensor.shape.dims[1] 855e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower with variable_scope.variable_scope("linear", reuse=reuse): 856689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower w = variable_scope.get_variable( 857689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower "w", shape=[size, size], dtype=input_tensor.dtype) 858e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower return math_ops.matmul(input_tensor, w) 859e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 860e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower @function.Defun(dtypes.float32) 861e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower def Foo(inputs): 862e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower inputs = array_ops.reshape(inputs, [32, 100]) 863e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower hidden = LinearWithReuse(inputs) 864e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower return LinearWithReuse(hidden, reuse=True) 865e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 866e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower input_op = array_ops.placeholder(shape=[32, 100], dtype=dtypes.float32) 867e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower output_op = Foo(input_op) 868e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 869e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower global_vars = variables.global_variables() 870e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower self.assertEqual(len(global_vars), 1) 871e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower self.assertEqual(global_vars[0].name, "linear/w:0") 872e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 873e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower with session.Session() as sess: 874e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower sess.run(variables.global_variables_initializer()) 875689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower output_val = sess.run( 876689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower output_op, feed_dict={input_op: np.random.rand(32, 100)}) 877e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower self.assertEqual(output_val.shape, (32, 100)) 878e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 879e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower def testFunctionCallInDifferentVariableScopes(self): 880689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 881e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower @function.Defun(dtypes.float32) 882e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower def Foo(inputs): 883689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower var = variable_scope.get_variable( 884689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower "var", 885689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower shape=[10], 886689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower dtype=dtypes.float32, 887689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower initializer=init_ops.ones_initializer()) 888e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower return inputs + var 889e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 890e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32) 891e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower with variable_scope.variable_scope("vs1"): 892e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower out1_op = Foo(input_op) 893e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 894e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower with variable_scope.variable_scope("vs2"): 895e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower out2_op = Foo(input_op) 896e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 897e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower global_vars = variables.global_variables() 898e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower self.assertEqual(len(global_vars), 1) 899e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower self.assertEqual(global_vars[0].name, "vs1/var:0") 900e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 901e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower with session.Session() as sess: 902e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower sess.run(variables.global_variables_initializer()) 903689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower out1, out2 = sess.run( 904689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower [out1_op, out2_op], feed_dict={input_op: np.linspace(1, 10, 10)}) 905e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower self.assertAllEqual(out1, np.linspace(2, 11, 10)) 906e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower self.assertAllEqual(out2, np.linspace(2, 11, 10)) 907e4a8dc831dbf2894c79659d50aea73999c1ff173A. Unique TensorFlower 908c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos def testTwoInputsSameOp(self): 909c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos g = ops.Graph() 910c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos with g.as_default(): 911c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos m = array_ops.placeholder(dtypes.float32) 912c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos s, u, v = linalg_ops.svd(m) 913c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos ss = math_ops.reduce_sum(s) 914c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos uu = math_ops.reduce_sum(u) 915c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos vv = math_ops.reduce_sum(v) 916c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos result = ss + uu + vv 917b876065afe85ef7b8e7334e506af2f84b3a6add1Alexandre Passos f = graph_to_function_def.graph_to_function_def( 918c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos g, 919c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos g.get_operations()[1:], # skip the placeholder 920c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos [s, u, v], 921c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos [result]) 922c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos self.assertEqual(len(f.signature.input_arg), 3) 923c048e2938ceced754f95e5a15dff81e37646aaa3Alexandre Passos 9244723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan def testGradientWithIntegerFunctionArgument(self): 9254723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan @function.Defun(dtypes.int32, dtypes.float32) 9264723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan def Foo(t, x): 9274723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan return x[t] 9284723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan 9294723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan g = ops.Graph() 9304723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan with g.as_default(): 9314723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan inp = array_ops.placeholder(dtypes.float32) 9324723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan t = constant_op.constant(0, dtypes.int32) 9334723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan out = Foo(t, inp) 9344723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan dinp, = gradients_impl.gradients(out, [inp]) 9354723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan 9364723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan x = np.zeros((2,)).astype(np.float32) 9374723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan with session.Session(graph=g) as sess: 9384723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan self.assertAllClose( 9394723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan np.array([1.0, 0.0]).astype(np.float32), 9404723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan sess.run(dinp, {inp: x})) 9414723f8f6ed4e43632ea90456bd36a1f8e8b1aeb8RJ Ryan 9421c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray def testFunctionMarkedStateful(self): 9431c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray 9441c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray @function.Defun(dtypes.int32, dtypes.float32) 9451c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray def Foo(t, x): 9461c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray return x[t] 9471c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray 9481c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray @function.Defun(dtypes.int64) 9491c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray def Bar(x): 9501c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray return x 9511c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray 9521c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray # NOTE(mrry): All functions are currently considered stateless by the 9531c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray # runtime, so we simulate a "stateful" function. 9541c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray # TODO(b/70565970): Remove this hack when we are able to build stateful 9551c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray # functions using the API. 9561c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray # pylint: disable=protected-access 9571c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray Foo._signature.is_stateful = True 9581c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray Bar._signature.is_stateful = True 9591c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray # pylint: enable=protected-access 9601c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray 9611c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray result_1 = Foo(3, [1.0, 2.0, 3.0, 4.0]) 9621c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray result_2 = Bar(constant_op.constant(100, dtype=dtypes.int64)) 9631c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray 9641c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray with session.Session() as sess: 9651c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray self.assertEqual(4.0, sess.run(result_1)) 9661c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray self.assertEqual(100, sess.run(result_2)) 9671c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray self.assertEqual((4.0, 100), sess.run((result_1, result_2))) 9681c2bcf947f2e192512857887fb1301d13fe332ecDerek Murray 969185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray def testStatefulFunction(self): 970185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray 971185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray @function.Defun() 972185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray def FunctionWithStatelessOp(): 973185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray return constant_op.constant(42.0) 974185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray 975185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray @function.Defun() 976185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray def FunctionWithStatefulOp(): 977185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray return random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32) 978185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray 979185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray @function.Defun() 980185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray def FunctionWithStatelessFunctionCall(): 981185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray return FunctionWithStatelessOp() 982185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray 983185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray @function.Defun() 984185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray def FunctionWithStatefulFunctionCall(): 985185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray return FunctionWithStatefulOp() 986185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray 987185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray # Test that the `is_stateful` bit is propagated. 988185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray self.assertFalse(FunctionWithStatelessOp.definition.signature.is_stateful) 989185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray self.assertTrue(FunctionWithStatefulOp.definition.signature.is_stateful) 990185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray self.assertFalse( 991185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray FunctionWithStatelessFunctionCall.definition.signature.is_stateful) 992185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray self.assertTrue( 993185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray FunctionWithStatefulFunctionCall.definition.signature.is_stateful) 994185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray 995185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray # Ensure that two invocations of the same random-number-generating 996185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray # function produce different results. 997185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray result1 = FunctionWithStatefulFunctionCall() 998185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray result2 = FunctionWithStatefulFunctionCall() 999185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray 1000185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray # Statefulness affects how the function is treated by the various 1001185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray # optimization passes, so run the test in each optimizer 1002185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray # configuration. 1003185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray for config in _OptimizerOptions(): 1004185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray with session.Session(config=config) as sess: 1005185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray val1, val2 = sess.run((result1, result2)) 1006185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray self.assertFalse(all(val1 == val2)) 1007185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray val3, val4 = sess.run((result1, result2)) 1008185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray self.assertFalse(all(val3 == val1)) 1009185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray self.assertFalse(all(val4 == val2)) 1010185c593cb71cb6d8116ba05c97e9385642648f1bDerek Murray 101132d138db751c541e951d1958cac4918214e9644eDerek Murray def testSameFunctionOnTwoDevices(self): 101232d138db751c541e951d1958cac4918214e9644eDerek Murray 101332d138db751c541e951d1958cac4918214e9644eDerek Murray @function.Defun(dtypes.float32) 101432d138db751c541e951d1958cac4918214e9644eDerek Murray def AddOne(x): 101532d138db751c541e951d1958cac4918214e9644eDerek Murray return x + 1.0 101632d138db751c541e951d1958cac4918214e9644eDerek Murray 101732d138db751c541e951d1958cac4918214e9644eDerek Murray with ops.device("/cpu:0"): 101832d138db751c541e951d1958cac4918214e9644eDerek Murray f_0 = AddOne(41.0) 101932d138db751c541e951d1958cac4918214e9644eDerek Murray 102032d138db751c541e951d1958cac4918214e9644eDerek Murray with ops.device("/cpu:1"): 102132d138db751c541e951d1958cac4918214e9644eDerek Murray f_1 = AddOne(43.0) 102232d138db751c541e951d1958cac4918214e9644eDerek Murray 102332d138db751c541e951d1958cac4918214e9644eDerek Murray for config in _OptimizerOptions(): 102432d138db751c541e951d1958cac4918214e9644eDerek Murray config.device_count["CPU"] = 2 102532d138db751c541e951d1958cac4918214e9644eDerek Murray with session.Session(config=config) as sess: 102632d138db751c541e951d1958cac4918214e9644eDerek Murray self.assertEqual(42.0, sess.run(f_0)) 102732d138db751c541e951d1958cac4918214e9644eDerek Murray self.assertEqual(44.0, sess.run(f_1)) 102832d138db751c541e951d1958cac4918214e9644eDerek Murray self.assertEqual((42.0, 44.0), sess.run((f_0, f_1))) 102932d138db751c541e951d1958cac4918214e9644eDerek Murray 10303f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 10312cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api 103200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milneclass FunctionsFromProtos(test.TestCase): 103300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 103400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def expectFunctionsEqual(self, func, grad_func=None, new_func=None): 103500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne if new_func is None: 103600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # Make a copy of func.definition to avoid any bugs masked by using the 103700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # same object 103800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne serialized_fdef = func.definition.SerializeToString() 103900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # Serialize and then deserialize `func` to create `new_func` 104000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne fdef = function_pb2.FunctionDef.FromString(serialized_fdef) 104100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne new_func = function._from_definition(fdef, grad_func=grad_func) 104200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(func.name, new_func.name) 104300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(func.definition, new_func.definition) 104400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(func.grad_func_name, new_func.grad_func_name) 104500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(func.declared_input_types, new_func.declared_input_types) 104600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(func.captured_inputs, new_func.captured_inputs) 104700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 104800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def testBasic(self): 1049689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 105000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, dtypes.float32) 105100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def Foo(x, y): 105200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return x + y 1053689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 105400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.expectFunctionsEqual(Foo) 105500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 105600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def testGradFunc(self): 1057689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 105800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, dtypes.float32) 105900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def G(x, dy): 106000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return x * dy 106100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 106200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, grad_func=G) 106300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def F(x): 106400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return math_ops.exp(x) - math_ops.exp(-x) 1065689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 106600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.expectFunctionsEqual(F, grad_func=G) 106700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 106800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def testCapturedInputs(self): 106900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne c = constant_op.constant(10, dtypes.int64) 1070689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 107100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.int64) 107200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def Foo(x): 107300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return x + c 107400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 107500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne new_func = function._from_definition(Foo.definition) 107600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 107700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(Foo.name, new_func.name) 107800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(Foo.definition, new_func.definition) 107900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(Foo.grad_func_name, new_func.grad_func_name) 108000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 108100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # Captured inputs are added as regular inputs to the function definition 108200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(new_func.declared_input_types, 108300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne Foo.declared_input_types + (dtypes.int64,)) 108400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(len(new_func.captured_inputs), 0) 108500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 108600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def testNestedFunctions(self): 1087689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 108800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32) 108900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def Outer(x): 109000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 109100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32) 109200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def Inner(y): 109300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return y + 1 109400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 109500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return Inner(Inner(x)) 109600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 109700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.expectFunctionsEqual(Outer) 109800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 109900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def testFromLibrary(self): 110000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # Define some functions with different gradient functions. Note that many of 110100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # the below functions are identical since function bodies don't matter for 110200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # this test. 110300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 110400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, dtypes.float32) 110500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def G1(x, dy): 110600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return x * dy 110700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 110800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, dtypes.float32) 110900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def G2(x, dy): 111000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return x * dy 111100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 111200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # F1 and F2 have the same gradient function 111300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, grad_func=G1) 111400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def F1(x): 111500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return math_ops.exp(x) - math_ops.exp(-x) 111600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 111700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, grad_func=G1) 111800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def F2(x): 111900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return math_ops.exp(x) - math_ops.exp(-x) 112000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 112100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # F3 has a different gradient function 112200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, grad_func=G2) 112300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def F3(x): 112400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return math_ops.exp(x) - math_ops.exp(-x) 112500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 112600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # F4 has no gradient function 112700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32) 112800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def F4(x): 112900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return math_ops.exp(x) - math_ops.exp(-x) 113000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 113100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # Instantiate all functions 113200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne g = ops.Graph() 113300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne with g.as_default(): 113400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne c = constant_op.constant(1.0, dtypes.float32) 113500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne f1 = F1(c) 113600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne f2 = F2(c) 113700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne f3 = F3(c) 113800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne f4 = F4(c) 113900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradients_impl.gradients([f1, f2, f3, f4], c) 114000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 114100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library = g.as_graph_def().library 114200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne new_funcs = function._from_library(library) 114300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 114400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def CheckNewFunc(func): 114500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne new_func = [f for f in new_funcs if f.name == func.name] 114600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(len(new_func), 1) 114700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.expectFunctionsEqual(func, new_func=new_func[0]) 114800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 114900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne CheckNewFunc(G1) 115000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne CheckNewFunc(G2) 115100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne CheckNewFunc(F1) 115200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne CheckNewFunc(F2) 115300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne CheckNewFunc(F3) 115400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne CheckNewFunc(F4) 115500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 115600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def testFromLibraryEmptyLib(self): 115700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library = function_pb2.FunctionDefLibrary() 115800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne self.assertEqual(len(function._from_library(library)), 0) 115900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 116000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def testFromLibraryMissingFuncDef(self): 1161689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 116200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32, dtypes.float32) 116300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def G1(x, dy): 116400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return x * dy 116500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 116600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32) 116700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def F1(x): 116800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return math_ops.exp(x) - math_ops.exp(-x) 116900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 117000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient = function_pb2.GradientDef() 117100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient.function_name = F1.name 117200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient.gradient_func = G1.name 117300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 117400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # Create invalid function def that is missing G1 function def 117500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library = function_pb2.FunctionDefLibrary() 117600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library.gradient.extend([gradient]) 117700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library.function.extend([F1.definition]) 117800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 117900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne with self.assertRaisesRegexp( 11802198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev ValueError, 11812198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev "FunctionDefLibrary missing 'G1_[0-9a-zA-Z]{8,11}' FunctionDef"): 118200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne function._from_library(library) 118300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 118400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # Create invalid function def that is missing F1 function def 118500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library = function_pb2.FunctionDefLibrary() 118600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library.gradient.extend([gradient]) 118700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library.function.extend([G1.definition]) 118800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 118900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne with self.assertRaisesRegexp( 11902198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev ValueError, 11912198b8cfe8acb5af7bb5a1dac54c18ff72c98002Igor Ganichev "FunctionDefLibrary missing 'F1_[0-9a-zA-Z]{8,11}' FunctionDef"): 119200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne function._from_library(library) 119300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 119400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def testFromLibraryCyclicGradFuncs(self): 1195689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 119600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32) 119700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def F1(x): 119800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return math_ops.exp(x) - math_ops.exp(-x) 119900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 120000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne @function.Defun(dtypes.float32) 120100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne def F2(x): 120200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne return math_ops.exp(x) - math_ops.exp(-x) 120300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 120400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # Create invalid function def library where F1 has gradient function F2 and 120500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne # F2 has gradient function F1 120600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library = function_pb2.FunctionDefLibrary() 120700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library.function.extend([F1.definition, F2.definition]) 120800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 120900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient1 = function_pb2.GradientDef() 121000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient1.function_name = F1.name 121100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient1.gradient_func = F2.name 121200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 121300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient2 = function_pb2.GradientDef() 121400a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient2.function_name = F2.name 121500a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne gradient2.gradient_func = F1.name 121600a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 121700a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne library.gradient.extend([gradient1, gradient2]) 121800a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 121900a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne with self.assertRaisesRegexp( 122000a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne ValueError, "FunctionDefLibrary contains cyclic gradient functions!"): 122100a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne function._from_library(library) 122200a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 122300a294d68e0f36c44fefcf2e07bf40068250a884Skye Wanderman-Milne 12242cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api 122558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass FunctionOverloadTest(test.TestCase): 12267a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12277a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower def testBasic(self): 12287a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12297a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower @function.Defun() 12307a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower def Sinh(x): 123158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return 1 / 2. * (math_ops.exp(x) - math_ops.exp(-x)) 12327a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 123358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 12347a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower with g.as_default(): 123558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = Sinh(constant_op.constant(0.25, dtypes.float32)) 123658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney y = Sinh(constant_op.constant(0.25, dtypes.float64)) 12377a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12387a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower with self.test_session(graph=g): 12397a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertAllClose(x.eval(), np.sinh(0.25)) 12407a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertAllClose(y.eval(), np.sinh(0.25)) 12417a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12427a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower def testGradient(self): 12437a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12447a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower @function.Defun(func_name="Spec") 12457a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower def G(x, dy): 12467a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower return x * dy 12477a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12487a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower @function.Defun(grad_func=G) 12497a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower def F(x): 125058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return math_ops.exp(x) - math_ops.exp(-x) 12517a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 125258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney for dtype in [dtypes.float32, dtypes.float64]: 125358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 12547a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower with g.as_default(): 125558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = constant_op.constant(0.25, dtype) 12567a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower y = F(x) 125758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dx, = gradients_impl.gradients(y, x) 12587a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12597a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower with self.test_session(graph=g): 12607a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertAllClose(dx.eval(), 0.25) 12617a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12627a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower def testDocString(self): 12637a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12647a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower @function.Defun() 12657a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower def Foo(x): 12667a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower """Successor of x.""" 12677a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower return x + 1 12687a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 126958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 12707a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower with g.as_default(): 12717a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower _ = Foo(1) 12727a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12737a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower self.assertEqual(g.as_graph_def().library.function[0].signature.description, 12747a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower "Successor of x.") 12757a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12767a0cde1dd7b9656da3be8aebc7bfe9ec65e0b6b0A. Unique TensorFlower 12772cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api 1278b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlowerclass FunctionCaptureByValueTest(test.TestCase): 1279b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1280b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower def testCaptureByValue(self): 1281b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower g = ops.Graph() 1282b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower with g.as_default(): 1283b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower w = constant_op.constant([[1.0]]) 1284b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower b = constant_op.constant([2.0]) 1285b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1286b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower # Foo() captures w and b. 1287b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower @function.Defun(dtypes.float32, capture_by_value=True) 1288b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower def Foo(x): 1289b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1290b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower # Plus() captures b. 1291b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower @function.Defun(dtypes.float32, capture_by_value=True) 1292b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower def Plus(y): 1293b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower return y + b 1294b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1295b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower self.assertEqual(0, len(Plus.captured_inputs)) 1296b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1297b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower return Plus(math_ops.matmul(w, x)) 1298b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1299b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower y = Foo(constant_op.constant([[10.]])) 1300b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1301b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower self.assertEqual(0, len(Foo.captured_inputs)) 1302b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1303b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower with self.test_session(graph=g): 1304b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower self.assertAllEqual(y.eval(), [[12.0]]) 1305b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 1306b57bbdecb24afa18d1716403dd05a08566e3e516A. Unique TensorFlower 13072cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api 130858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass UnrollLSTMTest(test.TestCase): 13093f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower BATCH_SIZE = 16 13103f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower LSTM_DIMS = 32 1311e27da590fec8eed886ecd1cf3d0c2575dffeaa09A. Unique TensorFlower NUM_UNROLL = 20 13123f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 13133f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower def _Weights(self): 13143f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower dims = self.LSTM_DIMS 131558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return random_ops.random_uniform([2 * dims, 4 * dims], -1, 1, seed=123456) 13163f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 13173f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower def _Input(self): 131858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return random_ops.random_uniform( 1319ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower [self.NUM_UNROLL, self.BATCH_SIZE, self.LSTM_DIMS], seed=654321) 13203f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 1321084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower # Helper to construct a LSTM cell graph. 1322084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower @classmethod 1323084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower def LSTMCell(cls, x, mprev, cprev, weights): 13240e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower xm = array_ops.concat([x, mprev], 1) 132558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney i_i, i_g, f_g, o_g = array_ops.split( 132658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney value=math_ops.matmul(xm, weights), num_or_size_splits=4, axis=1) 132758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney new_c = math_ops.sigmoid(f_g) * cprev + math_ops.sigmoid( 132858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney i_g) * math_ops.tanh(i_i) 132958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney new_c = clip_ops.clip_by_value(new_c, -50.0, 50.0) 133058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney new_m = math_ops.sigmoid(o_g) * math_ops.tanh(new_c) 1331084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower return new_m, new_c 1332084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower 13333f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower def _BuildForward(self, weights, inp, mode="cell"): 13343f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 13353f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower def Loop(cell, w, i): 133658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = array_ops.unstack(i, self.NUM_UNROLL) 133758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney m = array_ops.zeros_like(x[0]) 133858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney c = array_ops.zeros_like(x[0]) 13393f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower for i in range(self.NUM_UNROLL): 13403f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower m, c = cell(x[i], m, c, w) 13413f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return m 13423f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 13433f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower cell = UnrollLSTMTest.LSTMCell 13443f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower if mode == "complete": 13453f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower # Constructs the complete graph in python. 13463f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return Loop(cell, weights, inp) 13473f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 134858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney cell = function.Defun(dtypes.float32, dtypes.float32, dtypes.float32, 134958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dtypes.float32)(cell) 13503f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower if mode == "cell": 13513f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower # Just represent the LSTM as a function. 13523f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return Loop(cell, weights, inp) 13533f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 13543f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower if mode == "loop": 13553f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower # Wraps the whole loop as a function. 135658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32) 13573f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower def LSTMLoop(w, i): 13583f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return Loop(cell, w, i) 13593f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 13603f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return LSTMLoop(weights, inp) 13613f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 13623f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower if mode == "loop10": 13633f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower # Wraps 10 lstm steps into one function, and the whole loop 13643f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower # into another calling the formers. 13653f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 136619b7fd80780d02372f76076bc8eb40d55a89a301A. Unique TensorFlower # Groups 10 steps at a time. 136758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32, dtypes.float32, 136858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney *([dtypes.float32] * 10)) 136919b7fd80780d02372f76076bc8eb40d55a89a301A. Unique TensorFlower def Loop10(w, m, c, *args): 137019b7fd80780d02372f76076bc8eb40d55a89a301A. Unique TensorFlower for x in args: 13713f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower m, c = cell(x, m, c, w) 13723f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return m, c 13733f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 137458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney @function.Defun(dtypes.float32, dtypes.float32) 13753f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower def LSTMLoop10(weights, inp): 137658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = array_ops.unstack(inp, self.NUM_UNROLL) 137758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney m = array_ops.zeros_like(x[0]) 137858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney c = array_ops.zeros_like(x[0]) 13793f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower assert self.NUM_UNROLL % 10 == 0 13803f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower for i in range(0, self.NUM_UNROLL, 10): 13813f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower m, c = Loop10(weights, m, c, *x[i:i + 10]) 13823f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return m 13833f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 13843f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return LSTMLoop10(weights, inp) 138566c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower 1386084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower def testUnrollLSTM(self): 138766c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower # Run one step of the unrolled lstm graph. 1388e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower def RunForward(mode, cfg=None): 138958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney tf_logging.info("mode = %s", mode) 139058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 139166c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower start = time.time() 139266c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower with g.as_default(): 13933f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower weights = self._Weights() 13943f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower inp = self._Input() 13953f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower m = self._BuildForward(weights, inp, mode) 139666c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower gdef = g.as_graph_def() 139766c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower finish = time.time() 139858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start, 1399d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower len(str(gdef)), len(gdef.SerializeToString())) 140058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with g.as_default(), session.Session(config=cfg) as sess: 14013f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return sess.run(m) 14023f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 14033f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower mv0 = RunForward("complete") 1404f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower for cfg in _OptimizerOptions(): 140558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney tf_logging.info("cfg = %s", cfg) 1406e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower mv1 = RunForward("cell", cfg) 1407e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower mv2 = RunForward("loop", cfg) 1408e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower mv3 = RunForward("loop10", cfg) 1409e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower self.assertAllClose(mv0, mv1, rtol=1e-4) 1410e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower self.assertAllClose(mv0, mv2, rtol=1e-4) 1411e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower self.assertAllClose(mv0, mv3, rtol=1e-4) 141266c21d60d4becef8c72162015c66492ba975495aA. Unique TensorFlower 1413084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower def testUnrollLSTMGrad(self): 1414084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower # Run one step of the unrolled lstm graph. 1415e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower def RunForwardBackward(mode, cfg=None): 141658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney tf_logging.info("mode = %s", mode) 141758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 1418084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower start = time.time() 14190c6c4b93dedf6a2c654e012a4ffe1df642834419A. Unique TensorFlower with g.as_default(): 14203f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower weights = self._Weights() 14213f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower inp = self._Input() 14223f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower m = self._BuildForward(weights, inp, mode) 142358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney loss = math_ops.reduce_sum(math_ops.square(m)) 142458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dw = gradients_impl.gradients([loss], [weights]) 1425084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower gdef = g.as_graph_def() 1426084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower finish = time.time() 142758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start, 1428d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower len(str(gdef)), len(gdef.SerializeToString())) 142958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with g.as_default(), session.Session(config=cfg) as sess: 14303f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower return sess.run(dw) 14313f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower 14323f9101e95e3d1359cb81863657b51c5909f89baaA. Unique TensorFlower d0 = RunForwardBackward("complete") 1433f10637b372b3216400129b42abacd4fe50a6d7eaA. Unique TensorFlower for cfg in _OptimizerOptions(): 143458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney tf_logging.info("cfg = %s", cfg) 1435e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower d1 = RunForwardBackward("cell", cfg) 1436e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower d2 = RunForwardBackward("loop", cfg) 1437e830638148e203a2d9cf491e4693d35661a360d1A. Unique TensorFlower d3 = RunForwardBackward("loop10", cfg) 1438439d5bd72e879e40308f7197a33a1a20625a673bGunhan Gulsoy self.assertAllClose(d0, d1, rtol=1e-4, atol=1e-4) 1439439d5bd72e879e40308f7197a33a1a20625a673bGunhan Gulsoy self.assertAllClose(d0, d2, rtol=1e-4, atol=1e-4) 1440439d5bd72e879e40308f7197a33a1a20625a673bGunhan Gulsoy self.assertAllClose(d0, d3, rtol=1e-4, atol=1e-4) 1441084075e4e967677fc0be9fe921a0157126e69617A. Unique TensorFlower 1442d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlower 14432cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api 144458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass FunctionInlineControlTest(test.TestCase): 14450cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower 14460cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower def testFoo(self): 144758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dtype = dtypes.float32 144858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions( 144958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney optimizer_options=config_pb2.OptimizerOptions( 145058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney opt_level=config_pb2.OptimizerOptions.L0, 14510cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower do_common_subexpression_elimination=True, 14520cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower do_function_inlining=True, 14530cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower do_constant_folding=True))) 1454c0f6357c4a080edb10dd089151dd523834ea80fcA. Unique TensorFlower cell_func_call_pattern = re.compile(r"Cell[^/]*\(") 14550cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower for noinline in [False, True]: 1456ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 1457adc15a46c56070ef92b890b65b6de0147abeff80A. Unique TensorFlower @function.Defun(dtype, noinline=noinline) 14581d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Cell(v): 14591d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower # If v is a vector [n, 1], x is a big square matrix. 146058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = math_ops.tanh(v + array_ops.transpose(v, [1, 0])) 1461b9f548d041ba8d66102c6d195e645051f1bee52fYukun Chen return math_ops.reduce_sum(x, 1, keepdims=True) 14620cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower 14631d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower @function.Defun(dtype) 14641d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def Forward(x): 14651d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower for _ in range(10): 1466bb5764ca8da329b2c5c81f6c0219c260efe288e4A. Unique TensorFlower # pylint: disable=cell-var-from-loop 146775b07290907abf759f05d8ea087f4623de316db0A. Unique TensorFlower x = Cell(x) 146858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return math_ops.reduce_sum(x, [0, 1]) 14690cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower 1470aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov self.assertEqual(noinline, Cell.definition.attr["_noinline"].b) 1471aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov 147258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 14731d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower with g.as_default(): 147458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = array_ops.placeholder(dtype) 14750cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower y = Forward(x) 147658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dx, = gradients_impl.gradients([y], [x]) 14770cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower 147822cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro np.random.seed(321) 147922cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro inp = np.random.uniform(-1, 1, [16, 1]).astype(np.float32) 148058201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney run_metadata = config_pb2.RunMetadata() 148158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session(graph=g, config=cfg) as sess: 1482689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower ans = sess.run( 1483689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower [y, dx], {x: inp}, 1484689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower run_metadata=run_metadata, 1485689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower options=config_pb2.RunOptions( 1486689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower trace_level=config_pb2.RunOptions.FULL_TRACE)) 148722cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro print(ans[0], np.sum(ans[1])) 148822cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro self.assertAllClose(ans[0], 255.971, rtol=1e-3) 148922cfbd1cebda28c3b55c64519f4af23c462b0000Craig Citro self.assertAllClose(np.sum(ans[1]), 13.0408, rtol=1e-3) 14900cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower 1491aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov def MetadataHasCell(run_metadata): 1492aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov for dev_stats in run_metadata.step_stats.dev_stats: 1493aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov for node_stats in dev_stats.node_stats: 1494c0f6357c4a080edb10dd089151dd523834ea80fcA. Unique TensorFlower if cell_func_call_pattern.search(node_stats.timeline_label): 1495aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov return True 1496aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov return False 1497aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov 1498aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov self.assertEqual(MetadataHasCell(run_metadata), noinline) 1499aac685b7209b03ffd356ea6860366467b335d402Dan Smilkov 15000cc25f0b95af9a99dd228c75d411670831fd6cecA. Unique TensorFlower 150158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney@function.Defun(*[dtypes.float32] * 3) 15021d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlowerdef Linear(w, b, x): 150358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return nn_ops.relu(math_ops.matmul(x, w) + b) 15041d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 15051d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 150658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney@function.Defun(*[dtypes.float32] * 5) 15071d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlowerdef Linear2(w1, b1, w2, b2, x): 15081d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower return Linear(w2, b2, Linear(w1, b1, x)) 15091d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 15101d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 15112cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev# Set C API before defining module level functions 15122cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevops._USE_C_API = True 15132cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 15142cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 15152cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@function.Defun(*[dtypes.float32] * 3) 15162cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevdef LinearWithCApi(w, b, x): 15172cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev return nn_ops.relu(math_ops.matmul(x, w) + b) 15182cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 15192cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 15202cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@function.Defun(*[dtypes.float32] * 5) 15212cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevdef Linear2WithCApi(w1, b1, w2, b2, x): 15222cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev return LinearWithCApi(w2, b2, LinearWithCApi(w1, b1, x)) 15232cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 15242cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 15252cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev# Unset C API after defining module level functions 15262cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichevops._USE_C_API = False 15272cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 15282cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 152958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass ModuleFunctionTest(test.TestCase): 15301d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 15311d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower def testBasic(self): 153258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with ops.Graph().as_default(): 153358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney a, b, c, d, e = [ 1534689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5) 153558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney ] 15361d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower y = Linear(a, b, c) 15371d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower z = Linear2(a, b, c, d, e) 153858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with session.Session() as sess: 15391d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower self.assertAllEqual([[1]], sess.run(y)) 15401d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower self.assertAllEqual([[5]], sess.run(z)) 15411d2aa9451d920ad4bc8ad1ac86c06863b590e81fA. Unique TensorFlower 15422cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev @test_util.enable_c_api 15432cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev def testBasicWithCApi(self): 15442cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev with ops.Graph().as_default(): 15452cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev a, b, c, d, e = [ 15462cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5) 15472cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev ] 15482cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev y = LinearWithCApi(a, b, c) 15492cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev z = Linear2WithCApi(a, b, c, d, e) 15502cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev with session.Session() as sess: 15512cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev self.assertAllEqual([[1]], sess.run(y)) 15522cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev self.assertAllEqual([[5]], sess.run(z)) 15532cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev 1554ddb917930c7f2519166c879e8f5273afb9c5c59aA. Unique TensorFlower 15552cb94608bb95024ad0c3444fd1167f8765e03acbIgor Ganichev@test_util.with_c_api 155658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunneyclass VariableHoistingTest(test.TestCase): 1557d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 15586237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos def _testSimpleModel(self, use_forward_func, use_resource=False): 1559d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1560d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower def _Model(x): 156158201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney w = variable_scope.get_variable( 156258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney "w", (64, 64), 15636237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos initializer=init_ops.random_uniform_initializer(seed=312), 15646237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos use_resource=use_resource) 156558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney b = variable_scope.get_variable( 1566689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower "b", (64), 1567689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower initializer=init_ops.zeros_initializer(), 15686237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos use_resource=use_resource), 156958201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney return math_ops.sigmoid(math_ops.matmul(x, w) + b) 1570d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1571d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower @function.Defun() 1572d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower def Model(x): 1573d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower return _Model(x) 1574d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1575d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower cvars = [] 1576d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1577d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower @function.Defun() 1578d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower def Grad(x, y0): 1579d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower if use_forward_func: 1580d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower y = Model(x) 1581d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower else: 1582d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower y = _Model(x) 158358201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney loss = math_ops.reduce_mean( 158458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney math_ops.reduce_sum(y0 * math_ops.log(y), 1), 0) 1585b4cc63c5e59a922ad8040aa5da06d128733d5516A. Unique TensorFlower arg_w, arg_b = function.get_extra_args() 158658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney self.assertEqual(arg_w.get_shape(), tensor_shape.TensorShape([64, 64])) 158758201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney self.assertEqual(arg_b.get_shape(), tensor_shape.TensorShape([64])) 158858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney dw, db = gradients_impl.gradients(loss, [arg_w, arg_b]) 1589d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower cvars.extend(function.get_extra_vars()) 1590d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower return loss, dw, db 1591d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 159258201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney g = ops.Graph() 1593d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower with g.as_default(): 159458201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney x = random_ops.random_normal([64, 64], seed=100) 159558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney y0 = random_ops.random_normal([64, 64], seed=200) 159658201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney with variable_scope.variable_scope("Foo"): 1597d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower loss, dw, db = Grad(x, y0) 1598d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1599d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertEqual(2, len(cvars)) 1600d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower w, b = cvars[:2] 1601d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertEqual("Foo/w", w.op.name) 1602d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertEqual("Foo/b", b.op.name) 1603d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1604d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower with self.test_session(graph=g) as sess: 160558201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney sess.run(variables.global_variables_initializer()) 1606d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db]) 1607d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1608d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllEqual(w.shape, (64, 64)) 1609d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllClose(np.sum(w), 2050.44) 1610d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllEqual(b.shape, (64,)) 1611d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllClose(np.sum(b), 0.0) 1612d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllClose(loss, -2.27, rtol=1e-2) 1613d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllEqual(dw.shape, (64, 64)) 1614d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllClose(np.sum(dw), -1.04, rtol=1e-2) 1615d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllEqual(db.shape, (64,)) 1616d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self.assertAllClose(np.sum(db), 0.509, rtol=1e-2) 1617d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1618d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower def testBasic(self): 1619d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self._testSimpleModel(True) 1620d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower self._testSimpleModel(False) 1621d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1622020501c695119f76bba0dd2bb47abfeaa939d669Alexandre Passos def testBasicResource(self): 16236237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos self._testSimpleModel(True, use_resource=True) 16246237c439bdc3c3f882c1532ef85708956410e6f6Alexandre Passos self._testSimpleModel(False, use_resource=True) 1625d680b19889fa7cb6c25a3a32006199e542b3d411A. Unique TensorFlower 1626689cbda96444511bd37a01b125791c45a093bec3A. Unique TensorFlower 1627d12b67833196d5508657d488cb8d96127419f2ebA. Unique TensorFlowerif __name__ == "__main__": 162858201a058853de647b37ddb0ccf63d89b2357f03Justine Tunney test.main() 1629