ScriptGroup.java revision 6f5555db1af436bb5aad430e6e00aa5b69d5ca6c
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package android.support.v8.renderscript; 18 19import java.lang.reflect.Method; 20import java.util.ArrayList; 21import java.util.Collections; 22import java.util.Comparator; 23 24/** 25 * ScriptGroup creates a group of kernels that are executed 26 * together with one execution call as if they were a single kernel. 27 * The kernels may be connected internally or to an external allocation. 28 * The intermediate results for internal connections are not observable 29 * after the execution of the script. 30 * <p> 31 * External connections are grouped into inputs and outputs. 32 * All outputs are produced by a script kernel and placed into a 33 * user-supplied allocation. Inputs provide the input of a kernel. 34 * Inputs bound to script globals are set directly upon the script. 35 * <p> 36 * A ScriptGroup must contain at least one kernel. A ScriptGroup 37 * must contain only a single directed acyclic graph (DAG) of 38 * script kernels and connections. Attempting to create a 39 * ScriptGroup with multiple DAGs or attempting to create 40 * a cycle within a ScriptGroup will throw an exception. 41 * <p> 42 * Currently, all kernels in a ScriptGroup must be from separate 43 * Script objects. Attempting to use multiple kernels from the same 44 * Script object will result in an 45 * {@link android.support.v8.renderscript.RSInvalidStateException}. 46 * 47 **/ 48public class ScriptGroup extends BaseObj { 49 IO mOutputs[]; 50 IO mInputs[]; 51 private boolean mUseIncSupp = false; 52 private ArrayList<Node> mNodes = new ArrayList<Node>(); 53 54 static class IO { 55 Script.KernelID mKID; 56 Allocation mAllocation; 57 58 IO(Script.KernelID s) { 59 mKID = s; 60 } 61 } 62 63 static class ConnectLine { 64 ConnectLine(Type t, Script.KernelID from, Script.KernelID to) { 65 mFrom = from; 66 mToK = to; 67 mAllocationType = t; 68 } 69 70 ConnectLine(Type t, Script.KernelID from, Script.FieldID to) { 71 mFrom = from; 72 mToF = to; 73 mAllocationType = t; 74 } 75 76 Script.FieldID mToF; 77 Script.KernelID mToK; 78 Script.KernelID mFrom; 79 Type mAllocationType; 80 Allocation mAllocation; 81 } 82 83 static class Node { 84 Script mScript; 85 ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>(); 86 ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>(); 87 ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>(); 88 int dagNumber; 89 boolean mSeen; 90 int mOrder; 91 92 Node mNext; 93 94 Node(Script s) { 95 mScript = s; 96 } 97 } 98 99 100 ScriptGroup(long id, RenderScript rs) { 101 super(id, rs); 102 } 103 104 /** 105 * Sets an input of the ScriptGroup. This specifies an 106 * Allocation to be used for kernels that require an input 107 * Allocation provided from outside of the ScriptGroup. 108 * 109 * @param s The ID of the kernel where the allocation should be 110 * connected. 111 * @param a The allocation to connect. 112 */ 113 public void setInput(Script.KernelID s, Allocation a) { 114 for (int ct=0; ct < mInputs.length; ct++) { 115 if (mInputs[ct].mKID == s) { 116 mInputs[ct].mAllocation = a; 117 if (!mUseIncSupp) { 118 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 119 } 120 return; 121 } 122 } 123 throw new RSIllegalArgumentException("Script not found"); 124 } 125 126 /** 127 * Sets an output of the ScriptGroup. This specifies an 128 * Allocation to be used for the kernels that require an output 129 * Allocation visible after the ScriptGroup is executed. 130 * 131 * @param s The ID of the kernel where the allocation should be 132 * connected. 133 * @param a The allocation to connect. 134 */ 135 public void setOutput(Script.KernelID s, Allocation a) { 136 for (int ct=0; ct < mOutputs.length; ct++) { 137 if (mOutputs[ct].mKID == s) { 138 mOutputs[ct].mAllocation = a; 139 if (!mUseIncSupp) { 140 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 141 } 142 return; 143 } 144 } 145 throw new RSIllegalArgumentException("Script not found"); 146 } 147 148 /** 149 * Execute the ScriptGroup. This will run all the kernels in 150 * the ScriptGroup. No internal connection results will be visible 151 * after execution of the ScriptGroup. 152 * 153 * If Incremental Support for intrinsics is needed, the execution 154 * will take the naive path: execute kernels one by one in the 155 * correct order. 156 */ 157 public void execute() { 158 if (!mUseIncSupp) { 159 mRS.nScriptGroupExecute(getID(mRS)); 160 } else { 161 // setup the allocations. 162 for (int ct=0; ct < mNodes.size(); ct++) { 163 Node n = mNodes.get(ct); 164 for (int ct2=0; ct2 < n.mOutputs.size(); ct2++) { 165 ConnectLine l = n.mOutputs.get(ct2); 166 if (l.mAllocation !=null) { 167 continue; 168 } 169 170 //create allocation here 171 Allocation alloc = Allocation.createTyped(mRS, l.mAllocationType, 172 Allocation.MipmapControl.MIPMAP_NONE, 173 Allocation.USAGE_SCRIPT); 174 175 l.mAllocation = alloc; 176 for (int ct3=ct2+1; ct3 < n.mOutputs.size(); ct3++) { 177 if (n.mOutputs.get(ct3).mFrom == l.mFrom) { 178 n.mOutputs.get(ct3).mAllocation = alloc; 179 } 180 } 181 } 182 } 183 for (Node node : mNodes) { 184 for (Script.KernelID kernel : node.mKernels) { 185 Allocation ain = null; 186 Allocation aout = null; 187 188 for (ConnectLine nodeInput : node.mInputs) { 189 if (nodeInput.mToK == kernel) { 190 ain = nodeInput.mAllocation; 191 } 192 } 193 194 for (IO sgInput : mInputs) { 195 if (sgInput.mKID == kernel) { 196 ain = sgInput.mAllocation; 197 } 198 } 199 200 for (ConnectLine nodeOutput : node.mOutputs) { 201 if (nodeOutput.mFrom == kernel) { 202 aout = nodeOutput.mAllocation; 203 } 204 } 205 206 for (IO sgOutput : mOutputs) { 207 if (sgOutput.mKID == kernel) { 208 aout = sgOutput.mAllocation; 209 } 210 } 211 212 kernel.mScript.forEach(kernel.mSlot, ain, aout, null); 213 } 214 } 215 } 216 } 217 218 219 /** 220 * Helper class to build a ScriptGroup. A ScriptGroup is 221 * created in two steps. 222 * <p> 223 * First, all kernels to be used by the ScriptGroup should be added. 224 * <p> 225 * Second, add connections between kernels. There are two types 226 * of connections: kernel to kernel and kernel to field. 227 * Kernel to kernel allows a kernel's output to be passed to 228 * another kernel as input. Kernel to field allows the output of 229 * one kernel to be bound as a script global. Kernel to kernel is 230 * higher performance and should be used where possible. 231 * <p> 232 * A ScriptGroup must contain a single directed acyclic graph (DAG); it 233 * cannot contain cycles. Currently, all kernels used in a ScriptGroup 234 * must come from different Script objects. Additionally, all kernels 235 * in a ScriptGroup must have at least one input, output, or internal 236 * connection. 237 * <p> 238 * Once all connections are made, a call to {@link #create} will 239 * return the ScriptGroup object. 240 * 241 */ 242 public static final class Builder { 243 private RenderScript mRS; 244 private ArrayList<Node> mNodes = new ArrayList<Node>(); 245 private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>(); 246 private int mKernelCount; 247 private boolean mUseIncSupp = false; 248 249 /** 250 * Create a Builder for generating a ScriptGroup. 251 * 252 * 253 * @param rs The RenderScript context. 254 */ 255 public Builder(RenderScript rs) { 256 mRS = rs; 257 } 258 259 // do a DFS from original node, looking for original node 260 // any cycle that could be created must contain original node 261 private void validateCycle(Node target, Node original) { 262 for (int ct = 0; ct < target.mOutputs.size(); ct++) { 263 final ConnectLine cl = target.mOutputs.get(ct); 264 if (cl.mToK != null) { 265 Node tn = findNode(cl.mToK.mScript); 266 if (tn.equals(original)) { 267 throw new RSInvalidStateException("Loops in group not allowed."); 268 } 269 validateCycle(tn, original); 270 } 271 if (cl.mToF != null) { 272 Node tn = findNode(cl.mToF.mScript); 273 if (tn.equals(original)) { 274 throw new RSInvalidStateException("Loops in group not allowed."); 275 } 276 validateCycle(tn, original); 277 } 278 } 279 } 280 281 private void mergeDAGs(int valueUsed, int valueKilled) { 282 for (int ct=0; ct < mNodes.size(); ct++) { 283 if (mNodes.get(ct).dagNumber == valueKilled) 284 mNodes.get(ct).dagNumber = valueUsed; 285 } 286 } 287 288 private void validateDAGRecurse(Node n, int dagNumber) { 289 // combine DAGs if this node has been seen already 290 if (n.dagNumber != 0 && n.dagNumber != dagNumber) { 291 mergeDAGs(n.dagNumber, dagNumber); 292 return; 293 } 294 295 n.dagNumber = dagNumber; 296 for (int ct=0; ct < n.mOutputs.size(); ct++) { 297 final ConnectLine cl = n.mOutputs.get(ct); 298 if (cl.mToK != null) { 299 Node tn = findNode(cl.mToK.mScript); 300 validateDAGRecurse(tn, dagNumber); 301 } 302 if (cl.mToF != null) { 303 Node tn = findNode(cl.mToF.mScript); 304 validateDAGRecurse(tn, dagNumber); 305 } 306 } 307 } 308 309 private void validateDAG() { 310 for (int ct=0; ct < mNodes.size(); ct++) { 311 Node n = mNodes.get(ct); 312 if (n.mInputs.size() == 0) { 313 if (n.mOutputs.size() == 0 && mNodes.size() > 1) { 314 throw new RSInvalidStateException("Groups cannot contain unconnected scripts"); 315 } 316 validateDAGRecurse(n, ct+1); 317 } 318 } 319 int dagNumber = mNodes.get(0).dagNumber; 320 for (int ct=0; ct < mNodes.size(); ct++) { 321 if (mNodes.get(ct).dagNumber != dagNumber) { 322 throw new RSInvalidStateException("Multiple DAGs in group not allowed."); 323 } 324 } 325 } 326 327 private Node findNode(Script s) { 328 for (int ct=0; ct < mNodes.size(); ct++) { 329 if (s == mNodes.get(ct).mScript) { 330 return mNodes.get(ct); 331 } 332 } 333 return null; 334 } 335 336 private Node findNode(Script.KernelID k) { 337 for (int ct=0; ct < mNodes.size(); ct++) { 338 Node n = mNodes.get(ct); 339 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 340 if (k == n.mKernels.get(ct2)) { 341 return n; 342 } 343 } 344 } 345 return null; 346 } 347 348 /** 349 * Adds a Kernel to the group. 350 * 351 * 352 * @param k The kernel to add. 353 * 354 * @return Builder Returns this. 355 */ 356 public Builder addKernel(Script.KernelID k) { 357 if (mLines.size() != 0) { 358 throw new RSInvalidStateException( 359 "Kernels may not be added once connections exist."); 360 } 361 if (k.mScript.isIncSupp()) { 362 mUseIncSupp = true; 363 } 364 //android.util.Log.v("RSR", "addKernel 1 k=" + k); 365 if (findNode(k) != null) { 366 return this; 367 } 368 //android.util.Log.v("RSR", "addKernel 2 "); 369 mKernelCount++; 370 Node n = findNode(k.mScript); 371 if (n == null) { 372 //android.util.Log.v("RSR", "addKernel 3 "); 373 n = new Node(k.mScript); 374 mNodes.add(n); 375 } 376 n.mKernels.add(k); 377 return this; 378 } 379 380 /** 381 * Adds a connection to the group. 382 * 383 * 384 * @param t The type of the connection. This is used to 385 * determine the kernel launch sizes on the source side 386 * of this connection. 387 * @param from The source for the connection. 388 * @param to The destination of the connection. 389 * 390 * @return Builder Returns this 391 */ 392 public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) { 393 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 394 Node nf = findNode(from); 395 if (nf == null) { 396 throw new RSInvalidStateException("From script not found."); 397 } 398 399 Node nt = findNode(to.mScript); 400 if (nt == null) { 401 throw new RSInvalidStateException("To script not found."); 402 } 403 404 ConnectLine cl = new ConnectLine(t, from, to); 405 mLines.add(new ConnectLine(t, from, to)); 406 407 nf.mOutputs.add(cl); 408 nt.mInputs.add(cl); 409 410 validateCycle(nf, nf); 411 return this; 412 } 413 414 /** 415 * Adds a connection to the group. 416 * 417 * 418 * @param t The type of the connection. This is used to 419 * determine the kernel launch sizes for both sides of 420 * this connection. 421 * @param from The source for the connection. 422 * @param to The destination of the connection. 423 * 424 * @return Builder Returns this 425 */ 426 public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) { 427 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 428 Node nf = findNode(from); 429 if (nf == null) { 430 throw new RSInvalidStateException("From script not found."); 431 } 432 433 Node nt = findNode(to); 434 if (nt == null) { 435 throw new RSInvalidStateException("To script not found."); 436 } 437 438 ConnectLine cl = new ConnectLine(t, from, to); 439 mLines.add(new ConnectLine(t, from, to)); 440 441 nf.mOutputs.add(cl); 442 nt.mInputs.add(cl); 443 444 validateCycle(nf, nf); 445 return this; 446 } 447 448 /** 449 * Calculate the order of each node. 450 * 451 * 452 * @return Success or Fail 453 */ 454 private boolean calcOrderRecurse(Node node0, int depth) { 455 node0.mSeen = true; 456 if (node0.mOrder < depth) { 457 node0.mOrder = depth; 458 } 459 boolean ret = true; 460 461 for (ConnectLine link : node0.mOutputs) { 462 Node node1 = null; 463 if (link.mToF != null) { 464 node1 = findNode(link.mToF.mScript); 465 } else { 466 node1 = findNode(link.mToK.mScript); 467 } 468 if (node1.mSeen) { 469 return false; 470 } 471 ret &= calcOrderRecurse(node1, node0.mOrder + 1); 472 } 473 474 return ret; 475 } 476 477 private boolean calcOrder() { 478 boolean ret = true; 479 for (Node n0 : mNodes) { 480 if (n0.mInputs.size() == 0) { 481 for (Node n1 : mNodes) { 482 n1.mSeen = false; 483 } 484 ret &= calcOrderRecurse(n0, 1); 485 } 486 } 487 488 Collections.sort(mNodes, new Comparator<Node>() { 489 public int compare(Node n1, Node n2) { 490 return n1.mOrder - n2.mOrder; 491 } 492 }); 493 494 return ret; 495 } 496 497 /** 498 * Creates the Script group. 499 * 500 * 501 * @return ScriptGroup The new ScriptGroup 502 */ 503 public ScriptGroup create() { 504 505 if (mNodes.size() == 0) { 506 throw new RSInvalidStateException("Empty script groups are not allowed"); 507 } 508 509 // reset DAG numbers in case we're building a second group 510 for (int ct=0; ct < mNodes.size(); ct++) { 511 mNodes.get(ct).dagNumber = 0; 512 } 513 validateDAG(); 514 515 ArrayList<IO> inputs = new ArrayList<IO>(); 516 ArrayList<IO> outputs = new ArrayList<IO>(); 517 518 long[] kernels = new long[mKernelCount]; 519 int idx = 0; 520 for (int ct=0; ct < mNodes.size(); ct++) { 521 Node n = mNodes.get(ct); 522 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 523 final Script.KernelID kid = n.mKernels.get(ct2); 524 kernels[idx++] = kid.getID(mRS); 525 526 boolean hasInput = false; 527 boolean hasOutput = false; 528 for (int ct3=0; ct3 < n.mInputs.size(); ct3++) { 529 if (n.mInputs.get(ct3).mToK == kid) { 530 hasInput = true; 531 } 532 } 533 for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) { 534 if (n.mOutputs.get(ct3).mFrom == kid) { 535 hasOutput = true; 536 } 537 } 538 if (!hasInput) { 539 inputs.add(new IO(kid)); 540 } 541 if (!hasOutput) { 542 outputs.add(new IO(kid)); 543 } 544 } 545 } 546 if (idx != mKernelCount) { 547 throw new RSRuntimeException("Count mismatch, should not happen."); 548 } 549 550 long id = 0; 551 if (!mUseIncSupp) { 552 long[] src = new long[mLines.size()]; 553 long[] dstk = new long[mLines.size()]; 554 long[] dstf = new long[mLines.size()]; 555 long[] types = new long[mLines.size()]; 556 557 for (int ct=0; ct < mLines.size(); ct++) { 558 ConnectLine cl = mLines.get(ct); 559 src[ct] = cl.mFrom.getID(mRS); 560 if (cl.mToK != null) { 561 dstk[ct] = cl.mToK.getID(mRS); 562 } 563 if (cl.mToF != null) { 564 dstf[ct] = cl.mToF.getID(mRS); 565 } 566 types[ct] = cl.mAllocationType.getID(mRS); 567 } 568 id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types); 569 if (id == 0) { 570 throw new RSRuntimeException("Object creation error, should not happen."); 571 } 572 } else { 573 //Calculate the order of the DAG so that script can run one after another. 574 calcOrder(); 575 } 576 577 ScriptGroup sg = new ScriptGroup(id, mRS); 578 sg.mOutputs = new IO[outputs.size()]; 579 for (int ct=0; ct < outputs.size(); ct++) { 580 sg.mOutputs[ct] = outputs.get(ct); 581 } 582 583 sg.mInputs = new IO[inputs.size()]; 584 for (int ct=0; ct < inputs.size(); ct++) { 585 sg.mInputs[ct] = inputs.get(ct); 586 } 587 sg.mNodes = mNodes; 588 sg.mUseIncSupp = mUseIncSupp; 589 return sg; 590 } 591 592 } 593 594 595} 596 597 598