1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Canonicalizing list comprehensions into for and if statements. 16 17e.g. 18result = [x * x for x in xs] 19 20becomes 21 22result = [] 23for x in xs: 24 elt = x * x 25 result.append(elt) 26""" 27 28from __future__ import absolute_import 29from __future__ import division 30from __future__ import print_function 31 32import gast 33 34from tensorflow.contrib.py2tf.pyct import parser 35from tensorflow.contrib.py2tf.pyct import templates 36from tensorflow.contrib.py2tf.pyct import transformer 37 38 39class ListCompCanonicalizationTransformer(transformer.Base): 40 """NodeTransformer to canonicalize list comprehensions.""" 41 42 def __init__(self, context): 43 super(ListCompCanonicalizationTransformer, self).__init__(context) 44 45 def make_update_list_node(self, list_, elt): 46 return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0] 47 48 def instantiate_list_node(self): 49 return parser.parse_str('[]').body[0].value 50 51 def visit_Assign(self, node): 52 if not isinstance(node.value, gast.ListComp): 53 return node 54 if len(node.targets) > 1: 55 raise ValueError('Only support single assignment.') 56 return self.canonicalize_listcomp(node.targets[0], node.value) 57 58 def canonicalize_listcomp(self, result_node, list_comp_node): 59 60 make_list = templates.replace( 61 'list_ = create_list', 62 list_=result_node, 63 create_list=self.instantiate_list_node()) 64 loop_body = self.make_update_list_node(result_node, list_comp_node.elt) 65 66 for gen in reversed(list_comp_node.generators): 67 for gen_if in reversed(gen.ifs): 68 loop_body = templates.replace( 69 'if test: loop_body', test=gen_if, loop_body=loop_body) 70 loop_body = templates.replace( 71 'for target in iter_: loop_body', 72 iter_=gen.iter, 73 target=gen.target, 74 loop_body=loop_body) 75 76 return make_list + loop_body 77 78 79def transform(node, context): 80 return ListCompCanonicalizationTransformer(context).visit(node) 81