1/* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package com.example.android.rs.blasbenchmark; 18 19import android.renderscript.*; 20import android.util.Log; 21import java.util.Random; 22import java.lang.Math; 23 24public class SGEMMTest extends TestBase { 25 26 static { 27 System.loadLibrary("gemmdata"); 28 } 29 30 native void getData(byte[] a, byte[] b, byte[] c); 31 32 ScriptIntrinsicBLAS mBLAS; 33 private Allocation matA; 34 private Allocation matB; 35 private Allocation matC; 36 37 private int m; 38 private int n; 39 private int k; 40 41 private int a_offset; 42 private int b_offset; 43 private int mTestSize; 44 private final float allowedError = 0.000001f; 45 46 SGEMMTest(int testSize) { 47 mTestSize = testSize; 48 } 49 50 public void createTest() { 51 mBLAS = ScriptIntrinsicBLAS.create(mRS); 52 setTest(); 53 } 54 55 private void setTest() { 56 switch (mTestSize) { 57 case 1: 58 setTestSmall(); 59 break; 60 case 2: 61 setTestMedium(); 62 break; 63 case 3: 64 setTestLarge(); 65 break; 66 default: 67 break; 68 } 69 } 70 71 // Calculate the square of the L2 norm of a matrix. 72 private float calcL2Norm(float[] input) { 73 float l2Norm = 0.f; 74 for (int i = 0; i < input.length; ++i) { 75 l2Norm += input[i] * input[i]; 76 } 77 return l2Norm; 78 } 79 80 // Test whether the error of each element is samller the allowed error range. 81 private boolean testWithTolerance(float[] out, float[] ref) { 82 float l2NormOut = calcL2Norm(out); 83 float l2NormRef = calcL2Norm(ref); 84 float tolerance = allowedError * (l2NormOut < l2NormRef ? l2NormOut : l2NormRef); 85 tolerance /= m * n; 86 for (int i = 0; i < out.length; ++i) { 87 float err = out[i] - ref[i]; 88 float absErr = err * err; 89 if (absErr > tolerance) { 90 return false; 91 } 92 } 93 return true; 94 } 95 96 // Transform byte data into float, given a offset. 97 private float[] byteToFloat(byte[] input, int offset) { 98 float[] output = new float[input.length]; 99 for (int i = 0; i < input.length; ++i) { 100 output[i] = (float)(input[i] - offset); 101 } 102 return output; 103 } 104 105 // Calculate the reference result for C = A*B 106 private float[] getGEMMResult(int m, int n, int k, float[] a_float, float[] b_float) { 107 float[] c_float = new float[m * n]; 108 for (int j = 0; j < n; j++) { 109 for (int i = 0; i < m; i++) { 110 float total = 0.f; 111 for (int l = 0; l < k; l++) { 112 int a_index = ((i * k) + l); 113 int b_index = ((l * n) + j); 114 float mult = a_float[a_index] * b_float[b_index]; 115 total += mult; 116 } 117 int c_index = ((i * n) + j); 118 c_float[c_index] = total; 119 } 120 } 121 return c_float; 122 } 123 124 // This test multiplies a couple of small float matrices, and compares the 125 // results with java-calculated expectations. The data here is arbitrary. 126 public void setTestSmall() { 127 m = 2; 128 n = 4; 129 k = 3; 130 a_offset = 0; 131 b_offset = 12; 132 133 float[] a_float = byteToFloat(new byte[] { 134 1, 2, 3, 135 4, 5, 6, 136 }, a_offset); 137 138 float[] b_float = byteToFloat(new byte[] { 139 11, 7, 3, 140 10, 6, 2, 141 9, 5, 1, 142 8, 4, 0, 143 }, b_offset); 144 145 Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS)); 146 Type a_type = builder.setX(k).setY(m).create(); 147 Type b_type = builder.setX(n).setY(k).create(); 148 Type c_type = builder.setX(n).setY(m).create(); 149 150 matA = Allocation.createTyped(mRS, a_type); 151 matB = Allocation.createTyped(mRS, b_type); 152 matC = Allocation.createTyped(mRS, c_type); 153 154 matA.copyFrom(a_float); 155 matB.copyFrom(b_float); 156 157 //During setup, do a sample run to see if the result is correct. 158 mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE, 159 1.0f, matA, matB, 0.f, matC); 160 float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float); 161 float[] c_float_out = new float[m * n]; 162 matC.copyTo(c_float_out); 163 if (!testWithTolerance(c_float_ref, c_float_out)) { 164 Log.e(TAG, "Result is not correct!"); 165 throw new AssertionError("Result is not correct."); 166 } 167 } 168 169 // This test multiplies another two medium matrices, and compares the 170 // results with the expected values. The data here is arbitrary. 171 public void setTestMedium() { 172 m = 7; 173 n = 9; 174 k = 23; 175 a_offset = 13; 176 b_offset = 23; 177 178 float[] a_float = byteToFloat(new byte[] { 179 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 180 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 181 1, 23, 2, 22, 3, 21, 4, 20, 5, 19, 6, 18, 7, 17, 8, 16, 9, 15, 10, 14, 11, 13, 12, 182 23, 1, 22, 2, 21, 3, 20, 4, 19, 5, 18, 6, 17, 7, 16, 8, 15, 9, 14, 10, 13, 11, 12, 183 1, 1, 1, 1, 1, 1, 1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 184 3, 1, 4, 1, 5, 8, 2, 3, 1, 14, 11, 15, 18, 12, 13, 11, 14, 11, 15, 18, 12, 13, 11, 185 8, 0, 5, 8, 1, 3, 7, 5, 7, 13, 10, 23, 13, 11, 17, 23, 12, 19, 17, 13, 14, 10, 19, 186 }, a_offset); 187 188 float[] b_float = byteToFloat(new byte[] { 189 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 11, 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 190 0, 20, 40, 60, 80, 10, 11, 13, 15, 17, 19, 21, 10, 12, 14, 6, 8, 10, 1, 3, 5, 7, 9, 191 1, 21, 41, 61, 81, 11, 12, 14, 16, 18, 20, 22, 11, 13, 15, 7, 9, 11, 2, 4, 6, 8, 9, 192 0, 19, 39, 59, 79, 9, 10, 12, 14, 16, 18, 20, 9, 11, 13, 5, 7, 9, 0, 2, 4, 6, 8, 193 2, 22, 42, 62, 82, 12, 13, 15, 17, 19, 21, 23, 12, 14, 16, 8, 9, 12, 3, 5, 7, 9, 9, 194 0, 18, 38, 58, 78, 8, 9, 11, 13, 15, 17, 19, 8, 10, 12, 4, 6, 8, 0, 1, 3, 5, 7, 195 3, 23, 43, 63, 83, 13, 14, 16, 18, 20, 22, 24, 13, 15, 17, 9, 9, 13, 4, 6, 8, 9, 9, 196 0, 17, 37, 57, 77, 7, 8, 10, 12, 14, 16, 18, 7, 9, 11, 3, 5, 7, 0, 0, 2, 4, 6, 197 10, 20, 30, 40, 50, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 1, 2, 3, 198 }, b_offset); 199 200 Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS)); 201 Type a_type = builder.setX(k).setY(m).create(); 202 Type b_type = builder.setX(n).setY(k).create(); 203 Type c_type = builder.setX(n).setY(m).create(); 204 205 matA = Allocation.createTyped(mRS, a_type); 206 matB = Allocation.createTyped(mRS, b_type); 207 matC = Allocation.createTyped(mRS, c_type); 208 209 matA.copyFrom(a_float); 210 matB.copyFrom(b_float); 211 212 //During setup, do a sample run to see if the result is correct. 213 mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE, 214 1.0f, matA, matB, 0.f, matC); 215 float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float); 216 float[] c_float_out = new float[m * n]; 217 matC.copyTo(c_float_out); 218 if (!testWithTolerance(c_float_ref, c_float_out)) { 219 Log.e(TAG, "Result is not correct!"); 220 throw new AssertionError("Result is not correct."); 221 } 222 } 223 224 225 // This test takes a large set of real data captured from a convolutional 226 // neural network solving a computer vision problem, and runs it through SGEMM. 227 public void setTestLarge() { 228 229 m = 256; 230 n = 192; 231 k = 1152; 232 a_offset = 0; 233 b_offset = 84; 234 235 int a_count = (m * k); 236 int b_count = (n * k); 237 int c_count = (m * n); 238 239 byte[] a_byte = new byte[a_count]; 240 byte[] b_byte = new byte[b_count]; 241 byte[] c_byte = new byte[c_count]; 242 243 getData(a_byte, b_byte, c_byte); 244 245 float[] a_float = byteToFloat(a_byte, a_offset); 246 float[] b_float = byteToFloat(b_byte, b_offset); 247 248 Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS)); 249 Type a_type = builder.setX(k).setY(m).create(); 250 Type b_type = builder.setX(n).setY(k).create(); 251 Type c_type = builder.setX(n).setY(m).create(); 252 253 matA = Allocation.createTyped(mRS, a_type); 254 matB = Allocation.createTyped(mRS, b_type); 255 matC = Allocation.createTyped(mRS, c_type); 256 257 matA.copyFrom(a_float); 258 matB.copyFrom(b_float); 259 260 //During setup, do a sample run to see if the result is correct. 261 mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE, 262 1.0f, matA, matB, 0.f, matC); 263 float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float); 264 float[] c_float_out = new float[c_count]; 265 matC.copyTo(c_float_out); 266 if (!testWithTolerance(c_float_ref, c_float_out)) { 267 Log.e(TAG, "Result is not correct!"); 268 throw new AssertionError("Result is not correct."); 269 } 270 } 271 272 public void runTest() { 273 mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE, 274 1.0f, matA, matB, 0.f, matC); 275 } 276 277 public String getTestInfo() { 278 return "SGEMM Test: m=" + m + ", n=" + n + ", k=" + k; 279 } 280} 281