1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for print_selective_registration_header."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import sys
23
24from google.protobuf import text_format
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.python.platform import gfile
28from tensorflow.python.platform import test
29from tensorflow.python.tools import selective_registration_header_lib
30
31# Note that this graph def is not valid to be loaded - its inputs are not
32# assigned correctly in all cases.
33GRAPH_DEF_TXT = """
34  node: {
35    name: "node_1"
36    op: "Reshape"
37    input: [ "none", "none" ]
38    device: "/cpu:0"
39    attr: { key: "T" value: { type: DT_FLOAT } }
40  }
41  node: {
42    name: "node_2"
43    op: "MatMul"
44    input: [ "none", "none" ]
45    device: "/cpu:0"
46    attr: { key: "T" value: { type: DT_FLOAT } }
47    attr: { key: "transpose_a" value: { b: false } }
48    attr: { key: "transpose_b" value: { b: false } }
49  }
50  node: {
51    name: "node_3"
52    op: "MatMul"
53    input: [ "none", "none" ]
54    device: "/cpu:0"
55    attr: { key: "T" value: { type: DT_DOUBLE } }
56    attr: { key: "transpose_a" value: { b: false } }
57    attr: { key: "transpose_b" value: { b: false } }
58  }
59"""
60
61GRAPH_DEF_TXT_2 = """
62  node: {
63    name: "node_4"
64    op: "BiasAdd"
65    input: [ "none", "none" ]
66    device: "/cpu:0"
67    attr: { key: "T" value: { type: DT_FLOAT } }
68  }
69
70"""
71
72
73class PrintOpFilegroupTest(test.TestCase):
74
75  def setUp(self):
76    _, self.script_name = os.path.split(sys.argv[0])
77
78  def WriteGraphFiles(self, graphs):
79    fnames = []
80    for i, graph in enumerate(graphs):
81      fname = os.path.join(self.get_temp_dir(), 'graph%s.pb' % i)
82      with gfile.GFile(fname, 'wb') as f:
83        f.write(graph.SerializeToString())
84      fnames.append(fname)
85    return fnames
86
87  def testGetOps(self):
88    default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
89    graphs = [
90        text_format.Parse(d, graph_pb2.GraphDef())
91        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
92    ]
93
94    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
95        'rawproto', self.WriteGraphFiles(graphs), default_ops)
96    self.assertListEqual(
97        [
98            ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
99            ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
100            ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
101            ('NoOp', 'NoOp'),  #
102            ('Reshape', 'ReshapeOp'),  #
103            ('_Recv', 'RecvOp'),  #
104            ('_Send', 'SendOp'),  #
105        ],
106        ops_and_kernels)
107
108    graphs[0].node[0].ClearField('device')
109    graphs[0].node[2].ClearField('device')
110    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
111        'rawproto', self.WriteGraphFiles(graphs), default_ops)
112    self.assertListEqual(
113        [
114            ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
115            ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
116            ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
117            ('NoOp', 'NoOp'),  #
118            ('Reshape', 'ReshapeOp'),  #
119            ('_Recv', 'RecvOp'),  #
120            ('_Send', 'SendOp'),  #
121        ],
122        ops_and_kernels)
123
124  def testAll(self):
125    default_ops = 'all'
126    graphs = [
127        text_format.Parse(d, graph_pb2.GraphDef())
128        for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]
129    ]
130    ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
131        'rawproto', self.WriteGraphFiles(graphs), default_ops)
132
133    header = selective_registration_header_lib.get_header_from_ops_and_kernels(
134        ops_and_kernels, include_all_ops_and_kernels=True)
135    self.assertListEqual(
136        [
137            '// This file was autogenerated by %s' % self.script_name,
138            '#ifndef OPS_TO_REGISTER',  #
139            '#define OPS_TO_REGISTER',  #
140            '#define SHOULD_REGISTER_OP(op) true',  #
141            '#define SHOULD_REGISTER_OP_KERNEL(clz) true',  #
142            '#define SHOULD_REGISTER_OP_GRADIENT true',  #
143            '#endif'
144        ],
145        header.split('\n'))
146
147    self.assertListEqual(
148        header.split('\n'),
149        selective_registration_header_lib.get_header(
150            self.WriteGraphFiles(graphs), 'rawproto', default_ops).split('\n'))
151
152  def testGetSelectiveHeader(self):
153    default_ops = ''
154    graphs = [text_format.Parse(GRAPH_DEF_TXT_2, graph_pb2.GraphDef())]
155
156    expected = '''// This file was autogenerated by %s
157#ifndef OPS_TO_REGISTER
158#define OPS_TO_REGISTER
159
160    namespace {
161      constexpr const char* skip(const char* x) {
162        return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
163      }
164
165      constexpr bool isequal(const char* x, const char* y) {
166        return (*skip(x) && *skip(y))
167                   ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1))
168                   : (!*skip(x) && !*skip(y));
169      }
170
171      template<int N>
172      struct find_in {
173        static constexpr bool f(const char* x, const char* const y[N]) {
174          return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1);
175        }
176      };
177
178      template<>
179      struct find_in<0> {
180        static constexpr bool f(const char* x, const char* const y[]) {
181          return false;
182        }
183      };
184    }  // end namespace
185    constexpr const char* kNecessaryOpKernelClasses[] = {
186"BiasOp<CPUDevice, float>",
187};
188#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
189
190constexpr inline bool ShouldRegisterOp(const char op[]) {
191  return false
192     || isequal(op, "BiasAdd")
193  ;
194}
195#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
196
197#define SHOULD_REGISTER_OP_GRADIENT false
198#endif''' % self.script_name
199
200    header = selective_registration_header_lib.get_header(
201        self.WriteGraphFiles(graphs), 'rawproto', default_ops)
202    print(header)
203    self.assertListEqual(expected.split('\n'), header.split('\n'))
204
205
206if __name__ == '__main__':
207  test.main()
208