1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.example.android.rs.blasbenchmark;
18
19import android.renderscript.*;
20import android.util.Log;
21import java.util.Random;
22import java.lang.Math;
23
24public class BNNMTest extends TestBase {
25
26    static {
27        System.loadLibrary("gemmdata");
28    }
29
30    native void getData(byte[] a, byte[] b, byte[] c);
31
32    ScriptIntrinsicBLAS mBLAS;
33    private Allocation matA;
34    private Allocation matB;
35    private Allocation matC;
36
37    private int m;
38    private int n;
39    private int k;
40
41    private int a_offset;
42    private int b_offset;
43    private int c_offset;
44    private int c_mult_int;
45
46    private int mTestSize;
47
48    BNNMTest(int testSize) {
49        mTestSize = testSize;
50    }
51
52    public void createTest() {
53        mBLAS = ScriptIntrinsicBLAS.create(mRS);
54        setTest();
55    }
56
57    private void setTest() {
58        switch (mTestSize) {
59            case 1:
60                setTestSmall();
61                break;
62            case 2:
63                setTestMedium();
64                break;
65            case 3:
66                setTestLarge();
67                break;
68            default:
69                break;
70        }
71    }
72
73    // In Java, the eight-bit 'byte' type is signed, but the API for the 8-bit
74    // matrix multiplication deals with unsigned bytes. This is a convenience
75    // function that converts arrays of unsigned ints to their equivalent
76    // representations as signed bytes. For example, the bit pattern 0xff is 255
77    // as an unsigned value, but -127 as a Java signed byte. So if you pass in an
78    // array of int[] {255} into this function, you'll get back byte[] {-127}.
79    private byte[] unsignedToSignedByte(int[] input) {
80        byte[] output = new byte[input.length];
81        for (int i = 0; i < input.length; ++i) {
82            output[i] = (byte)(input[i]);
83        }
84        return output;
85    }
86
87
88    private void addByteNoise(byte[] data, int count, float frequency, int maxDelta) {
89        Random rand = new Random();
90        for (int n = 0; n < count; ++n) {
91            if (rand.nextFloat() < frequency) {
92                final int originalValue = data[n];
93                final float direction = rand.nextFloat();
94                int delta = (int)(Math.ceil(rand.nextFloat() * maxDelta));
95                if (direction < 0.5f) {
96                    delta = -delta;
97                }
98                int newValue = (originalValue + delta);
99                if (newValue < -127) {
100                    newValue = -127;
101                }
102                if (newValue > 127) {
103                    newValue = 127;
104                }
105                data[n] = (byte)(newValue);
106            }
107        }
108    }
109
110    private boolean testWithTolerance(byte[] c_byte, byte[] c_byte_output) {
111
112        // The testing procedure here is a bit complex, but the aim is to mimic the
113        // requirements we've empirically found running deep neural networks in real
114        // applications. We want to open the door to vendors using approximations that
115        // produce slightly different results for optimization's sake, but keep the
116        // precision loss within small enough bounds that we don't lose accuracy in
117        // the final result.
118        // After experimentation, we've found that we can tolerate around 5% of the
119        // output bytes being different by 1. Any larger differences are not tolerable
120        // and we can't get good results if the frequency of small differences is
121        // higher than 5%. This test tries to measure those properties on an example
122        // set of parameters that were captured from a real application.
123        // For example, if you uncommented this function that adds random noise to the
124        // results at a 3% specified frequency, the test should fail:
125        // AddByteNoise(c_byte_output, c_count, 0.03f, 1);
126
127        final boolean areSizesDifferent = (c_byte.length != c_byte_output.length);
128        final int c_count = Math.min(c_byte.length, c_byte_output.length);
129
130        int howManyDifferent = 0;
131        boolean areAnyTooDifferent = false;
132        for (int i = 0; i < c_count; i++) {
133            byte expectedValue = c_byte[i];
134            byte actualValue = c_byte_output[i];
135            int delta = (expectedValue - actualValue);
136            // First make sure that the difference is no more than one.
137            if ((delta < -1) || (delta > 1)) {
138                areAnyTooDifferent = true;
139            }
140            // If there is a difference, increment the counter to track it.
141            if (delta != 0) {
142                // Don't spam the logs if too many are different.
143                if (howManyDifferent < 50) {
144                    android.util.Log.e("BNNM", "Mismatch at " + i +
145                                       ": expected " + (expectedValue & 0xff) +
146                                       ", got " + (actualValue & 0xff));
147                }
148                ++howManyDifferent;
149            }
150        }
151        // We want no more than 2% of the values to show any differences, so work out
152        // what that means in absolute numbers.
153        final int percentThreshold = 2;
154        final int differenceThreshold = Math.max((percentThreshold * c_count) / 100, 1);
155        final boolean areTooManyDifferent = (howManyDifferent >= differenceThreshold);
156
157        if (areAnyTooDifferent) {
158            android.util.Log.e("BNNM", "Some outputs were too different.");
159        }
160
161        if (areTooManyDifferent) {
162            android.util.Log.e("BNNM", "There were too many small differences." +
163                               " We can tolerate " + percentThreshold + "% (" +
164                               differenceThreshold + "), but there were " + howManyDifferent);
165        }
166
167        return !(areAnyTooDifferent || areTooManyDifferent);
168    }
169
170    // This test multiplies a couple of small 8-bit matrices, and compares the
171    // results with hand-calculated expectations.
172    public void setTestSmall() {
173        // The A matrix is:
174        // |   1 |   4 |
175        // |   2 |   5 |
176        // |   3 |   6 |
177        byte[] a_byte = unsignedToSignedByte(new int[] {
178                1, 2, 3,
179                4, 5, 6,
180            });
181        final int a_rows = 3;
182        final int a_cols = 2;
183        a_offset = 0;
184        // The B matrix is:
185        // |  -1 |  -2 |  -3 |  -4 |
186        // |  -5 |  -6 |  -7 |  -8 |
187        // |  -9 | -10 | -11 | -12 |
188        byte[] b_byte = unsignedToSignedByte(new int[] {
189                11, 7, 3,
190                10, 6, 2,
191                9, 5, 1,
192                8, 4, 0,
193            });
194        final int b_cols = 4;
195        b_offset = 12;
196        // EightBitGemm implements C = B.transposed() * A,
197        // so we expect to get these results:
198        // 1*-1 + 2*-5 + 3*-9 + 128 = 90
199        // 1*-2 + 2*-6 + 3*-10 + 128 = 84
200        // 1*-3 + 2*-7 + 3*-11 + 128 = 78
201        // 1*-4 + 2*-8 + 3*-12 + 128 = 72
202        // 4*-1 + 5*-5 + 6*-9 + 128 = 45
203        // 4*-2 + 5*-6 + 6*-10 + 128 = 30
204        // 4*-3 + 5*-7 + 6*-11 + 128 = 15
205        // 4*-4 + 5*-8 + 6*-12 + 128 = 0
206        // | 90 |  45 |
207        // | 84 |  30 |
208        // | 78 | 15 |
209        // | 72 | 0 |
210        c_offset = 128;
211        final int c_shift = 21;
212        c_mult_int = (1 << c_shift);
213        byte[] expected_data = unsignedToSignedByte(new int[] {
214                90, 84, 78, 72,
215                45, 30, 15, 0,
216            });
217
218        m = a_cols;
219        n = b_cols;
220        k = a_rows;
221
222        Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS));
223        Type a_type = builder.setX(k).setY(m).create();
224        Type b_type = builder.setX(k).setY(n).create();
225        Type c_type = builder.setX(n).setY(m).create();
226
227        matA = Allocation.createTyped(mRS, a_type);
228        matB = Allocation.createTyped(mRS, b_type);
229        matC = Allocation.createTyped(mRS, c_type);
230        matA.copyFrom(a_byte);
231        matB.copyFrom(b_byte);
232
233        //During setup, do a sample run to see if the result is correct.
234        mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int);
235        int c_count = (m * n);
236        byte[] c_byte_output = new byte[c_count];
237        matC.copyTo(c_byte_output);
238        if (!testWithTolerance(expected_data, c_byte_output)) {
239            Log.e(TAG, "Result is not correct!");
240            throw new AssertionError("Result is not correct.");
241        }
242    }
243
244
245    // This test multiplies another two medium 8-bit matrices, and compares the
246    // results with the expected values. The data here is arbitrary.
247    public void setTestMedium() {
248        byte[] a_byte = unsignedToSignedByte(new int[] {
249                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
250                23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
251                1, 23, 2, 22, 3, 21, 4, 20, 5, 19, 6, 18, 7, 17, 8, 16, 9, 15, 10, 14, 11, 13, 12,
252                23, 1, 22, 2, 21, 3, 20, 4, 19, 5, 18, 6, 17, 7, 16, 8, 15, 9, 14, 10, 13, 11, 12,
253                1, 1, 1, 1, 1, 1, 1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
254                3, 1, 4, 1, 5, 8, 2, 3, 1, 14, 11, 15, 18, 12, 13, 11, 14, 11, 15, 18, 12, 13, 11,
255                8, 0, 5, 8, 1, 3, 7, 5, 7, 13, 10, 23, 13, 11, 17, 23, 12, 19, 17, 13, 14, 10, 19,
256            });
257        final int a_rows = 23;
258        final int a_cols = 7;
259        a_offset = 13;
260        byte[] b_byte = unsignedToSignedByte(new int[] {
261                0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 11, 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9,
262                0, 20, 40, 60, 80, 10, 11, 13, 15, 17, 19, 21, 10, 12, 14, 6, 8, 10, 1, 3, 5, 7, 9,
263                1, 21, 41, 61, 81, 11, 12, 14, 16, 18, 20, 22, 11, 13, 15, 7, 9, 11, 2, 4, 6, 8, 9,
264                0, 19, 39, 59, 79, 9, 10, 12, 14, 16, 18, 20, 9, 11, 13, 5, 7, 9, 0, 2, 4, 6, 8,
265                2, 22, 42, 62, 82, 12, 13, 15, 17, 19, 21, 23, 12, 14, 16, 8, 9, 12, 3, 5, 7, 9, 9,
266                0, 18, 38, 58, 78, 8, 9, 11, 13, 15, 17, 19, 8, 10, 12, 4, 6, 8, 0, 1, 3, 5, 7,
267                3, 23, 43, 63, 83, 13, 14, 16, 18, 20, 22, 24, 13, 15, 17, 9, 9, 13, 4, 6, 8, 9, 9,
268                0, 17, 37, 57, 77, 7, 8, 10, 12, 14, 16, 18, 7, 9, 11, 3, 5, 7, 0, 0, 2, 4, 6,
269                10, 20, 30, 40, 50, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 1, 2, 3,
270            });
271        final int b_cols = 9;
272        b_offset = 23;
273        c_offset = 2121;
274        final int c_shift = 21;
275        c_mult_int = 132359;
276        byte[] expected_data = unsignedToSignedByte(new int[] {
277                167, 53, 51, 54, 49, 55, 46,
278                56, 116, 153, 232, 232, 234, 231,
279                236, 232, 237, 174, 168, 131, 130,
280                132, 129, 133, 128, 133, 134, 151,
281                154, 152, 156, 151, 158, 150, 160,
282                156, 255, 113, 106, 120, 98, 127,
283                91, 134, 178, 231, 102, 97, 107,
284                92, 111, 87, 116, 164, 187, 76,
285                73, 78, 70, 81, 67, 83, 139,
286            });
287
288        m = a_cols;
289        n = b_cols;
290        k = a_rows;
291
292        Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS));
293        Type a_type = builder.setX(k).setY(m).create();
294        Type b_type = builder.setX(k).setY(n).create();
295        Type c_type = builder.setX(n).setY(m).create();
296
297        matA = Allocation.createTyped(mRS, a_type);
298        matB = Allocation.createTyped(mRS, b_type);
299        matC = Allocation.createTyped(mRS, c_type);
300
301        matA.copyFrom(a_byte);
302        matB.copyFrom(b_byte);
303
304        //During setup, do a sample run to see if the result is correct.
305        mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int);
306        int c_count = (m * n);
307        byte[] c_byte_output = new byte[c_count];
308        matC.copyTo(c_byte_output);
309        if (!testWithTolerance(expected_data, c_byte_output)) {
310            Log.e(TAG, "Result is not correct!");
311            throw new AssertionError("Result is not correct.");
312        }
313    }
314
315
316
317    // This test takes a large set of real data captured from a convolutional
318    // neural network solving a computer vision problem, and runs it through the
319    // eight-bit matrix multiply. We test the results to make sure they're close
320    // enough to be usable.
321    public void setTestLarge() {
322
323        m = 256;
324        n = 192;
325        k = 1152;
326        a_offset = 0;
327        b_offset = 84;
328        c_mult_int = 3401;
329        c_offset = 74980;
330
331        int a_count = (m * k);
332        int b_count = (n * k);
333        int c_count = (m * n);
334
335        byte[] a_byte = new byte[a_count];
336        byte[] b_byte = new byte[b_count];
337        byte[] c_byte = new byte[c_count];
338
339        getData(a_byte, b_byte, c_byte);
340
341        Type.Builder builder = new Type.Builder(mRS, Element.U8(mRS));
342        Type a_type = builder.setX(k).setY(m).create();
343        Type b_type = builder.setX(k).setY(n).create();
344        Type c_type = builder.setX(n).setY(m).create();
345
346        matA = Allocation.createTyped(mRS, a_type);
347        matB = Allocation.createTyped(mRS, b_type);
348        matC = Allocation.createTyped(mRS, c_type);
349
350        matA.copyFrom(a_byte);
351        matB.copyFrom(b_byte);
352
353        //During setup, do a sample run to see if the result is correct.
354        mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int);
355        byte[] c_byte_output = new byte[c_count];
356        matC.copyTo(c_byte_output);
357        if (!testWithTolerance(c_byte, c_byte_output)) {
358            Log.e(TAG, "Result is not correct!");
359            throw new AssertionError("Result is not correct.");
360        }
361    }
362
363    public void runTest() {
364        mBLAS.BNNM(matA, a_offset, matB, b_offset, matC, c_offset, c_mult_int);
365    }
366
367    public String getTestInfo() {
368        return "8Bit GEMM Test: m=" + m + ", n=" + n + ", k=" + k;
369    }
370}
371