ScriptGroup.java revision 7d435ae5ba100be5710b685653cc351cab159c11
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.support.v8.renderscript;
18
19import java.lang.reflect.Method;
20import java.util.ArrayList;
21
22/**
23 * ScriptGroup creates a group of kernels that are executed
24 * together with one execution call as if they were a single kernel.
25 * The kernels may be connected internally or to an external allocation.
26 * The intermediate results for internal connections are not observable
27 * after the execution of the script.
28 * <p>
29 * External connections are grouped into inputs and outputs.
30 * All outputs are produced by a script kernel and placed into a
31 * user-supplied allocation. Inputs provide the input of a kernel.
32 * Inputs bound to script globals are set directly upon the script.
33 * <p>
34 * A ScriptGroup must contain at least one kernel. A ScriptGroup
35 * must contain only a single directed acyclic graph (DAG) of
36 * script kernels and connections. Attempting to create a
37 * ScriptGroup with multiple DAGs or attempting to create
38 * a cycle within a ScriptGroup will throw an exception.
39 * <p>
40 * Currently, all kernels in a ScriptGroup must be from separate
41 * Script objects. Attempting to use multiple kernels from the same
42 * Script object will result in an {@link android.renderscript.RSInvalidStateException}.
43 *
44 **/
45public class ScriptGroup extends BaseObj {
46    IO mOutputs[];
47    IO mInputs[];
48
49    static class IO {
50        Script.KernelID mKID;
51        Allocation mAllocation;
52
53        IO(Script.KernelID s) {
54            mKID = s;
55        }
56    }
57
58    static class ConnectLine {
59        ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
60            mFrom = from;
61            mToK = to;
62            mAllocationType = t;
63        }
64
65        ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
66            mFrom = from;
67            mToF = to;
68            mAllocationType = t;
69        }
70
71        Script.FieldID mToF;
72        Script.KernelID mToK;
73        Script.KernelID mFrom;
74        Type mAllocationType;
75    }
76
77    static class Node {
78        Script mScript;
79        ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
80        ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
81        ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
82        int dagNumber;
83
84        Node mNext;
85
86        Node(Script s) {
87            mScript = s;
88        }
89    }
90
91
92    ScriptGroup(int id, RenderScript rs) {
93        super(id, rs);
94    }
95
96    /**
97     * Sets an input of the ScriptGroup. This specifies an
98     * Allocation to be used for kernels that require an input
99     * Allocation provided from outside of the ScriptGroup.
100     *
101     * @param s The ID of the kernel where the allocation should be
102     *          connected.
103     * @param a The allocation to connect.
104     */
105    public void setInput(Script.KernelID s, Allocation a) {
106        for (int ct=0; ct < mInputs.length; ct++) {
107            if (mInputs[ct].mKID == s) {
108                mInputs[ct].mAllocation = a;
109                mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
110                return;
111            }
112        }
113        throw new RSIllegalArgumentException("Script not found");
114    }
115
116    /**
117     * Sets an output of the ScriptGroup. This specifies an
118     * Allocation to be used for the kernels that require an output
119     * Allocation visible after the ScriptGroup is executed.
120     *
121     * @param s The ID of the kernel where the allocation should be
122     *          connected.
123     * @param a The allocation to connect.
124     */
125    public void setOutput(Script.KernelID s, Allocation a) {
126        for (int ct=0; ct < mOutputs.length; ct++) {
127            if (mOutputs[ct].mKID == s) {
128                mOutputs[ct].mAllocation = a;
129                mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
130                return;
131            }
132        }
133        throw new RSIllegalArgumentException("Script not found");
134    }
135
136    /**
137     * Execute the ScriptGroup.  This will run all the kernels in
138     * the ScriptGroup.  No internal connection results will be visible
139     * after execution of the ScriptGroup.
140     */
141    public void execute() {
142        mRS.nScriptGroupExecute(getID(mRS));
143    }
144
145
146    /**
147     * Helper class to build a ScriptGroup. A ScriptGroup is
148     * created in two steps.
149     * <p>
150     * First, all kernels to be used by the ScriptGroup should be added.
151     * <p>
152     * Second, add connections between kernels. There are two types
153     * of connections: kernel to kernel and kernel to field.
154     * Kernel to kernel allows a kernel's output to be passed to
155     * another kernel as input. Kernel to field allows the output of
156     * one kernel to be bound as a script global. Kernel to kernel is
157     * higher performance and should be used where possible.
158     * <p>
159     * A ScriptGroup must contain a single directed acyclic graph (DAG); it
160     * cannot contain cycles. Currently, all kernels used in a ScriptGroup
161     * must come from different Script objects. Additionally, all kernels
162     * in a ScriptGroup must have at least one input, output, or internal
163     * connection.
164     * <p>
165     * Once all connections are made, a call to {@link #create} will
166     * return the ScriptGroup object.
167     *
168     */
169    public static final class Builder {
170        private RenderScript mRS;
171        private ArrayList<Node> mNodes = new ArrayList<Node>();
172        private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
173        private int mKernelCount;
174
175        private ScriptGroupThunker.Builder mT;
176
177        /**
178         * Create a Builder for generating a ScriptGroup.
179         *
180         *
181         * @param rs The RenderScript context.
182         */
183        public Builder(RenderScript rs) {
184            if (rs.isNative) {
185                mT = new ScriptGroupThunker.Builder(rs);
186            }
187            mRS = rs;
188        }
189
190        // do a DFS from original node, looking for original node
191        // any cycle that could be created must contain original node
192        private void validateCycle(Node target, Node original) {
193            for (int ct = 0; ct < target.mOutputs.size(); ct++) {
194                final ConnectLine cl = target.mOutputs.get(ct);
195                if (cl.mToK != null) {
196                    Node tn = findNode(cl.mToK.mScript);
197                    if (tn.equals(original)) {
198                        throw new RSInvalidStateException("Loops in group not allowed.");
199                    }
200                    validateCycle(tn, original);
201                }
202                if (cl.mToF != null) {
203                    Node tn = findNode(cl.mToF.mScript);
204                    if (tn.equals(original)) {
205                        throw new RSInvalidStateException("Loops in group not allowed.");
206                    }
207                    validateCycle(tn, original);
208                }
209            }
210        }
211
212        private void mergeDAGs(int valueUsed, int valueKilled) {
213            for (int ct=0; ct < mNodes.size(); ct++) {
214                if (mNodes.get(ct).dagNumber == valueKilled)
215                    mNodes.get(ct).dagNumber = valueUsed;
216            }
217        }
218
219        private void validateDAGRecurse(Node n, int dagNumber) {
220            // combine DAGs if this node has been seen already
221            if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
222                mergeDAGs(n.dagNumber, dagNumber);
223                return;
224            }
225
226            n.dagNumber = dagNumber;
227            for (int ct=0; ct < n.mOutputs.size(); ct++) {
228                final ConnectLine cl = n.mOutputs.get(ct);
229                if (cl.mToK != null) {
230                    Node tn = findNode(cl.mToK.mScript);
231                    validateDAGRecurse(tn, dagNumber);
232                }
233                if (cl.mToF != null) {
234                    Node tn = findNode(cl.mToF.mScript);
235                    validateDAGRecurse(tn, dagNumber);
236                }
237            }
238        }
239
240        private void validateDAG() {
241            for (int ct=0; ct < mNodes.size(); ct++) {
242                Node n = mNodes.get(ct);
243                if (n.mInputs.size() == 0) {
244                    if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
245                        throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
246                    }
247                    validateDAGRecurse(n, ct+1);
248                }
249            }
250            int dagNumber = mNodes.get(0).dagNumber;
251            for (int ct=0; ct < mNodes.size(); ct++) {
252                if (mNodes.get(ct).dagNumber != dagNumber) {
253                    throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
254                }
255            }
256        }
257
258        private Node findNode(Script s) {
259            for (int ct=0; ct < mNodes.size(); ct++) {
260                if (s == mNodes.get(ct).mScript) {
261                    return mNodes.get(ct);
262                }
263            }
264            return null;
265        }
266
267        private Node findNode(Script.KernelID k) {
268            for (int ct=0; ct < mNodes.size(); ct++) {
269                Node n = mNodes.get(ct);
270                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
271                    if (k == n.mKernels.get(ct2)) {
272                        return n;
273                    }
274                }
275            }
276            return null;
277        }
278
279        /**
280         * Adds a Kernel to the group.
281         *
282         *
283         * @param k The kernel to add.
284         *
285         * @return Builder Returns this.
286         */
287        public Builder addKernel(Script.KernelID k) {
288            if (mT != null) {
289                mT.addKernel(k);
290                return this;
291            }
292
293            if (mLines.size() != 0) {
294                throw new RSInvalidStateException(
295                    "Kernels may not be added once connections exist.");
296            }
297
298            //android.util.Log.v("RSR", "addKernel 1 k=" + k);
299            if (findNode(k) != null) {
300                return this;
301            }
302            //android.util.Log.v("RSR", "addKernel 2 ");
303            mKernelCount++;
304            Node n = findNode(k.mScript);
305            if (n == null) {
306                //android.util.Log.v("RSR", "addKernel 3 ");
307                n = new Node(k.mScript);
308                mNodes.add(n);
309            }
310            n.mKernels.add(k);
311            return this;
312        }
313
314        /**
315         * Adds a connection to the group.
316         *
317         *
318         * @param t The type of the connection. This is used to
319         *          determine the kernel launch sizes on the source side
320         *          of this connection.
321         * @param from The source for the connection.
322         * @param to The destination of the connection.
323         *
324         * @return Builder Returns this
325         */
326        public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
327            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
328
329            if (mT != null) {
330                mT.addConnection(t, from, to);
331                return this;
332            }
333
334            Node nf = findNode(from);
335            if (nf == null) {
336                throw new RSInvalidStateException("From script not found.");
337            }
338
339            Node nt = findNode(to.mScript);
340            if (nt == null) {
341                throw new RSInvalidStateException("To script not found.");
342            }
343
344            ConnectLine cl = new ConnectLine(t, from, to);
345            mLines.add(new ConnectLine(t, from, to));
346
347            nf.mOutputs.add(cl);
348            nt.mInputs.add(cl);
349
350            validateCycle(nf, nf);
351            return this;
352        }
353
354        /**
355         * Adds a connection to the group.
356         *
357         *
358         * @param t The type of the connection. This is used to
359         *          determine the kernel launch sizes for both sides of
360         *          this connection.
361         * @param from The source for the connection.
362         * @param to The destination of the connection.
363         *
364         * @return Builder Returns this
365         */
366        public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
367            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
368
369            if (mT != null) {
370                mT.addConnection(t, from, to);
371                return this;
372            }
373
374            Node nf = findNode(from);
375            if (nf == null) {
376                throw new RSInvalidStateException("From script not found.");
377            }
378
379            Node nt = findNode(to);
380            if (nt == null) {
381                throw new RSInvalidStateException("To script not found.");
382            }
383
384            ConnectLine cl = new ConnectLine(t, from, to);
385            mLines.add(new ConnectLine(t, from, to));
386
387            nf.mOutputs.add(cl);
388            nt.mInputs.add(cl);
389
390            validateCycle(nf, nf);
391            return this;
392        }
393
394
395
396        /**
397         * Creates the Script group.
398         *
399         *
400         * @return ScriptGroup The new ScriptGroup
401         */
402        public ScriptGroup create() {
403
404            if (mT != null) {
405                return mT.create();
406            }
407
408            if (mNodes.size() == 0) {
409                throw new RSInvalidStateException("Empty script groups are not allowed");
410            }
411
412            // reset DAG numbers in case we're building a second group
413            for (int ct=0; ct < mNodes.size(); ct++) {
414                mNodes.get(ct).dagNumber = 0;
415            }
416            validateDAG();
417
418            ArrayList<IO> inputs = new ArrayList<IO>();
419            ArrayList<IO> outputs = new ArrayList<IO>();
420
421            int[] kernels = new int[mKernelCount];
422            int idx = 0;
423            for (int ct=0; ct < mNodes.size(); ct++) {
424                Node n = mNodes.get(ct);
425                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
426                    final Script.KernelID kid = n.mKernels.get(ct2);
427                    kernels[idx++] = kid.getID(mRS);
428
429                    boolean hasInput = false;
430                    boolean hasOutput = false;
431                    for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
432                        if (n.mInputs.get(ct3).mToK == kid) {
433                            hasInput = true;
434                        }
435                    }
436                    for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
437                        if (n.mOutputs.get(ct3).mFrom == kid) {
438                            hasOutput = true;
439                        }
440                    }
441                    if (!hasInput) {
442                        inputs.add(new IO(kid));
443                    }
444                    if (!hasOutput) {
445                        outputs.add(new IO(kid));
446                    }
447
448                }
449            }
450            if (idx != mKernelCount) {
451                throw new RSRuntimeException("Count mismatch, should not happen.");
452            }
453
454            int[] src = new int[mLines.size()];
455            int[] dstk = new int[mLines.size()];
456            int[] dstf = new int[mLines.size()];
457            int[] types = new int[mLines.size()];
458
459            for (int ct=0; ct < mLines.size(); ct++) {
460                ConnectLine cl = mLines.get(ct);
461                src[ct] = cl.mFrom.getID(mRS);
462                if (cl.mToK != null) {
463                    dstk[ct] = cl.mToK.getID(mRS);
464                }
465                if (cl.mToF != null) {
466                    dstf[ct] = cl.mToF.getID(mRS);
467                }
468                types[ct] = cl.mAllocationType.getID(mRS);
469            }
470
471            int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
472            if (id == 0) {
473                throw new RSRuntimeException("Object creation error, should not happen.");
474            }
475
476            ScriptGroup sg = new ScriptGroup(id, mRS);
477            sg.mOutputs = new IO[outputs.size()];
478            for (int ct=0; ct < outputs.size(); ct++) {
479                sg.mOutputs[ct] = outputs.get(ct);
480            }
481
482            sg.mInputs = new IO[inputs.size()];
483            for (int ct=0; ct < inputs.size(); ct++) {
484                sg.mInputs[ct] = inputs.get(ct);
485            }
486
487            return sg;
488        }
489
490    }
491
492
493}
494
495
496