1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.example.android.rs.blasbenchmark;
18
19import android.renderscript.*;
20import android.util.Log;
21import java.util.Random;
22import java.lang.Math;
23
24public class SGEMMTest extends TestBase {
25
26    static {
27        System.loadLibrary("gemmdata");
28    }
29
30    native void getData(byte[] a, byte[] b, byte[] c);
31
32    ScriptIntrinsicBLAS mBLAS;
33    private Allocation matA;
34    private Allocation matB;
35    private Allocation matC;
36
37    private int m;
38    private int n;
39    private int k;
40
41    private int a_offset;
42    private int b_offset;
43    private int mTestSize;
44    private final float allowedError = 0.000001f;
45
46    SGEMMTest(int testSize) {
47        mTestSize = testSize;
48    }
49
50    public void createTest() {
51        mBLAS = ScriptIntrinsicBLAS.create(mRS);
52        setTest();
53    }
54
55    private void setTest() {
56        switch (mTestSize) {
57            case 1:
58                setTestSmall();
59                break;
60            case 2:
61                setTestMedium();
62                break;
63            case 3:
64                setTestLarge();
65                break;
66            default:
67                break;
68        }
69    }
70
71    // Calculate the square of the L2 norm of a matrix.
72    private float calcL2Norm(float[] input) {
73        float l2Norm = 0.f;
74        for (int i = 0; i < input.length; ++i) {
75            l2Norm += input[i] * input[i];
76        }
77        return l2Norm;
78    }
79
80    // Test whether the error of each element is samller the allowed error range.
81    private boolean testWithTolerance(float[] out, float[] ref) {
82        float l2NormOut = calcL2Norm(out);
83        float l2NormRef = calcL2Norm(ref);
84        float tolerance = allowedError * (l2NormOut < l2NormRef ? l2NormOut : l2NormRef);
85        tolerance /= m * n;
86        for (int i = 0; i < out.length; ++i) {
87            float err = out[i] - ref[i];
88            float absErr = err * err;
89            if (absErr > tolerance) {
90                return false;
91            }
92        }
93        return true;
94    }
95
96    // Transform byte data into float, given a offset.
97    private float[] byteToFloat(byte[] input, int offset) {
98        float[] output = new float[input.length];
99        for (int i = 0; i < input.length; ++i) {
100            output[i] = (float)(input[i] - offset);
101        }
102        return output;
103    }
104
105    // Calculate the reference result for C = A*B
106    private float[] getGEMMResult(int m, int n, int k, float[] a_float, float[] b_float) {
107        float[] c_float = new float[m * n];
108        for (int j = 0; j < n; j++) {
109            for (int i = 0; i < m; i++) {
110                float total = 0.f;
111                for (int l = 0; l < k; l++) {
112                    int a_index = ((i * k) + l);
113                    int b_index = ((l * n) + j);
114                    float mult = a_float[a_index] * b_float[b_index];
115                    total += mult;
116                }
117                int c_index = ((i * n) + j);
118                c_float[c_index] = total;
119            }
120        }
121        return c_float;
122    }
123
124    // This test multiplies a couple of small float matrices, and compares the
125    // results with java-calculated expectations. The data here is arbitrary.
126    public void setTestSmall() {
127        m = 2;
128        n = 4;
129        k = 3;
130        a_offset = 0;
131        b_offset = 12;
132
133        float[] a_float = byteToFloat(new byte[] {
134                1, 2, 3,
135                4, 5, 6,
136            }, a_offset);
137
138        float[] b_float = byteToFloat(new byte[] {
139                11, 7, 3,
140                10, 6, 2,
141                9, 5, 1,
142                8, 4, 0,
143            }, b_offset);
144
145        Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
146        Type a_type = builder.setX(k).setY(m).create();
147        Type b_type = builder.setX(n).setY(k).create();
148        Type c_type = builder.setX(n).setY(m).create();
149
150        matA = Allocation.createTyped(mRS, a_type);
151        matB = Allocation.createTyped(mRS, b_type);
152        matC = Allocation.createTyped(mRS, c_type);
153
154        matA.copyFrom(a_float);
155        matB.copyFrom(b_float);
156
157        //During setup, do a sample run to see if the result is correct.
158        mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
159                    1.0f, matA, matB, 0.f, matC);
160        float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
161        float[] c_float_out = new float[m * n];
162        matC.copyTo(c_float_out);
163        if (!testWithTolerance(c_float_ref, c_float_out)) {
164            Log.e(TAG, "Result is not correct!");
165            throw new AssertionError("Result is not correct.");
166        }
167    }
168
169    // This test multiplies another two medium matrices, and compares the
170    // results with the expected values. The data here is arbitrary.
171    public void setTestMedium() {
172        m = 7;
173        n = 9;
174        k = 23;
175        a_offset = 13;
176        b_offset = 23;
177
178        float[] a_float = byteToFloat(new byte[] {
179                1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
180                23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
181                1, 23, 2, 22, 3, 21, 4, 20, 5, 19, 6, 18, 7, 17, 8, 16, 9, 15, 10, 14, 11, 13, 12,
182                23, 1, 22, 2, 21, 3, 20, 4, 19, 5, 18, 6, 17, 7, 16, 8, 15, 9, 14, 10, 13, 11, 12,
183                1, 1, 1, 1, 1, 1, 1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
184                3, 1, 4, 1, 5, 8, 2, 3, 1, 14, 11, 15, 18, 12, 13, 11, 14, 11, 15, 18, 12, 13, 11,
185                8, 0, 5, 8, 1, 3, 7, 5, 7, 13, 10, 23, 13, 11, 17, 23, 12, 19, 17, 13, 14, 10, 19,
186            }, a_offset);
187
188        float[] b_float = byteToFloat(new byte[] {
189                0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 11, 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9,
190                0, 20, 40, 60, 80, 10, 11, 13, 15, 17, 19, 21, 10, 12, 14, 6, 8, 10, 1, 3, 5, 7, 9,
191                1, 21, 41, 61, 81, 11, 12, 14, 16, 18, 20, 22, 11, 13, 15, 7, 9, 11, 2, 4, 6, 8, 9,
192                0, 19, 39, 59, 79, 9, 10, 12, 14, 16, 18, 20, 9, 11, 13, 5, 7, 9, 0, 2, 4, 6, 8,
193                2, 22, 42, 62, 82, 12, 13, 15, 17, 19, 21, 23, 12, 14, 16, 8, 9, 12, 3, 5, 7, 9, 9,
194                0, 18, 38, 58, 78, 8, 9, 11, 13, 15, 17, 19, 8, 10, 12, 4, 6, 8, 0, 1, 3, 5, 7,
195                3, 23, 43, 63, 83, 13, 14, 16, 18, 20, 22, 24, 13, 15, 17, 9, 9, 13, 4, 6, 8, 9, 9,
196                0, 17, 37, 57, 77, 7, 8, 10, 12, 14, 16, 18, 7, 9, 11, 3, 5, 7, 0, 0, 2, 4, 6,
197                10, 20, 30, 40, 50, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 1, 2, 3,
198            }, b_offset);
199
200        Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
201        Type a_type = builder.setX(k).setY(m).create();
202        Type b_type = builder.setX(n).setY(k).create();
203        Type c_type = builder.setX(n).setY(m).create();
204
205        matA = Allocation.createTyped(mRS, a_type);
206        matB = Allocation.createTyped(mRS, b_type);
207        matC = Allocation.createTyped(mRS, c_type);
208
209        matA.copyFrom(a_float);
210        matB.copyFrom(b_float);
211
212        //During setup, do a sample run to see if the result is correct.
213        mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
214                    1.0f, matA, matB, 0.f, matC);
215        float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
216        float[] c_float_out = new float[m * n];
217        matC.copyTo(c_float_out);
218        if (!testWithTolerance(c_float_ref, c_float_out)) {
219            Log.e(TAG, "Result is not correct!");
220            throw new AssertionError("Result is not correct.");
221        }
222    }
223
224
225    // This test takes a large set of real data captured from a convolutional
226    // neural network solving a computer vision problem, and runs it through SGEMM.
227    public void setTestLarge() {
228
229        m = 256;
230        n = 192;
231        k = 1152;
232        a_offset = 0;
233        b_offset = 84;
234
235        int a_count = (m * k);
236        int b_count = (n * k);
237        int c_count = (m * n);
238
239        byte[] a_byte = new byte[a_count];
240        byte[] b_byte = new byte[b_count];
241        byte[] c_byte = new byte[c_count];
242
243        getData(a_byte, b_byte, c_byte);
244
245        float[] a_float = byteToFloat(a_byte, a_offset);
246        float[] b_float = byteToFloat(b_byte, b_offset);
247
248        Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
249        Type a_type = builder.setX(k).setY(m).create();
250        Type b_type = builder.setX(n).setY(k).create();
251        Type c_type = builder.setX(n).setY(m).create();
252
253        matA = Allocation.createTyped(mRS, a_type);
254        matB = Allocation.createTyped(mRS, b_type);
255        matC = Allocation.createTyped(mRS, c_type);
256
257        matA.copyFrom(a_float);
258        matB.copyFrom(b_float);
259
260        //During setup, do a sample run to see if the result is correct.
261        mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
262                    1.0f, matA, matB, 0.f, matC);
263        float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
264        float[] c_float_out = new float[c_count];
265        matC.copyTo(c_float_out);
266        if (!testWithTolerance(c_float_ref, c_float_out)) {
267            Log.e(TAG, "Result is not correct!");
268            throw new AssertionError("Result is not correct.");
269        }
270    }
271
272    public void runTest() {
273        mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
274                    1.0f, matA, matB, 0.f, matC);
275    }
276
277    public String getTestInfo() {
278        return "SGEMM Test: m=" + m + ", n=" + n + ", k=" + k;
279    }
280}
281