ScriptGroup.java revision 6f5555db1af436bb5aad430e6e00aa5b69d5ca6c
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.support.v8.renderscript;
18
19import java.lang.reflect.Method;
20import java.util.ArrayList;
21import java.util.Collections;
22import java.util.Comparator;
23
24/**
25 * ScriptGroup creates a group of kernels that are executed
26 * together with one execution call as if they were a single kernel.
27 * The kernels may be connected internally or to an external allocation.
28 * The intermediate results for internal connections are not observable
29 * after the execution of the script.
30 * <p>
31 * External connections are grouped into inputs and outputs.
32 * All outputs are produced by a script kernel and placed into a
33 * user-supplied allocation. Inputs provide the input of a kernel.
34 * Inputs bound to script globals are set directly upon the script.
35 * <p>
36 * A ScriptGroup must contain at least one kernel. A ScriptGroup
37 * must contain only a single directed acyclic graph (DAG) of
38 * script kernels and connections. Attempting to create a
39 * ScriptGroup with multiple DAGs or attempting to create
40 * a cycle within a ScriptGroup will throw an exception.
41 * <p>
42 * Currently, all kernels in a ScriptGroup must be from separate
43 * Script objects. Attempting to use multiple kernels from the same
44 * Script object will result in an
45 * {@link android.support.v8.renderscript.RSInvalidStateException}.
46 *
47 **/
48public class ScriptGroup extends BaseObj {
49    IO mOutputs[];
50    IO mInputs[];
51    private boolean mUseIncSupp = false;
52    private ArrayList<Node> mNodes = new ArrayList<Node>();
53
54    static class IO {
55        Script.KernelID mKID;
56        Allocation mAllocation;
57
58        IO(Script.KernelID s) {
59            mKID = s;
60        }
61    }
62
63    static class ConnectLine {
64        ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
65            mFrom = from;
66            mToK = to;
67            mAllocationType = t;
68        }
69
70        ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
71            mFrom = from;
72            mToF = to;
73            mAllocationType = t;
74        }
75
76        Script.FieldID mToF;
77        Script.KernelID mToK;
78        Script.KernelID mFrom;
79        Type mAllocationType;
80        Allocation mAllocation;
81    }
82
83    static class Node {
84        Script mScript;
85        ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
86        ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
87        ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
88        int dagNumber;
89        boolean mSeen;
90        int mOrder;
91
92        Node mNext;
93
94        Node(Script s) {
95            mScript = s;
96        }
97    }
98
99
100    ScriptGroup(long id, RenderScript rs) {
101        super(id, rs);
102    }
103
104    /**
105     * Sets an input of the ScriptGroup. This specifies an
106     * Allocation to be used for kernels that require an input
107     * Allocation provided from outside of the ScriptGroup.
108     *
109     * @param s The ID of the kernel where the allocation should be
110     *          connected.
111     * @param a The allocation to connect.
112     */
113    public void setInput(Script.KernelID s, Allocation a) {
114        for (int ct=0; ct < mInputs.length; ct++) {
115            if (mInputs[ct].mKID == s) {
116                mInputs[ct].mAllocation = a;
117                if (!mUseIncSupp) {
118                    mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
119                }
120                return;
121            }
122        }
123        throw new RSIllegalArgumentException("Script not found");
124    }
125
126    /**
127     * Sets an output of the ScriptGroup. This specifies an
128     * Allocation to be used for the kernels that require an output
129     * Allocation visible after the ScriptGroup is executed.
130     *
131     * @param s The ID of the kernel where the allocation should be
132     *          connected.
133     * @param a The allocation to connect.
134     */
135    public void setOutput(Script.KernelID s, Allocation a) {
136        for (int ct=0; ct < mOutputs.length; ct++) {
137            if (mOutputs[ct].mKID == s) {
138                mOutputs[ct].mAllocation = a;
139                if (!mUseIncSupp) {
140                    mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
141                }
142                return;
143            }
144        }
145        throw new RSIllegalArgumentException("Script not found");
146    }
147
148    /**
149     * Execute the ScriptGroup.  This will run all the kernels in
150     * the ScriptGroup.  No internal connection results will be visible
151     * after execution of the ScriptGroup.
152     *
153     * If Incremental Support for intrinsics is needed, the execution
154     * will take the naive path: execute kernels one by one in the
155     * correct order.
156     */
157    public void execute() {
158        if (!mUseIncSupp) {
159            mRS.nScriptGroupExecute(getID(mRS));
160        } else {
161            // setup the allocations.
162            for (int ct=0; ct < mNodes.size(); ct++) {
163                Node n = mNodes.get(ct);
164                for (int ct2=0; ct2 < n.mOutputs.size(); ct2++) {
165                    ConnectLine l = n.mOutputs.get(ct2);
166                    if (l.mAllocation !=null) {
167                        continue;
168                    }
169
170                    //create allocation here
171                    Allocation alloc = Allocation.createTyped(mRS, l.mAllocationType,
172                                                              Allocation.MipmapControl.MIPMAP_NONE,
173                                                              Allocation.USAGE_SCRIPT);
174
175                    l.mAllocation = alloc;
176                    for (int ct3=ct2+1; ct3 < n.mOutputs.size(); ct3++) {
177                        if (n.mOutputs.get(ct3).mFrom == l.mFrom) {
178                            n.mOutputs.get(ct3).mAllocation = alloc;
179                        }
180                    }
181                }
182            }
183            for (Node node : mNodes) {
184                for (Script.KernelID kernel : node.mKernels) {
185                    Allocation ain  = null;
186                    Allocation aout = null;
187
188                    for (ConnectLine nodeInput : node.mInputs) {
189                        if (nodeInput.mToK == kernel) {
190                            ain = nodeInput.mAllocation;
191                        }
192                    }
193
194                    for (IO sgInput : mInputs) {
195                        if (sgInput.mKID == kernel) {
196                            ain = sgInput.mAllocation;
197                        }
198                    }
199
200                    for (ConnectLine nodeOutput : node.mOutputs) {
201                        if (nodeOutput.mFrom == kernel) {
202                            aout = nodeOutput.mAllocation;
203                        }
204                    }
205
206                    for (IO sgOutput : mOutputs) {
207                        if (sgOutput.mKID == kernel) {
208                            aout = sgOutput.mAllocation;
209                        }
210                    }
211
212                    kernel.mScript.forEach(kernel.mSlot, ain, aout, null);
213                }
214            }
215        }
216    }
217
218
219    /**
220     * Helper class to build a ScriptGroup. A ScriptGroup is
221     * created in two steps.
222     * <p>
223     * First, all kernels to be used by the ScriptGroup should be added.
224     * <p>
225     * Second, add connections between kernels. There are two types
226     * of connections: kernel to kernel and kernel to field.
227     * Kernel to kernel allows a kernel's output to be passed to
228     * another kernel as input. Kernel to field allows the output of
229     * one kernel to be bound as a script global. Kernel to kernel is
230     * higher performance and should be used where possible.
231     * <p>
232     * A ScriptGroup must contain a single directed acyclic graph (DAG); it
233     * cannot contain cycles. Currently, all kernels used in a ScriptGroup
234     * must come from different Script objects. Additionally, all kernels
235     * in a ScriptGroup must have at least one input, output, or internal
236     * connection.
237     * <p>
238     * Once all connections are made, a call to {@link #create} will
239     * return the ScriptGroup object.
240     *
241     */
242    public static final class Builder {
243        private RenderScript mRS;
244        private ArrayList<Node> mNodes = new ArrayList<Node>();
245        private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
246        private int mKernelCount;
247        private boolean mUseIncSupp = false;
248
249        /**
250         * Create a Builder for generating a ScriptGroup.
251         *
252         *
253         * @param rs The RenderScript context.
254         */
255        public Builder(RenderScript rs) {
256            mRS = rs;
257        }
258
259        // do a DFS from original node, looking for original node
260        // any cycle that could be created must contain original node
261        private void validateCycle(Node target, Node original) {
262            for (int ct = 0; ct < target.mOutputs.size(); ct++) {
263                final ConnectLine cl = target.mOutputs.get(ct);
264                if (cl.mToK != null) {
265                    Node tn = findNode(cl.mToK.mScript);
266                    if (tn.equals(original)) {
267                        throw new RSInvalidStateException("Loops in group not allowed.");
268                    }
269                    validateCycle(tn, original);
270                }
271                if (cl.mToF != null) {
272                    Node tn = findNode(cl.mToF.mScript);
273                    if (tn.equals(original)) {
274                        throw new RSInvalidStateException("Loops in group not allowed.");
275                    }
276                    validateCycle(tn, original);
277                }
278            }
279        }
280
281        private void mergeDAGs(int valueUsed, int valueKilled) {
282            for (int ct=0; ct < mNodes.size(); ct++) {
283                if (mNodes.get(ct).dagNumber == valueKilled)
284                    mNodes.get(ct).dagNumber = valueUsed;
285            }
286        }
287
288        private void validateDAGRecurse(Node n, int dagNumber) {
289            // combine DAGs if this node has been seen already
290            if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
291                mergeDAGs(n.dagNumber, dagNumber);
292                return;
293            }
294
295            n.dagNumber = dagNumber;
296            for (int ct=0; ct < n.mOutputs.size(); ct++) {
297                final ConnectLine cl = n.mOutputs.get(ct);
298                if (cl.mToK != null) {
299                    Node tn = findNode(cl.mToK.mScript);
300                    validateDAGRecurse(tn, dagNumber);
301                }
302                if (cl.mToF != null) {
303                    Node tn = findNode(cl.mToF.mScript);
304                    validateDAGRecurse(tn, dagNumber);
305                }
306            }
307        }
308
309        private void validateDAG() {
310            for (int ct=0; ct < mNodes.size(); ct++) {
311                Node n = mNodes.get(ct);
312                if (n.mInputs.size() == 0) {
313                    if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
314                        throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
315                    }
316                    validateDAGRecurse(n, ct+1);
317                }
318            }
319            int dagNumber = mNodes.get(0).dagNumber;
320            for (int ct=0; ct < mNodes.size(); ct++) {
321                if (mNodes.get(ct).dagNumber != dagNumber) {
322                    throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
323                }
324            }
325        }
326
327        private Node findNode(Script s) {
328            for (int ct=0; ct < mNodes.size(); ct++) {
329                if (s == mNodes.get(ct).mScript) {
330                    return mNodes.get(ct);
331                }
332            }
333            return null;
334        }
335
336        private Node findNode(Script.KernelID k) {
337            for (int ct=0; ct < mNodes.size(); ct++) {
338                Node n = mNodes.get(ct);
339                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
340                    if (k == n.mKernels.get(ct2)) {
341                        return n;
342                    }
343                }
344            }
345            return null;
346        }
347
348        /**
349         * Adds a Kernel to the group.
350         *
351         *
352         * @param k The kernel to add.
353         *
354         * @return Builder Returns this.
355         */
356        public Builder addKernel(Script.KernelID k) {
357            if (mLines.size() != 0) {
358                throw new RSInvalidStateException(
359                    "Kernels may not be added once connections exist.");
360            }
361            if (k.mScript.isIncSupp()) {
362                mUseIncSupp = true;
363            }
364            //android.util.Log.v("RSR", "addKernel 1 k=" + k);
365            if (findNode(k) != null) {
366                return this;
367            }
368            //android.util.Log.v("RSR", "addKernel 2 ");
369            mKernelCount++;
370            Node n = findNode(k.mScript);
371            if (n == null) {
372                //android.util.Log.v("RSR", "addKernel 3 ");
373                n = new Node(k.mScript);
374                mNodes.add(n);
375            }
376            n.mKernels.add(k);
377            return this;
378        }
379
380        /**
381         * Adds a connection to the group.
382         *
383         *
384         * @param t The type of the connection. This is used to
385         *          determine the kernel launch sizes on the source side
386         *          of this connection.
387         * @param from The source for the connection.
388         * @param to The destination of the connection.
389         *
390         * @return Builder Returns this
391         */
392        public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
393            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
394            Node nf = findNode(from);
395            if (nf == null) {
396                throw new RSInvalidStateException("From script not found.");
397            }
398
399            Node nt = findNode(to.mScript);
400            if (nt == null) {
401                throw new RSInvalidStateException("To script not found.");
402            }
403
404            ConnectLine cl = new ConnectLine(t, from, to);
405            mLines.add(new ConnectLine(t, from, to));
406
407            nf.mOutputs.add(cl);
408            nt.mInputs.add(cl);
409
410            validateCycle(nf, nf);
411            return this;
412        }
413
414        /**
415         * Adds a connection to the group.
416         *
417         *
418         * @param t The type of the connection. This is used to
419         *          determine the kernel launch sizes for both sides of
420         *          this connection.
421         * @param from The source for the connection.
422         * @param to The destination of the connection.
423         *
424         * @return Builder Returns this
425         */
426        public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
427            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
428            Node nf = findNode(from);
429            if (nf == null) {
430                throw new RSInvalidStateException("From script not found.");
431            }
432
433            Node nt = findNode(to);
434            if (nt == null) {
435                throw new RSInvalidStateException("To script not found.");
436            }
437
438            ConnectLine cl = new ConnectLine(t, from, to);
439            mLines.add(new ConnectLine(t, from, to));
440
441            nf.mOutputs.add(cl);
442            nt.mInputs.add(cl);
443
444            validateCycle(nf, nf);
445            return this;
446        }
447
448        /**
449         * Calculate the order of each node.
450         *
451         *
452         * @return Success or Fail
453         */
454        private boolean calcOrderRecurse(Node node0, int depth) {
455            node0.mSeen = true;
456            if (node0.mOrder < depth) {
457                node0.mOrder = depth;
458            }
459            boolean ret = true;
460
461            for (ConnectLine link : node0.mOutputs) {
462                Node node1 = null;
463                if (link.mToF != null) {
464                    node1 = findNode(link.mToF.mScript);
465                } else {
466                    node1 = findNode(link.mToK.mScript);
467                }
468                if (node1.mSeen) {
469                    return false;
470                }
471                ret &= calcOrderRecurse(node1, node0.mOrder + 1);
472            }
473
474            return ret;
475        }
476
477        private boolean calcOrder() {
478            boolean ret = true;
479            for (Node n0 : mNodes) {
480                if (n0.mInputs.size() == 0) {
481                    for (Node n1 : mNodes) {
482                        n1.mSeen = false;
483                    }
484                    ret &= calcOrderRecurse(n0, 1);
485                }
486            }
487
488            Collections.sort(mNodes, new Comparator<Node>() {
489                public int compare(Node n1, Node n2) {
490                    return n1.mOrder - n2.mOrder;
491                }
492            });
493
494            return ret;
495        }
496
497        /**
498         * Creates the Script group.
499         *
500         *
501         * @return ScriptGroup The new ScriptGroup
502         */
503        public ScriptGroup create() {
504
505            if (mNodes.size() == 0) {
506                throw new RSInvalidStateException("Empty script groups are not allowed");
507            }
508
509            // reset DAG numbers in case we're building a second group
510            for (int ct=0; ct < mNodes.size(); ct++) {
511                mNodes.get(ct).dagNumber = 0;
512            }
513            validateDAG();
514
515            ArrayList<IO> inputs = new ArrayList<IO>();
516            ArrayList<IO> outputs = new ArrayList<IO>();
517
518            long[] kernels = new long[mKernelCount];
519            int idx = 0;
520            for (int ct=0; ct < mNodes.size(); ct++) {
521                Node n = mNodes.get(ct);
522                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
523                    final Script.KernelID kid = n.mKernels.get(ct2);
524                    kernels[idx++] = kid.getID(mRS);
525
526                    boolean hasInput = false;
527                    boolean hasOutput = false;
528                    for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
529                        if (n.mInputs.get(ct3).mToK == kid) {
530                            hasInput = true;
531                        }
532                    }
533                    for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
534                        if (n.mOutputs.get(ct3).mFrom == kid) {
535                            hasOutput = true;
536                        }
537                    }
538                    if (!hasInput) {
539                        inputs.add(new IO(kid));
540                    }
541                    if (!hasOutput) {
542                        outputs.add(new IO(kid));
543                    }
544                }
545            }
546            if (idx != mKernelCount) {
547                throw new RSRuntimeException("Count mismatch, should not happen.");
548            }
549
550            long id = 0;
551            if (!mUseIncSupp) {
552                long[] src = new long[mLines.size()];
553                long[] dstk = new long[mLines.size()];
554                long[] dstf = new long[mLines.size()];
555                long[] types = new long[mLines.size()];
556
557                for (int ct=0; ct < mLines.size(); ct++) {
558                    ConnectLine cl = mLines.get(ct);
559                    src[ct] = cl.mFrom.getID(mRS);
560                    if (cl.mToK != null) {
561                        dstk[ct] = cl.mToK.getID(mRS);
562                    }
563                    if (cl.mToF != null) {
564                        dstf[ct] = cl.mToF.getID(mRS);
565                    }
566                    types[ct] = cl.mAllocationType.getID(mRS);
567                }
568                id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
569                if (id == 0) {
570                    throw new RSRuntimeException("Object creation error, should not happen.");
571                }
572            } else {
573                //Calculate the order of the DAG so that script can run one after another.
574                calcOrder();
575            }
576
577            ScriptGroup sg = new ScriptGroup(id, mRS);
578            sg.mOutputs = new IO[outputs.size()];
579            for (int ct=0; ct < outputs.size(); ct++) {
580                sg.mOutputs[ct] = outputs.get(ct);
581            }
582
583            sg.mInputs = new IO[inputs.size()];
584            for (int ct=0; ct < inputs.size(); ct++) {
585                sg.mInputs[ct] = inputs.get(ct);
586            }
587            sg.mNodes = mNodes;
588            sg.mUseIncSupp = mUseIncSupp;
589            return sg;
590        }
591
592    }
593
594
595}
596
597
598