1/*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "pass_queue.h"
18
19#include "file_utils.h"
20#include "spirit.h"
21#include "test_utils.h"
22#include "transformer.h"
23#include "gtest/gtest.h"
24
25#include <stdint.h>
26
27namespace android {
28namespace spirit {
29
30namespace {
31
32class MulToAddTransformer : public Transformer {
33public:
34  Instruction *transform(IMulInst *mul) override {
35    auto ret = new IAddInst(mul->mResultType, mul->mOperand1, mul->mOperand2);
36    ret->setId(mul->getId());
37    return ret;
38  }
39};
40
41class AddToDivTransformer : public Transformer {
42public:
43  Instruction *transform(IAddInst *add) override {
44    auto ret = new SDivInst(add->mResultType, add->mOperand1, add->mOperand2);
45    ret->setId(add->getId());
46    return ret;
47  }
48};
49
50class AddMulAfterAddTransformer : public Transformer {
51public:
52  Instruction *transform(IAddInst *add) override {
53    insert(add);
54    auto ret = new IMulInst(add->mResultType, add, add);
55    ret->setId(add->getId());
56    return ret;
57  }
58};
59
60class Deleter : public Transformer {
61public:
62  Instruction *transform(IMulInst *) override { return nullptr; }
63};
64
65class InPlaceModifyingPass : public Pass {
66public:
67  Module *run(Module *m, int *error) override {
68    m->getFloatType(64);
69    if (error) {
70      *error = 0;
71    }
72    return m;
73  }
74};
75
76} // annonymous namespace
77
78class PassQueueTest : public ::testing::Test {
79protected:
80  virtual void SetUp() { mWordsGreyscale = readWords("greyscale.spv"); }
81
82  std::vector<uint32_t> mWordsGreyscale;
83
84private:
85  std::vector<uint32_t> readWords(const char *testFile) {
86    static const std::string testDataPath(
87        "frameworks/rs/rsov/compiler/spirit/test_data/");
88    const std::string &fullPath = getAbsolutePath(testDataPath + testFile);
89    return readFile<uint32_t>(fullPath);
90  }
91};
92
93TEST_F(PassQueueTest, testMulToAdd) {
94  std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
95
96  ASSERT_NE(nullptr, m);
97
98  EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
99  EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
100
101  PassQueue passes;
102  passes.append(new MulToAddTransformer());
103  auto m1 = passes.run(m.get());
104
105  ASSERT_NE(nullptr, m1);
106
107  ASSERT_TRUE(m1->resolveIds());
108
109  EXPECT_EQ(2, countEntity<IAddInst>(m1));
110  EXPECT_EQ(0, countEntity<IMulInst>(m1));
111}
112
113TEST_F(PassQueueTest, testInPlaceModifying) {
114  std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
115
116  ASSERT_NE(nullptr, m);
117
118  EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
119  EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
120  EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
121
122  PassQueue passes;
123  passes.append(new InPlaceModifyingPass());
124  auto m1 = passes.run(m.get());
125
126  ASSERT_NE(nullptr, m1);
127
128  ASSERT_TRUE(m1->resolveIds());
129
130  EXPECT_EQ(1, countEntity<IAddInst>(m1));
131  EXPECT_EQ(1, countEntity<IMulInst>(m1));
132  EXPECT_EQ(2, countEntity<TypeFloatInst>(m1));
133}
134
135TEST_F(PassQueueTest, testDeletion) {
136  std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
137
138  ASSERT_NE(nullptr, m.get());
139
140  EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
141
142  PassQueue passes;
143  passes.append(new Deleter());
144  auto m1 = passes.run(m.get());
145
146  // One of the ids from the input module is missing now.
147  ASSERT_EQ(nullptr, m1);
148}
149
150TEST_F(PassQueueTest, testMulToAddToDiv) {
151  std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
152
153  ASSERT_NE(nullptr, m);
154
155  EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
156  EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
157
158  PassQueue passes;
159  passes.append(new MulToAddTransformer());
160  passes.append(new AddToDivTransformer());
161  auto m1 = passes.run(m.get());
162
163  ASSERT_NE(nullptr, m1);
164
165  ASSERT_TRUE(m1->resolveIds());
166
167  EXPECT_EQ(0, countEntity<IAddInst>(m1));
168  EXPECT_EQ(0, countEntity<IMulInst>(m1));
169  EXPECT_EQ(2, countEntity<SDivInst>(m1));
170}
171
172TEST_F(PassQueueTest, testAMix) {
173  std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
174
175  ASSERT_NE(nullptr, m);
176
177  EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
178  EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
179  EXPECT_EQ(0, countEntity<SDivInst>(m.get()));
180  EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
181
182  PassQueue passes;
183  passes.append(new MulToAddTransformer());
184  passes.append(new AddToDivTransformer());
185  passes.append(new InPlaceModifyingPass());
186
187  std::unique_ptr<Module> m1(passes.run(m.get()));
188
189  ASSERT_NE(nullptr, m1);
190
191  ASSERT_TRUE(m1->resolveIds());
192
193  EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
194  EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
195  EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
196  EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get()));
197}
198
199TEST_F(PassQueueTest, testAnotherMix) {
200  std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
201
202  ASSERT_NE(nullptr, m);
203
204  EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
205  EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
206  EXPECT_EQ(0, countEntity<SDivInst>(m.get()));
207  EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get()));
208
209  PassQueue passes;
210  passes.append(new InPlaceModifyingPass());
211  passes.append(new MulToAddTransformer());
212  passes.append(new AddToDivTransformer());
213  auto outputWords = passes.runAndSerialize(m.get());
214
215  std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
216
217  ASSERT_NE(nullptr, m1);
218
219  ASSERT_TRUE(m1->resolveIds());
220
221  EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
222  EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
223  EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
224  EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get()));
225}
226
227TEST_F(PassQueueTest, testMulToAddToDivFromWords) {
228  PassQueue passes;
229  passes.append(new MulToAddTransformer());
230  passes.append(new AddToDivTransformer());
231  auto outputWords = passes.run(std::move(mWordsGreyscale));
232
233  std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
234
235  ASSERT_NE(nullptr, m1);
236
237  ASSERT_TRUE(m1->resolveIds());
238
239  EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
240  EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
241  EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
242}
243
244TEST_F(PassQueueTest, testMulToAddToDivToWords) {
245  std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
246
247  ASSERT_NE(nullptr, m);
248
249  EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
250  EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
251
252  PassQueue passes;
253  passes.append(new MulToAddTransformer());
254  passes.append(new AddToDivTransformer());
255  auto outputWords = passes.runAndSerialize(m.get());
256
257  std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
258
259  ASSERT_NE(nullptr, m1);
260
261  ASSERT_TRUE(m1->resolveIds());
262
263  EXPECT_EQ(0, countEntity<IAddInst>(m1.get()));
264  EXPECT_EQ(0, countEntity<IMulInst>(m1.get()));
265  EXPECT_EQ(2, countEntity<SDivInst>(m1.get()));
266}
267
268TEST_F(PassQueueTest, testAddMulAfterAdd) {
269  std::unique_ptr<Module> m(Deserialize<Module>(mWordsGreyscale));
270
271  ASSERT_NE(nullptr, m);
272
273  EXPECT_EQ(1, countEntity<IAddInst>(m.get()));
274  EXPECT_EQ(1, countEntity<IMulInst>(m.get()));
275
276  constexpr int kNumMulToAdd = 100;
277
278  PassQueue passes;
279  for (int i = 0; i < kNumMulToAdd; i++) {
280    passes.append(new AddMulAfterAddTransformer());
281  }
282  auto outputWords = passes.runAndSerialize(m.get());
283
284  std::unique_ptr<Module> m1(Deserialize<Module>(outputWords));
285
286  ASSERT_NE(nullptr, m1);
287
288  ASSERT_TRUE(m1->resolveIds());
289
290  EXPECT_EQ(1, countEntity<IAddInst>(m1.get()));
291  EXPECT_EQ(1 + kNumMulToAdd, countEntity<IMulInst>(m1.get()));
292}
293
294} // namespace spirit
295} // namespace android
296