ScriptGroup.java revision 7d435ae5ba100be5710b685653cc351cab159c11
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17package android.support.v8.renderscript; 18 19import java.lang.reflect.Method; 20import java.util.ArrayList; 21 22/** 23 * ScriptGroup creates a group of kernels that are executed 24 * together with one execution call as if they were a single kernel. 25 * The kernels may be connected internally or to an external allocation. 26 * The intermediate results for internal connections are not observable 27 * after the execution of the script. 28 * <p> 29 * External connections are grouped into inputs and outputs. 30 * All outputs are produced by a script kernel and placed into a 31 * user-supplied allocation. Inputs provide the input of a kernel. 32 * Inputs bound to script globals are set directly upon the script. 33 * <p> 34 * A ScriptGroup must contain at least one kernel. A ScriptGroup 35 * must contain only a single directed acyclic graph (DAG) of 36 * script kernels and connections. Attempting to create a 37 * ScriptGroup with multiple DAGs or attempting to create 38 * a cycle within a ScriptGroup will throw an exception. 39 * <p> 40 * Currently, all kernels in a ScriptGroup must be from separate 41 * Script objects. Attempting to use multiple kernels from the same 42 * Script object will result in an {@link android.renderscript.RSInvalidStateException}. 43 * 44 **/ 45public class ScriptGroup extends BaseObj { 46 IO mOutputs[]; 47 IO mInputs[]; 48 49 static class IO { 50 Script.KernelID mKID; 51 Allocation mAllocation; 52 53 IO(Script.KernelID s) { 54 mKID = s; 55 } 56 } 57 58 static class ConnectLine { 59 ConnectLine(Type t, Script.KernelID from, Script.KernelID to) { 60 mFrom = from; 61 mToK = to; 62 mAllocationType = t; 63 } 64 65 ConnectLine(Type t, Script.KernelID from, Script.FieldID to) { 66 mFrom = from; 67 mToF = to; 68 mAllocationType = t; 69 } 70 71 Script.FieldID mToF; 72 Script.KernelID mToK; 73 Script.KernelID mFrom; 74 Type mAllocationType; 75 } 76 77 static class Node { 78 Script mScript; 79 ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>(); 80 ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>(); 81 ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>(); 82 int dagNumber; 83 84 Node mNext; 85 86 Node(Script s) { 87 mScript = s; 88 } 89 } 90 91 92 ScriptGroup(int id, RenderScript rs) { 93 super(id, rs); 94 } 95 96 /** 97 * Sets an input of the ScriptGroup. This specifies an 98 * Allocation to be used for kernels that require an input 99 * Allocation provided from outside of the ScriptGroup. 100 * 101 * @param s The ID of the kernel where the allocation should be 102 * connected. 103 * @param a The allocation to connect. 104 */ 105 public void setInput(Script.KernelID s, Allocation a) { 106 for (int ct=0; ct < mInputs.length; ct++) { 107 if (mInputs[ct].mKID == s) { 108 mInputs[ct].mAllocation = a; 109 mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 110 return; 111 } 112 } 113 throw new RSIllegalArgumentException("Script not found"); 114 } 115 116 /** 117 * Sets an output of the ScriptGroup. This specifies an 118 * Allocation to be used for the kernels that require an output 119 * Allocation visible after the ScriptGroup is executed. 120 * 121 * @param s The ID of the kernel where the allocation should be 122 * connected. 123 * @param a The allocation to connect. 124 */ 125 public void setOutput(Script.KernelID s, Allocation a) { 126 for (int ct=0; ct < mOutputs.length; ct++) { 127 if (mOutputs[ct].mKID == s) { 128 mOutputs[ct].mAllocation = a; 129 mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a)); 130 return; 131 } 132 } 133 throw new RSIllegalArgumentException("Script not found"); 134 } 135 136 /** 137 * Execute the ScriptGroup. This will run all the kernels in 138 * the ScriptGroup. No internal connection results will be visible 139 * after execution of the ScriptGroup. 140 */ 141 public void execute() { 142 mRS.nScriptGroupExecute(getID(mRS)); 143 } 144 145 146 /** 147 * Helper class to build a ScriptGroup. A ScriptGroup is 148 * created in two steps. 149 * <p> 150 * First, all kernels to be used by the ScriptGroup should be added. 151 * <p> 152 * Second, add connections between kernels. There are two types 153 * of connections: kernel to kernel and kernel to field. 154 * Kernel to kernel allows a kernel's output to be passed to 155 * another kernel as input. Kernel to field allows the output of 156 * one kernel to be bound as a script global. Kernel to kernel is 157 * higher performance and should be used where possible. 158 * <p> 159 * A ScriptGroup must contain a single directed acyclic graph (DAG); it 160 * cannot contain cycles. Currently, all kernels used in a ScriptGroup 161 * must come from different Script objects. Additionally, all kernels 162 * in a ScriptGroup must have at least one input, output, or internal 163 * connection. 164 * <p> 165 * Once all connections are made, a call to {@link #create} will 166 * return the ScriptGroup object. 167 * 168 */ 169 public static final class Builder { 170 private RenderScript mRS; 171 private ArrayList<Node> mNodes = new ArrayList<Node>(); 172 private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>(); 173 private int mKernelCount; 174 175 private ScriptGroupThunker.Builder mT; 176 177 /** 178 * Create a Builder for generating a ScriptGroup. 179 * 180 * 181 * @param rs The RenderScript context. 182 */ 183 public Builder(RenderScript rs) { 184 if (rs.isNative) { 185 mT = new ScriptGroupThunker.Builder(rs); 186 } 187 mRS = rs; 188 } 189 190 // do a DFS from original node, looking for original node 191 // any cycle that could be created must contain original node 192 private void validateCycle(Node target, Node original) { 193 for (int ct = 0; ct < target.mOutputs.size(); ct++) { 194 final ConnectLine cl = target.mOutputs.get(ct); 195 if (cl.mToK != null) { 196 Node tn = findNode(cl.mToK.mScript); 197 if (tn.equals(original)) { 198 throw new RSInvalidStateException("Loops in group not allowed."); 199 } 200 validateCycle(tn, original); 201 } 202 if (cl.mToF != null) { 203 Node tn = findNode(cl.mToF.mScript); 204 if (tn.equals(original)) { 205 throw new RSInvalidStateException("Loops in group not allowed."); 206 } 207 validateCycle(tn, original); 208 } 209 } 210 } 211 212 private void mergeDAGs(int valueUsed, int valueKilled) { 213 for (int ct=0; ct < mNodes.size(); ct++) { 214 if (mNodes.get(ct).dagNumber == valueKilled) 215 mNodes.get(ct).dagNumber = valueUsed; 216 } 217 } 218 219 private void validateDAGRecurse(Node n, int dagNumber) { 220 // combine DAGs if this node has been seen already 221 if (n.dagNumber != 0 && n.dagNumber != dagNumber) { 222 mergeDAGs(n.dagNumber, dagNumber); 223 return; 224 } 225 226 n.dagNumber = dagNumber; 227 for (int ct=0; ct < n.mOutputs.size(); ct++) { 228 final ConnectLine cl = n.mOutputs.get(ct); 229 if (cl.mToK != null) { 230 Node tn = findNode(cl.mToK.mScript); 231 validateDAGRecurse(tn, dagNumber); 232 } 233 if (cl.mToF != null) { 234 Node tn = findNode(cl.mToF.mScript); 235 validateDAGRecurse(tn, dagNumber); 236 } 237 } 238 } 239 240 private void validateDAG() { 241 for (int ct=0; ct < mNodes.size(); ct++) { 242 Node n = mNodes.get(ct); 243 if (n.mInputs.size() == 0) { 244 if (n.mOutputs.size() == 0 && mNodes.size() > 1) { 245 throw new RSInvalidStateException("Groups cannot contain unconnected scripts"); 246 } 247 validateDAGRecurse(n, ct+1); 248 } 249 } 250 int dagNumber = mNodes.get(0).dagNumber; 251 for (int ct=0; ct < mNodes.size(); ct++) { 252 if (mNodes.get(ct).dagNumber != dagNumber) { 253 throw new RSInvalidStateException("Multiple DAGs in group not allowed."); 254 } 255 } 256 } 257 258 private Node findNode(Script s) { 259 for (int ct=0; ct < mNodes.size(); ct++) { 260 if (s == mNodes.get(ct).mScript) { 261 return mNodes.get(ct); 262 } 263 } 264 return null; 265 } 266 267 private Node findNode(Script.KernelID k) { 268 for (int ct=0; ct < mNodes.size(); ct++) { 269 Node n = mNodes.get(ct); 270 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 271 if (k == n.mKernels.get(ct2)) { 272 return n; 273 } 274 } 275 } 276 return null; 277 } 278 279 /** 280 * Adds a Kernel to the group. 281 * 282 * 283 * @param k The kernel to add. 284 * 285 * @return Builder Returns this. 286 */ 287 public Builder addKernel(Script.KernelID k) { 288 if (mT != null) { 289 mT.addKernel(k); 290 return this; 291 } 292 293 if (mLines.size() != 0) { 294 throw new RSInvalidStateException( 295 "Kernels may not be added once connections exist."); 296 } 297 298 //android.util.Log.v("RSR", "addKernel 1 k=" + k); 299 if (findNode(k) != null) { 300 return this; 301 } 302 //android.util.Log.v("RSR", "addKernel 2 "); 303 mKernelCount++; 304 Node n = findNode(k.mScript); 305 if (n == null) { 306 //android.util.Log.v("RSR", "addKernel 3 "); 307 n = new Node(k.mScript); 308 mNodes.add(n); 309 } 310 n.mKernels.add(k); 311 return this; 312 } 313 314 /** 315 * Adds a connection to the group. 316 * 317 * 318 * @param t The type of the connection. This is used to 319 * determine the kernel launch sizes on the source side 320 * of this connection. 321 * @param from The source for the connection. 322 * @param to The destination of the connection. 323 * 324 * @return Builder Returns this 325 */ 326 public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) { 327 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 328 329 if (mT != null) { 330 mT.addConnection(t, from, to); 331 return this; 332 } 333 334 Node nf = findNode(from); 335 if (nf == null) { 336 throw new RSInvalidStateException("From script not found."); 337 } 338 339 Node nt = findNode(to.mScript); 340 if (nt == null) { 341 throw new RSInvalidStateException("To script not found."); 342 } 343 344 ConnectLine cl = new ConnectLine(t, from, to); 345 mLines.add(new ConnectLine(t, from, to)); 346 347 nf.mOutputs.add(cl); 348 nt.mInputs.add(cl); 349 350 validateCycle(nf, nf); 351 return this; 352 } 353 354 /** 355 * Adds a connection to the group. 356 * 357 * 358 * @param t The type of the connection. This is used to 359 * determine the kernel launch sizes for both sides of 360 * this connection. 361 * @param from The source for the connection. 362 * @param to The destination of the connection. 363 * 364 * @return Builder Returns this 365 */ 366 public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) { 367 //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to); 368 369 if (mT != null) { 370 mT.addConnection(t, from, to); 371 return this; 372 } 373 374 Node nf = findNode(from); 375 if (nf == null) { 376 throw new RSInvalidStateException("From script not found."); 377 } 378 379 Node nt = findNode(to); 380 if (nt == null) { 381 throw new RSInvalidStateException("To script not found."); 382 } 383 384 ConnectLine cl = new ConnectLine(t, from, to); 385 mLines.add(new ConnectLine(t, from, to)); 386 387 nf.mOutputs.add(cl); 388 nt.mInputs.add(cl); 389 390 validateCycle(nf, nf); 391 return this; 392 } 393 394 395 396 /** 397 * Creates the Script group. 398 * 399 * 400 * @return ScriptGroup The new ScriptGroup 401 */ 402 public ScriptGroup create() { 403 404 if (mT != null) { 405 return mT.create(); 406 } 407 408 if (mNodes.size() == 0) { 409 throw new RSInvalidStateException("Empty script groups are not allowed"); 410 } 411 412 // reset DAG numbers in case we're building a second group 413 for (int ct=0; ct < mNodes.size(); ct++) { 414 mNodes.get(ct).dagNumber = 0; 415 } 416 validateDAG(); 417 418 ArrayList<IO> inputs = new ArrayList<IO>(); 419 ArrayList<IO> outputs = new ArrayList<IO>(); 420 421 int[] kernels = new int[mKernelCount]; 422 int idx = 0; 423 for (int ct=0; ct < mNodes.size(); ct++) { 424 Node n = mNodes.get(ct); 425 for (int ct2=0; ct2 < n.mKernels.size(); ct2++) { 426 final Script.KernelID kid = n.mKernels.get(ct2); 427 kernels[idx++] = kid.getID(mRS); 428 429 boolean hasInput = false; 430 boolean hasOutput = false; 431 for (int ct3=0; ct3 < n.mInputs.size(); ct3++) { 432 if (n.mInputs.get(ct3).mToK == kid) { 433 hasInput = true; 434 } 435 } 436 for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) { 437 if (n.mOutputs.get(ct3).mFrom == kid) { 438 hasOutput = true; 439 } 440 } 441 if (!hasInput) { 442 inputs.add(new IO(kid)); 443 } 444 if (!hasOutput) { 445 outputs.add(new IO(kid)); 446 } 447 448 } 449 } 450 if (idx != mKernelCount) { 451 throw new RSRuntimeException("Count mismatch, should not happen."); 452 } 453 454 int[] src = new int[mLines.size()]; 455 int[] dstk = new int[mLines.size()]; 456 int[] dstf = new int[mLines.size()]; 457 int[] types = new int[mLines.size()]; 458 459 for (int ct=0; ct < mLines.size(); ct++) { 460 ConnectLine cl = mLines.get(ct); 461 src[ct] = cl.mFrom.getID(mRS); 462 if (cl.mToK != null) { 463 dstk[ct] = cl.mToK.getID(mRS); 464 } 465 if (cl.mToF != null) { 466 dstf[ct] = cl.mToF.getID(mRS); 467 } 468 types[ct] = cl.mAllocationType.getID(mRS); 469 } 470 471 int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types); 472 if (id == 0) { 473 throw new RSRuntimeException("Object creation error, should not happen."); 474 } 475 476 ScriptGroup sg = new ScriptGroup(id, mRS); 477 sg.mOutputs = new IO[outputs.size()]; 478 for (int ct=0; ct < outputs.size(); ct++) { 479 sg.mOutputs[ct] = outputs.get(ct); 480 } 481 482 sg.mInputs = new IO[inputs.size()]; 483 for (int ct=0; ct < inputs.size(); ct++) { 484 sg.mInputs[ct] = inputs.get(ct); 485 } 486 487 return sg; 488 } 489 490 } 491 492 493} 494 495 496