ScriptGroup.java revision 355707e4f665904e31d9f5fcff1e3921f7db8cdd
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import java.lang.reflect.Method;
20import java.util.ArrayList;
21
22/**
23 * ScriptGroup creates a group of kernels that are executed
24 * together with one execution call as if they were a single kernel.
25 * The kernels may be connected internally or to an external allocation.
26 * The intermediate results for internal connections are not observable
27 * after the execution of the script.
28 * <p>
29 * External connections are grouped into inputs and outputs.
30 * All outputs are produced by a script kernel and placed into a
31 * user-supplied allocation. Inputs provide the input of a kernel.
32 * Inputs bound to script globals are set directly upon the script.
33 * <p>
34 * A ScriptGroup must contain at least one kernel. A ScriptGroup
35 * must contain only a single directed acyclic graph (DAG) of
36 * script kernels and connections. Attempting to create a
37 * ScriptGroup with multiple DAGs or attempting to create
38 * a cycle within a ScriptGroup will throw an exception.
39 * <p>
40 * Currently, all kernels in a ScriptGroup must be from separate
41 * Script objects. Attempting to use multiple kernels from the same
42 * Script object will result in an {@link android.renderscript.RSInvalidStateException}.
43 *
44 **/
45public final class ScriptGroup extends BaseObj {
46    IO mOutputs[];
47    IO mInputs[];
48
49    static class IO {
50        Script.KernelID mKID;
51        Allocation mAllocation;
52
53        IO(Script.KernelID s) {
54            mKID = s;
55        }
56    }
57
58    static class ConnectLine {
59        ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
60            mFrom = from;
61            mToK = to;
62            mAllocationType = t;
63        }
64
65        ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
66            mFrom = from;
67            mToF = to;
68            mAllocationType = t;
69        }
70
71        Script.FieldID mToF;
72        Script.KernelID mToK;
73        Script.KernelID mFrom;
74        Type mAllocationType;
75    }
76
77    static class Node {
78        Script mScript;
79        ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
80        ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
81        ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
82        int dagNumber;
83
84        Node mNext;
85
86        Node(Script s) {
87            mScript = s;
88        }
89    }
90
91
92    ScriptGroup(long id, RenderScript rs) {
93        super(id, rs);
94    }
95
96    /**
97     * Sets an input of the ScriptGroup. This specifies an
98     * Allocation to be used for kernels that require an input
99     * Allocation provided from outside of the ScriptGroup.
100     *
101     * @param s The ID of the kernel where the allocation should be
102     *          connected.
103     * @param a The allocation to connect.
104     */
105    public void setInput(Script.KernelID s, Allocation a) {
106        for (int ct=0; ct < mInputs.length; ct++) {
107            if (mInputs[ct].mKID == s) {
108                mInputs[ct].mAllocation = a;
109                mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
110                return;
111            }
112        }
113        throw new RSIllegalArgumentException("Script not found");
114    }
115
116    /**
117     * Sets an output of the ScriptGroup. This specifies an
118     * Allocation to be used for the kernels that require an output
119     * Allocation visible after the ScriptGroup is executed.
120     *
121     * @param s The ID of the kernel where the allocation should be
122     *          connected.
123     * @param a The allocation to connect.
124     */
125    public void setOutput(Script.KernelID s, Allocation a) {
126        for (int ct=0; ct < mOutputs.length; ct++) {
127            if (mOutputs[ct].mKID == s) {
128                mOutputs[ct].mAllocation = a;
129                mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
130                return;
131            }
132        }
133        throw new RSIllegalArgumentException("Script not found");
134    }
135
136    /**
137     * Execute the ScriptGroup.  This will run all the kernels in
138     * the ScriptGroup.  No internal connection results will be visible
139     * after execution of the ScriptGroup.
140     */
141    public void execute() {
142        mRS.nScriptGroupExecute(getID(mRS));
143    }
144
145
146    /**
147     * Helper class to build a ScriptGroup. A ScriptGroup is
148     * created in two steps.
149     * <p>
150     * First, all kernels to be used by the ScriptGroup should be added.
151     * <p>
152     * Second, add connections between kernels. There are two types
153     * of connections: kernel to kernel and kernel to field.
154     * Kernel to kernel allows a kernel's output to be passed to
155     * another kernel as input. Kernel to field allows the output of
156     * one kernel to be bound as a script global. Kernel to kernel is
157     * higher performance and should be used where possible.
158     * <p>
159     * A ScriptGroup must contain a single directed acyclic graph (DAG); it
160     * cannot contain cycles. Currently, all kernels used in a ScriptGroup
161     * must come from different Script objects. Additionally, all kernels
162     * in a ScriptGroup must have at least one input, output, or internal
163     * connection.
164     * <p>
165     * Once all connections are made, a call to {@link #create} will
166     * return the ScriptGroup object.
167     *
168     */
169    public static final class Builder {
170        private RenderScript mRS;
171        private ArrayList<Node> mNodes = new ArrayList<Node>();
172        private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
173        private int mKernelCount;
174
175        /**
176         * Create a Builder for generating a ScriptGroup.
177         *
178         *
179         * @param rs The RenderScript context.
180         */
181        public Builder(RenderScript rs) {
182            mRS = rs;
183        }
184
185        // do a DFS from original node, looking for original node
186        // any cycle that could be created must contain original node
187        private void validateCycle(Node target, Node original) {
188            for (int ct = 0; ct < target.mOutputs.size(); ct++) {
189                final ConnectLine cl = target.mOutputs.get(ct);
190                if (cl.mToK != null) {
191                    Node tn = findNode(cl.mToK.mScript);
192                    if (tn.equals(original)) {
193                        throw new RSInvalidStateException("Loops in group not allowed.");
194                    }
195                    validateCycle(tn, original);
196                }
197                if (cl.mToF != null) {
198                    Node tn = findNode(cl.mToF.mScript);
199                    if (tn.equals(original)) {
200                        throw new RSInvalidStateException("Loops in group not allowed.");
201                    }
202                    validateCycle(tn, original);
203                }
204            }
205        }
206
207        private void mergeDAGs(int valueUsed, int valueKilled) {
208            for (int ct=0; ct < mNodes.size(); ct++) {
209                if (mNodes.get(ct).dagNumber == valueKilled)
210                    mNodes.get(ct).dagNumber = valueUsed;
211            }
212        }
213
214        private void validateDAGRecurse(Node n, int dagNumber) {
215            // combine DAGs if this node has been seen already
216            if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
217                mergeDAGs(n.dagNumber, dagNumber);
218                return;
219            }
220
221            n.dagNumber = dagNumber;
222            for (int ct=0; ct < n.mOutputs.size(); ct++) {
223                final ConnectLine cl = n.mOutputs.get(ct);
224                if (cl.mToK != null) {
225                    Node tn = findNode(cl.mToK.mScript);
226                    validateDAGRecurse(tn, dagNumber);
227                }
228                if (cl.mToF != null) {
229                    Node tn = findNode(cl.mToF.mScript);
230                    validateDAGRecurse(tn, dagNumber);
231                }
232            }
233        }
234
235        private void validateDAG() {
236            for (int ct=0; ct < mNodes.size(); ct++) {
237                Node n = mNodes.get(ct);
238                if (n.mInputs.size() == 0) {
239                    if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
240                        throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
241                    }
242                    validateDAGRecurse(n, ct+1);
243                }
244            }
245            int dagNumber = mNodes.get(0).dagNumber;
246            for (int ct=0; ct < mNodes.size(); ct++) {
247                if (mNodes.get(ct).dagNumber != dagNumber) {
248                    throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
249                }
250            }
251        }
252
253        private Node findNode(Script s) {
254            for (int ct=0; ct < mNodes.size(); ct++) {
255                if (s == mNodes.get(ct).mScript) {
256                    return mNodes.get(ct);
257                }
258            }
259            return null;
260        }
261
262        private Node findNode(Script.KernelID k) {
263            for (int ct=0; ct < mNodes.size(); ct++) {
264                Node n = mNodes.get(ct);
265                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
266                    if (k == n.mKernels.get(ct2)) {
267                        return n;
268                    }
269                }
270            }
271            return null;
272        }
273
274        /**
275         * Adds a Kernel to the group.
276         *
277         *
278         * @param k The kernel to add.
279         *
280         * @return Builder Returns this.
281         */
282        public Builder addKernel(Script.KernelID k) {
283            if (mLines.size() != 0) {
284                throw new RSInvalidStateException(
285                    "Kernels may not be added once connections exist.");
286            }
287
288            //android.util.Log.v("RSR", "addKernel 1 k=" + k);
289            if (findNode(k) != null) {
290                return this;
291            }
292            //android.util.Log.v("RSR", "addKernel 2 ");
293            mKernelCount++;
294            Node n = findNode(k.mScript);
295            if (n == null) {
296                //android.util.Log.v("RSR", "addKernel 3 ");
297                n = new Node(k.mScript);
298                mNodes.add(n);
299            }
300            n.mKernels.add(k);
301            return this;
302        }
303
304        /**
305         * Adds a connection to the group.
306         *
307         *
308         * @param t The type of the connection. This is used to
309         *          determine the kernel launch sizes on the source side
310         *          of this connection.
311         * @param from The source for the connection.
312         * @param to The destination of the connection.
313         *
314         * @return Builder Returns this
315         */
316        public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
317            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
318
319            Node nf = findNode(from);
320            if (nf == null) {
321                throw new RSInvalidStateException("From script not found.");
322            }
323
324            Node nt = findNode(to.mScript);
325            if (nt == null) {
326                throw new RSInvalidStateException("To script not found.");
327            }
328
329            ConnectLine cl = new ConnectLine(t, from, to);
330            mLines.add(new ConnectLine(t, from, to));
331
332            nf.mOutputs.add(cl);
333            nt.mInputs.add(cl);
334
335            validateCycle(nf, nf);
336            return this;
337        }
338
339        /**
340         * Adds a connection to the group.
341         *
342         *
343         * @param t The type of the connection. This is used to
344         *          determine the kernel launch sizes for both sides of
345         *          this connection.
346         * @param from The source for the connection.
347         * @param to The destination of the connection.
348         *
349         * @return Builder Returns this
350         */
351        public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
352            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
353
354            Node nf = findNode(from);
355            if (nf == null) {
356                throw new RSInvalidStateException("From script not found.");
357            }
358
359            Node nt = findNode(to);
360            if (nt == null) {
361                throw new RSInvalidStateException("To script not found.");
362            }
363
364            ConnectLine cl = new ConnectLine(t, from, to);
365            mLines.add(new ConnectLine(t, from, to));
366
367            nf.mOutputs.add(cl);
368            nt.mInputs.add(cl);
369
370            validateCycle(nf, nf);
371            return this;
372        }
373
374
375
376        /**
377         * Creates the Script group.
378         *
379         *
380         * @return ScriptGroup The new ScriptGroup
381         */
382        public ScriptGroup create() {
383            // FIXME: this is broken for 64-bit
384
385            if (mNodes.size() == 0) {
386                throw new RSInvalidStateException("Empty script groups are not allowed");
387            }
388
389            // reset DAG numbers in case we're building a second group
390            for (int ct=0; ct < mNodes.size(); ct++) {
391                mNodes.get(ct).dagNumber = 0;
392            }
393            validateDAG();
394
395            ArrayList<IO> inputs = new ArrayList<IO>();
396            ArrayList<IO> outputs = new ArrayList<IO>();
397
398            int[] kernels = new int[mKernelCount];
399            int idx = 0;
400            for (int ct=0; ct < mNodes.size(); ct++) {
401                Node n = mNodes.get(ct);
402                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
403                    final Script.KernelID kid = n.mKernels.get(ct2);
404                    kernels[idx++] = (int)kid.getID(mRS);
405
406                    boolean hasInput = false;
407                    boolean hasOutput = false;
408                    for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
409                        if (n.mInputs.get(ct3).mToK == kid) {
410                            hasInput = true;
411                        }
412                    }
413                    for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
414                        if (n.mOutputs.get(ct3).mFrom == kid) {
415                            hasOutput = true;
416                        }
417                    }
418                    if (!hasInput) {
419                        inputs.add(new IO(kid));
420                    }
421                    if (!hasOutput) {
422                        outputs.add(new IO(kid));
423                    }
424
425                }
426            }
427            if (idx != mKernelCount) {
428                throw new RSRuntimeException("Count mismatch, should not happen.");
429            }
430
431            int[] src = new int[mLines.size()];
432            int[] dstk = new int[mLines.size()];
433            int[] dstf = new int[mLines.size()];
434            int[] types = new int[mLines.size()];
435
436            for (int ct=0; ct < mLines.size(); ct++) {
437                ConnectLine cl = mLines.get(ct);
438                src[ct] = (int)cl.mFrom.getID(mRS);
439                if (cl.mToK != null) {
440                    dstk[ct] = (int)cl.mToK.getID(mRS);
441                }
442                if (cl.mToF != null) {
443                    dstf[ct] = (int)cl.mToF.getID(mRS);
444                }
445                types[ct] = (int)cl.mAllocationType.getID(mRS);
446            }
447
448            long id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
449            if (id == 0) {
450                throw new RSRuntimeException("Object creation error, should not happen.");
451            }
452
453            ScriptGroup sg = new ScriptGroup(id, mRS);
454            sg.mOutputs = new IO[outputs.size()];
455            for (int ct=0; ct < outputs.size(); ct++) {
456                sg.mOutputs[ct] = outputs.get(ct);
457            }
458
459            sg.mInputs = new IO[inputs.size()];
460            for (int ct=0; ct < inputs.size(); ct++) {
461                sg.mInputs[ct] = inputs.get(ct);
462            }
463
464            return sg;
465        }
466
467    }
468
469
470}
471
472
473