1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities that match patterns in a tf.Graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23
24class Pattern(object):
25  """The parent class of all patterns (e.g. OpTypePattern and OneofPattern)."""
26
27  @abc.abstractmethod
28  def match(self, op, tensor):
29    """Returns the result of matching op/tensor against this pattern."""
30    raise NotImplementedError('Method "match" not implemented.')
31
32
33class OpTypePattern(Pattern):
34  """A tree pattern that matches TF expressions with certain op types."""
35
36  def __init__(self, op_type, name=None, inputs=None):
37    """Initializes an OpTypePattern.
38
39    Args:
40      op_type: string that specifies the allowed types of the root. It can be
41        (1) an op type, e.g. 'Conv2D',
42        (2) '*', i.e. wildcard, or
43        (3) multiple op types separated by '|', e.g., 'Relu|Relu6'.
44        We could use regex strings, which might be worthwhile when we have many
45        similar TF op types.
46      name: Optional string. The name of the pattern that can be looked up in
47        MatchResult.
48      inputs: Optional list of `Pattern`s or strings that specify the
49        patterns for the inputs of a matching op. If None, this pattern accepts
50        any inputs of a matching op.
51    """
52    self._op_type = op_type
53    self._name = name
54    if inputs is None:
55      inputs = []
56    self._inputs = [
57        input_pattern
58        if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern)
59        for input_pattern in inputs
60    ]
61
62  @property
63  def name(self):
64    return self._name
65
66  def match(self, op, tensor):
67    if self._op_type != '*':
68      if op.type not in self._op_type.split('|'):
69        return None
70
71    match_result = MatchResult()
72    match_result.add(self, op, tensor)
73
74    if not self._inputs:
75      # If pattern.inputs is empty, skips the rest and accepts all the inputs.
76      return match_result
77
78    if len(op.inputs) != len(self._inputs):
79      return None
80
81    for input_tensor, input_pattern in zip(op.inputs, self._inputs):
82      input_match_result = input_pattern.match(input_tensor.op, input_tensor)
83      if input_match_result is None:
84        return None
85      match_result.merge_from(input_match_result)
86    return match_result
87
88
89class OneofPattern(Pattern):
90  """Matches one of the given sub-patterns."""
91
92  def __init__(self, sub_patterns):
93    self._sub_patterns = sub_patterns
94
95  def match(self, op, tensor):
96    for sub_pattern in self._sub_patterns:
97      match_result = sub_pattern.match(op, tensor)
98      if match_result is not None:
99        return match_result
100    return None
101
102
103class MatchResult(object):
104  r"""Encapsulates the result of a match done by GraphMatcher.
105
106  MatchResult contains a map from OpTypePattern to the matching op and tensor.
107  When the matching op has multiple output tensors, the matching tensor is the
108  output tensor used by the matching op of the parent pattern. E.g., when we
109  match graph
110
111      -         +
112     / \y0   y1/ \
113    x    split    z
114          |
115          y         (nodes are ops; edges are going up)
116
117  against add_pattern defined as
118
119    y1_pattern = OpTypePattern('*')
120    z_pattern = OpTypePattern('*')
121    add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern])
122
123  the matching op of `y1_pattern` is `split`, and the matching tensor of
124  `y1_pattern`
125  is `y1` not `y0`.
126  """
127
128  def __init__(self):
129    self._pattern_to_op_tensor = {}
130    self._name_to_pattern = {}
131
132  def add(self, pattern, op, tensor):
133    self._pattern_to_op_tensor[pattern] = op, tensor
134    if pattern.name is not None:
135      if pattern.name in self._name_to_pattern:
136        raise ValueError(
137            'Name %s is already bound to another pattern' % pattern.name)
138      self._name_to_pattern[pattern.name] = pattern
139
140  def _to_pattern(self, pattern_or_name):
141    if isinstance(pattern_or_name, OpTypePattern):
142      return pattern_or_name
143
144    if isinstance(pattern_or_name, str):
145      if pattern_or_name not in self._name_to_pattern:
146        return None
147      return self._name_to_pattern[pattern_or_name]
148
149    raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.'
150                     % type(pattern_or_name))
151
152  def _get_op_tensor(self, pattern_or_name):
153    pattern = self._to_pattern(pattern_or_name)
154    if pattern is None:
155      return None
156
157    if pattern not in self._pattern_to_op_tensor:
158      return None
159
160    return self._pattern_to_op_tensor[pattern]
161
162  def get_op(self, pattern_or_name):
163    op_tensor = self._get_op_tensor(pattern_or_name)
164    return op_tensor[0] if op_tensor else None
165
166  def get_tensor(self, pattern_or_name):
167    op_tensor = self._get_op_tensor(pattern_or_name)
168    return op_tensor[1] if op_tensor else None
169
170  def merge_from(self, other_match_result):
171    # pylint: disable=protected-access
172    self._pattern_to_op_tensor.update(other_match_result._pattern_to_op_tensor)
173    self._name_to_pattern.update(other_match_result._name_to_pattern)
174    # pylint: enable=protected-access
175
176
177class GraphMatcher(object):
178  """Checks if a particular subgraph matches a given pattern."""
179
180  def __init__(self, pattern):
181    """Initializes a GraphMatcher.
182
183    Args:
184      pattern: The `Pattern` against which `GraphMatcher` matches
185        subgraphs.
186    """
187    self._pattern = pattern
188
189  def _match_pattern(self, pattern, op, tensor):
190    """Returns whether an TF expression rooted at `op` matches `pattern`.
191
192    If there is a match, adds to `self._match_result` the matching op and tensor
193    with key `pattern`.
194
195    Args:
196      pattern: An `Pattern`.
197      op: A `tf.Operation` to match against the pattern.
198      tensor: the output `tf.Tensor` of `op` that is used by the matching op of
199        `pattern`'s parent. Can be None if `pattern` is already the root of the
200        pattern tree.
201
202    Returns:
203      True if an TF expression rooted at `op` matches `pattern`.
204    """
205    match_result = pattern.match(op, tensor)
206    if match_result is None:
207      return False
208    self._match_result.merge_from(match_result)
209    return True
210
211  def match_op(self, op):
212    """Matches `op` against `self._pattern`.
213
214    Args:
215      op: `tf.Operation` to match against the pattern.
216
217    Returns:
218      Returns a `MatchResult` if `op` matches the pattern; otherwise, returns
219      None.
220    """
221    self._match_result = MatchResult()
222    if not self._match_pattern(self._pattern, op, tensor=None):
223      return None
224    return self._match_result
225
226  def match_ops(self, ops):
227    """Matches each operation in `ops` against `self._pattern`.
228
229    Args:
230      ops: collection of `tf.Operation` to match against the pattern.
231
232    Yields:
233      `MatchResult` for each `tf.Operation` that matches the pattern.
234    """
235    for op in ops:
236      match_result = self.match_op(op)
237      if match_result:
238        yield match_result
239
240  def match_graph(self, graph):
241    """Matches each operation in `graph` against `self._pattern`.
242
243    Args:
244      graph: `tf.Graph` containing operations to match.
245
246    Yields:
247      `MatchResult` for each `tf.Operation` in `graph` that matches the pattern.
248    """
249    # Python 3.3.2+ implements `yield from`, but for now:
250    for match_result in self.match_ops(graph.get_operations()):
251      yield match_result
252