1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package android.support.v8.renderscript; 18 19import java.lang.reflect.Method; 20import java.util.ArrayList; 21 22/** 23 * ScriptGroup creates a group of kernels that are executed 24 * together with one execution call as if they were a single kernel. 25 * The kernels may be connected internally or to an external allocation. 26 * The intermediate results for internal connections are not observable 27 * after the execution of the script. 28 * <p> 29 * External connections are grouped into inputs and outputs. 30 * All outputs are produced by a script kernel and placed into a 31 * user-supplied allocation. Inputs provide the input of a kernel. 32 * Inputs bound to script globals are set directly upon the script. 33 * <p> 34 * A ScriptGroup must contain at least one kernel. A ScriptGroup 35 * must contain only a single directed acyclic graph (DAG) of 36 * script kernels and connections. Attempting to create a 37 * ScriptGroup with multiple DAGs or attempting to create 38 * a cycle within a ScriptGroup will throw an exception. 39 * <p> 40 * Currently, all kernels in a ScriptGroup must be from separate 41 * Script objects. Attempting to use multiple kernels from the same 42 * Script object will result in an 43 * {@link android.support.v8.renderscript.RSInvalidStateException}. 44 * 45 **/ 46public class ScriptGroup extends BaseObj { 47 IO mOutputs[]; 48 IO mInputs[]; 49 50 static class IO { 51 Script.KernelID mKID; 52 Allocation mAllocation; 53 54 IO(Script.KernelID s) { 55 mKID = s; 56 } 57 } 58 59 static class ConnectLine { 60 ConnectLine(Type t, Script.KernelID from, Script.KernelID to) { 61 mFrom = from; 62 mToK = to; 63 mAllocationType = t; 64 } 65 66 ConnectLine(Type t, Script.KernelID from, Script.FieldID to) { 67 mFrom = from; 68 mToF = to; 69 mAllocationType = t; 70 } 71 72 Script.FieldID mToF; 73 Script.KernelID mToK; 74 Script.KernelID mFrom; 75 Type mAllocationType; 76 } 77 78 static class Node { 79 Script mScript; 80 ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>(); 81 ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>(); 82 ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>(); 83 int dagNumber; 84 85 Node mNext; 86 87 Node(Script s) { 88 mScript = s; 89 } 90 } 91 92 93 ScriptGroup(int id, RenderScript rs) { 94 super(id, rs); 95 } 96 97 /** 98 * Sets an input of the ScriptGroup. This specifies an 99 * Allocation to be used for kernels that require an input 100 * Allocation provided from outside of the ScriptGroup. 101 * 102 * @param s The ID of the kernel where the allocation should be 103 * connected. 104 * @param a The allocation to connect. 105 */ 106 public void setInput(Script.KernelID s, Allocation a) { 107 for (int ct=0; ct < mInputs.length; ct++) { 108 if (mInputs[ct].mKID == s) { 109 mInputs[ct].mAllocation = a; 110 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 111 return; 112 } 113 } 114 throw new RSIllegalArgumentException("Script not found"); 115 } 116 117 /** 118 * Sets an output of the ScriptGroup. This specifies an 119 * Allocation to be used for the kernels that require an output 120 * Allocation visible after the ScriptGroup is executed. 121 * 122 * @param s The ID of the kernel where the allocation should be 123 * connected. 124 * @param a The allocation to connect. 125 */ 126 public void setOutput(Script.KernelID s, Allocation a) { 127 for (int ct=0; ct < mOutputs.length; ct++) { 128 if (mOutputs[ct].mKID == s) { 129 mOutputs[ct].mAllocation = a; 130 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 131 return; 132 } 133 } 134 throw new RSIllegalArgumentException("Script not found"); 135 } 136 137 /** 138 * Execute the ScriptGroup. This will run all the kernels in 139 * the ScriptGroup. No internal connection results will be visible 140 * after execution of the ScriptGroup. 141 */ 142 public void execute() { 143 mRS.nScriptGroupExecute(getID(mRS)); 144 } 145 146 147 /** 148 * Helper class to build a ScriptGroup. A ScriptGroup is 149 * created in two steps. 150 * <p> 151 * First, all kernels to be used by the ScriptGroup should be added. 152 * <p> 153 * Second, add connections between kernels. There are two types 154 * of connections: kernel to kernel and kernel to field. 155 * Kernel to kernel allows a kernel's output to be passed to 156 * another kernel as input. Kernel to field allows the output of 157 * one kernel to be bound as a script global. Kernel to kernel is 158 * higher performance and should be used where possible. 159 * <p> 160 * A ScriptGroup must contain a single directed acyclic graph (DAG); it 161 * cannot contain cycles. Currently, all kernels used in a ScriptGroup 162 * must come from different Script objects. Additionally, all kernels 163 * in a ScriptGroup must have at least one input, output, or internal 164 * connection. 165 * <p> 166 * Once all connections are made, a call to {@link #create} will 167 * return the ScriptGroup object. 168 * 169 */ 170 public static final class Builder { 171 private RenderScript mRS; 172 private ArrayList<Node> mNodes = new ArrayList<Node>(); 173 private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>(); 174 private int mKernelCount; 175 176 private ScriptGroupThunker.Builder mT; 177 178 /** 179 * Create a Builder for generating a ScriptGroup. 180 * 181 * 182 * @param rs The RenderScript context. 183 */ 184 public Builder(RenderScript rs) { 185 if (rs.isNative) { 186 mT = new ScriptGroupThunker.Builder(rs); 187 } 188 mRS = rs; 189 } 190 191 // do a DFS from original node, looking for original node 192 // any cycle that could be created must contain original node 193 private void validateCycle(Node target, Node original) { 194 for (int ct = 0; ct < target.mOutputs.size(); ct++) { 195 final ConnectLine cl = target.mOutputs.get(ct); 196 if (cl.mToK != null) { 197 Node tn = findNode(cl.mToK.mScript); 198 if (tn.equals(original)) { 199 throw new RSInvalidStateException("Loops in group not allowed."); 200 } 201 validateCycle(tn, original); 202 } 203 if (cl.mToF != null) { 204 Node tn = findNode(cl.mToF.mScript); 205 if (tn.equals(original)) { 206 throw new RSInvalidStateException("Loops in group not allowed."); 207 } 208 validateCycle(tn, original); 209 } 210 } 211 } 212 213 private void mergeDAGs(int valueUsed, int valueKilled) { 214 for (int ct=0; ct < mNodes.size(); ct++) { 215 if (mNodes.get(ct).dagNumber == valueKilled) 216 mNodes.get(ct).dagNumber = valueUsed; 217 } 218 } 219 220 private void validateDAGRecurse(Node n, int dagNumber) { 221 // combine DAGs if this node has been seen already 222 if (n.dagNumber != 0 && n.dagNumber != dagNumber) { 223 mergeDAGs(n.dagNumber, dagNumber); 224 return; 225 } 226 227 n.dagNumber = dagNumber; 228 for (int ct=0; ct < n.mOutputs.size(); ct++) { 229 final ConnectLine cl = n.mOutputs.get(ct); 230 if (cl.mToK != null) { 231 Node tn = findNode(cl.mToK.mScript); 232 validateDAGRecurse(tn, dagNumber); 233 } 234 if (cl.mToF != null) { 235 Node tn = findNode(cl.mToF.mScript); 236 validateDAGRecurse(tn, dagNumber); 237 } 238 } 239 } 240 241 private void validateDAG() { 242 for (int ct=0; ct < mNodes.size(); ct++) { 243 Node n = mNodes.get(ct); 244 if (n.mInputs.size() == 0) { 245 if (n.mOutputs.size() == 0 && mNodes.size() > 1) { 246 throw new RSInvalidStateException("Groups cannot contain unconnected scripts"); 247 } 248 validateDAGRecurse(n, ct+1); 249 } 250 } 251 int dagNumber = mNodes.get(0).dagNumber; 252 for (int ct=0; ct < mNodes.size(); ct++) { 253 if (mNodes.get(ct).dagNumber != dagNumber) { 254 throw new RSInvalidStateException("Multiple DAGs in group not allowed."); 255 } 256 } 257 } 258 259 private Node findNode(Script s) { 260 for (int ct=0; ct < mNodes.size(); ct++) { 261 if (s == mNodes.get(ct).mScript) { 262 return mNodes.get(ct); 263 } 264 } 265 return null; 266 } 267 268 private Node findNode(Script.KernelID k) { 269 for (int ct=0; ct < mNodes.size(); ct++) { 270 Node n = mNodes.get(ct); 271 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 272 if (k == n.mKernels.get(ct2)) { 273 return n; 274 } 275 } 276 } 277 return null; 278 } 279 280 /** 281 * Adds a Kernel to the group. 282 * 283 * 284 * @param k The kernel to add. 285 * 286 * @return Builder Returns this. 287 */ 288 public Builder addKernel(Script.KernelID k) { 289 if (mT != null) { 290 mT.addKernel(k); 291 return this; 292 } 293 294 if (mLines.size() != 0) { 295 throw new RSInvalidStateException( 296 "Kernels may not be added once connections exist."); 297 } 298 299 //android.util.Log.v("RSR", "addKernel 1 k=" + k); 300 if (findNode(k) != null) { 301 return this; 302 } 303 //android.util.Log.v("RSR", "addKernel 2 "); 304 mKernelCount++; 305 Node n = findNode(k.mScript); 306 if (n == null) { 307 //android.util.Log.v("RSR", "addKernel 3 "); 308 n = new Node(k.mScript); 309 mNodes.add(n); 310 } 311 n.mKernels.add(k); 312 return this; 313 } 314 315 /** 316 * Adds a connection to the group. 317 * 318 * 319 * @param t The type of the connection. This is used to 320 * determine the kernel launch sizes on the source side 321 * of this connection. 322 * @param from The source for the connection. 323 * @param to The destination of the connection. 324 * 325 * @return Builder Returns this 326 */ 327 public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) { 328 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 329 330 if (mT != null) { 331 mT.addConnection(t, from, to); 332 return this; 333 } 334 335 Node nf = findNode(from); 336 if (nf == null) { 337 throw new RSInvalidStateException("From script not found."); 338 } 339 340 Node nt = findNode(to.mScript); 341 if (nt == null) { 342 throw new RSInvalidStateException("To script not found."); 343 } 344 345 ConnectLine cl = new ConnectLine(t, from, to); 346 mLines.add(new ConnectLine(t, from, to)); 347 348 nf.mOutputs.add(cl); 349 nt.mInputs.add(cl); 350 351 validateCycle(nf, nf); 352 return this; 353 } 354 355 /** 356 * Adds a connection to the group. 357 * 358 * 359 * @param t The type of the connection. This is used to 360 * determine the kernel launch sizes for both sides of 361 * this connection. 362 * @param from The source for the connection. 363 * @param to The destination of the connection. 364 * 365 * @return Builder Returns this 366 */ 367 public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) { 368 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 369 370 if (mT != null) { 371 mT.addConnection(t, from, to); 372 return this; 373 } 374 375 Node nf = findNode(from); 376 if (nf == null) { 377 throw new RSInvalidStateException("From script not found."); 378 } 379 380 Node nt = findNode(to); 381 if (nt == null) { 382 throw new RSInvalidStateException("To script not found."); 383 } 384 385 ConnectLine cl = new ConnectLine(t, from, to); 386 mLines.add(new ConnectLine(t, from, to)); 387 388 nf.mOutputs.add(cl); 389 nt.mInputs.add(cl); 390 391 validateCycle(nf, nf); 392 return this; 393 } 394 395 396 397 /** 398 * Creates the Script group. 399 * 400 * 401 * @return ScriptGroup The new ScriptGroup 402 */ 403 public ScriptGroup create() { 404 405 if (mT != null) { 406 return mT.create(); 407 } 408 409 if (mNodes.size() == 0) { 410 throw new RSInvalidStateException("Empty script groups are not allowed"); 411 } 412 413 // reset DAG numbers in case we're building a second group 414 for (int ct=0; ct < mNodes.size(); ct++) { 415 mNodes.get(ct).dagNumber = 0; 416 } 417 validateDAG(); 418 419 ArrayList<IO> inputs = new ArrayList<IO>(); 420 ArrayList<IO> outputs = new ArrayList<IO>(); 421 422 int[] kernels = new int[mKernelCount]; 423 int idx = 0; 424 for (int ct=0; ct < mNodes.size(); ct++) { 425 Node n = mNodes.get(ct); 426 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 427 final Script.KernelID kid = n.mKernels.get(ct2); 428 kernels[idx++] = kid.getID(mRS); 429 430 boolean hasInput = false; 431 boolean hasOutput = false; 432 for (int ct3=0; ct3 < n.mInputs.size(); ct3++) { 433 if (n.mInputs.get(ct3).mToK == kid) { 434 hasInput = true; 435 } 436 } 437 for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) { 438 if (n.mOutputs.get(ct3).mFrom == kid) { 439 hasOutput = true; 440 } 441 } 442 if (!hasInput) { 443 inputs.add(new IO(kid)); 444 } 445 if (!hasOutput) { 446 outputs.add(new IO(kid)); 447 } 448 449 } 450 } 451 if (idx != mKernelCount) { 452 throw new RSRuntimeException("Count mismatch, should not happen."); 453 } 454 455 int[] src = new int[mLines.size()]; 456 int[] dstk = new int[mLines.size()]; 457 int[] dstf = new int[mLines.size()]; 458 int[] types = new int[mLines.size()]; 459 460 for (int ct=0; ct < mLines.size(); ct++) { 461 ConnectLine cl = mLines.get(ct); 462 src[ct] = cl.mFrom.getID(mRS); 463 if (cl.mToK != null) { 464 dstk[ct] = cl.mToK.getID(mRS); 465 } 466 if (cl.mToF != null) { 467 dstf[ct] = cl.mToF.getID(mRS); 468 } 469 types[ct] = cl.mAllocationType.getID(mRS); 470 } 471 472 int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types); 473 if (id == 0) { 474 throw new RSRuntimeException("Object creation error, should not happen."); 475 } 476 477 ScriptGroup sg = new ScriptGroup(id, mRS); 478 sg.mOutputs = new IO[outputs.size()]; 479 for (int ct=0; ct < outputs.size(); ct++) { 480 sg.mOutputs[ct] = outputs.get(ct); 481 } 482 483 sg.mInputs = new IO[inputs.size()]; 484 for (int ct=0; ct < inputs.size(); ct++) { 485 sg.mInputs[ct] = inputs.get(ct); 486 } 487 488 return sg; 489 } 490 491 } 492 493 494} 495 496 497