1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities that match patterns in a tf.Graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23 24class Pattern(object): 25 """The parent class of all patterns (e.g. OpTypePattern and OneofPattern).""" 26 27 @abc.abstractmethod 28 def match(self, op, tensor): 29 """Returns the result of matching op/tensor against this pattern.""" 30 raise NotImplementedError('Method "match" not implemented.') 31 32 33class OpTypePattern(Pattern): 34 """A tree pattern that matches TF expressions with certain op types.""" 35 36 def __init__(self, op_type, name=None, inputs=None): 37 """Initializes an OpTypePattern. 38 39 Args: 40 op_type: string that specifies the allowed types of the root. It can be 41 (1) an op type, e.g. 'Conv2D', 42 (2) '*', i.e. wildcard, or 43 (3) multiple op types separated by '|', e.g., 'Relu|Relu6'. 44 We could use regex strings, which might be worthwhile when we have many 45 similar TF op types. 46 name: Optional string. The name of the pattern that can be looked up in 47 MatchResult. 48 inputs: Optional list of `Pattern`s or strings that specify the 49 patterns for the inputs of a matching op. If None, this pattern accepts 50 any inputs of a matching op. 51 """ 52 self._op_type = op_type 53 self._name = name 54 if inputs is None: 55 inputs = [] 56 self._inputs = [ 57 input_pattern 58 if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern) 59 for input_pattern in inputs 60 ] 61 62 @property 63 def name(self): 64 return self._name 65 66 def match(self, op, tensor): 67 if self._op_type != '*': 68 if op.type not in self._op_type.split('|'): 69 return None 70 71 match_result = MatchResult() 72 match_result.add(self, op, tensor) 73 74 if not self._inputs: 75 # If pattern.inputs is empty, skips the rest and accepts all the inputs. 76 return match_result 77 78 if len(op.inputs) != len(self._inputs): 79 return None 80 81 for input_tensor, input_pattern in zip(op.inputs, self._inputs): 82 input_match_result = input_pattern.match(input_tensor.op, input_tensor) 83 if input_match_result is None: 84 return None 85 match_result.merge_from(input_match_result) 86 return match_result 87 88 89class OneofPattern(Pattern): 90 """Matches one of the given sub-patterns.""" 91 92 def __init__(self, sub_patterns): 93 self._sub_patterns = sub_patterns 94 95 def match(self, op, tensor): 96 for sub_pattern in self._sub_patterns: 97 match_result = sub_pattern.match(op, tensor) 98 if match_result is not None: 99 return match_result 100 return None 101 102 103class MatchResult(object): 104 r"""Encapsulates the result of a match done by GraphMatcher. 105 106 MatchResult contains a map from OpTypePattern to the matching op and tensor. 107 When the matching op has multiple output tensors, the matching tensor is the 108 output tensor used by the matching op of the parent pattern. E.g., when we 109 match graph 110 111 - + 112 / \y0 y1/ \ 113 x split z 114 | 115 y (nodes are ops; edges are going up) 116 117 against add_pattern defined as 118 119 y1_pattern = OpTypePattern('*') 120 z_pattern = OpTypePattern('*') 121 add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern]) 122 123 the matching op of `y1_pattern` is `split`, and the matching tensor of 124 `y1_pattern` 125 is `y1` not `y0`. 126 """ 127 128 def __init__(self): 129 self._pattern_to_op_tensor = {} 130 self._name_to_pattern = {} 131 132 def add(self, pattern, op, tensor): 133 self._pattern_to_op_tensor[pattern] = op, tensor 134 if pattern.name is not None: 135 if pattern.name in self._name_to_pattern: 136 raise ValueError( 137 'Name %s is already bound to another pattern' % pattern.name) 138 self._name_to_pattern[pattern.name] = pattern 139 140 def _to_pattern(self, pattern_or_name): 141 if isinstance(pattern_or_name, OpTypePattern): 142 return pattern_or_name 143 144 if isinstance(pattern_or_name, str): 145 if pattern_or_name not in self._name_to_pattern: 146 return None 147 return self._name_to_pattern[pattern_or_name] 148 149 raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.' 150 % type(pattern_or_name)) 151 152 def _get_op_tensor(self, pattern_or_name): 153 pattern = self._to_pattern(pattern_or_name) 154 if pattern is None: 155 return None 156 157 if pattern not in self._pattern_to_op_tensor: 158 return None 159 160 return self._pattern_to_op_tensor[pattern] 161 162 def get_op(self, pattern_or_name): 163 op_tensor = self._get_op_tensor(pattern_or_name) 164 return op_tensor[0] if op_tensor else None 165 166 def get_tensor(self, pattern_or_name): 167 op_tensor = self._get_op_tensor(pattern_or_name) 168 return op_tensor[1] if op_tensor else None 169 170 def merge_from(self, other_match_result): 171 # pylint: disable=protected-access 172 self._pattern_to_op_tensor.update(other_match_result._pattern_to_op_tensor) 173 self._name_to_pattern.update(other_match_result._name_to_pattern) 174 # pylint: enable=protected-access 175 176 177class GraphMatcher(object): 178 """Checks if a particular subgraph matches a given pattern.""" 179 180 def __init__(self, pattern): 181 """Initializes a GraphMatcher. 182 183 Args: 184 pattern: The `Pattern` against which `GraphMatcher` matches 185 subgraphs. 186 """ 187 self._pattern = pattern 188 189 def _match_pattern(self, pattern, op, tensor): 190 """Returns whether an TF expression rooted at `op` matches `pattern`. 191 192 If there is a match, adds to `self._match_result` the matching op and tensor 193 with key `pattern`. 194 195 Args: 196 pattern: An `Pattern`. 197 op: A `tf.Operation` to match against the pattern. 198 tensor: the output `tf.Tensor` of `op` that is used by the matching op of 199 `pattern`'s parent. Can be None if `pattern` is already the root of the 200 pattern tree. 201 202 Returns: 203 True if an TF expression rooted at `op` matches `pattern`. 204 """ 205 match_result = pattern.match(op, tensor) 206 if match_result is None: 207 return False 208 self._match_result.merge_from(match_result) 209 return True 210 211 def match_op(self, op): 212 """Matches `op` against `self._pattern`. 213 214 Args: 215 op: `tf.Operation` to match against the pattern. 216 217 Returns: 218 Returns a `MatchResult` if `op` matches the pattern; otherwise, returns 219 None. 220 """ 221 self._match_result = MatchResult() 222 if not self._match_pattern(self._pattern, op, tensor=None): 223 return None 224 return self._match_result 225 226 def match_ops(self, ops): 227 """Matches each operation in `ops` against `self._pattern`. 228 229 Args: 230 ops: collection of `tf.Operation` to match against the pattern. 231 232 Yields: 233 `MatchResult` for each `tf.Operation` that matches the pattern. 234 """ 235 for op in ops: 236 match_result = self.match_op(op) 237 if match_result: 238 yield match_result 239 240 def match_graph(self, graph): 241 """Matches each operation in `graph` against `self._pattern`. 242 243 Args: 244 graph: `tf.Graph` containing operations to match. 245 246 Yields: 247 `MatchResult` for each `tf.Operation` in `graph` that matches the pattern. 248 """ 249 # Python 3.3.2+ implements `yield from`, but for now: 250 for match_result in self.match_ops(graph.get_operations()): 251 yield match_result 252