1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.renderscript;
18
19import java.lang.reflect.Method;
20import java.util.ArrayList;
21
22/**
23 * ScriptGroup creates a groups of scripts which are executed
24 * together based upon upon one execution call as if they were
25 * all part of a single script.  The scripts may be connected
26 * internally or to an external allocation. For the internal
27 * connections the intermediate results are not observable after
28 * the execution of the script.
29 * <p>
30 * The external connections are grouped into inputs and outputs.
31 * All outputs are produced by a script kernel and placed into a
32 * user supplied allocation. Inputs are similar but supply the
33 * input of a kernal. Inputs bounds to a script are set directly
34 * upon the script.
35 * <p>
36 * A ScriptGroup must contain at least one kernel. A ScriptGroup
37 * must contain only a single directed acyclic graph (DAG) of
38 * script kernels and connections. Attempting to create a
39 * ScriptGroup with multiple DAGs or attempting to create
40 * a cycle within a ScriptGroup will throw an exception.
41 *
42 **/
43public final class ScriptGroup extends BaseObj {
44    IO mOutputs[];
45    IO mInputs[];
46
47    static class IO {
48        Script.KernelID mKID;
49        Allocation mAllocation;
50
51        IO(Script.KernelID s) {
52            mKID = s;
53        }
54    }
55
56    static class ConnectLine {
57        ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
58            mFrom = from;
59            mToK = to;
60            mAllocationType = t;
61        }
62
63        ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
64            mFrom = from;
65            mToF = to;
66            mAllocationType = t;
67        }
68
69        Script.FieldID mToF;
70        Script.KernelID mToK;
71        Script.KernelID mFrom;
72        Type mAllocationType;
73    }
74
75    static class Node {
76        Script mScript;
77        ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
78        ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
79        ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
80        int dagNumber;
81
82        Node mNext;
83
84        Node(Script s) {
85            mScript = s;
86        }
87    }
88
89
90    ScriptGroup(int id, RenderScript rs) {
91        super(id, rs);
92    }
93
94    /**
95     * Sets an input of the ScriptGroup. This specifies an
96     * Allocation to be used for the kernels which require a kernel
97     * input and that input is provided external to the group.
98     *
99     * @param s The ID of the kernel where the allocation should be
100     *          connected.
101     * @param a The allocation to connect.
102     */
103    public void setInput(Script.KernelID s, Allocation a) {
104        for (int ct=0; ct < mInputs.length; ct++) {
105            if (mInputs[ct].mKID == s) {
106                mInputs[ct].mAllocation = a;
107                mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
108                return;
109            }
110        }
111        throw new RSIllegalArgumentException("Script not found");
112    }
113
114    /**
115     * Sets an output of the ScriptGroup. This specifies an
116     * Allocation to be used for the kernels which require a kernel
117     * output and that output is provided external to the group.
118     *
119     * @param s The ID of the kernel where the allocation should be
120     *          connected.
121     * @param a The allocation to connect.
122     */
123    public void setOutput(Script.KernelID s, Allocation a) {
124        for (int ct=0; ct < mOutputs.length; ct++) {
125            if (mOutputs[ct].mKID == s) {
126                mOutputs[ct].mAllocation = a;
127                mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
128                return;
129            }
130        }
131        throw new RSIllegalArgumentException("Script not found");
132    }
133
134    /**
135     * Execute the ScriptGroup.  This will run all the kernels in
136     * the script.  The state of the connecting lines will not be
137     * observable after this operation.
138     */
139    public void execute() {
140        mRS.nScriptGroupExecute(getID(mRS));
141    }
142
143
144    /**
145     * Create a ScriptGroup. There are two steps to creating a
146     * ScriptGoup.
147     * <p>
148     * First all the Kernels to be used by the group should be
149     * added.  Once this is done the kernels should be connected.
150     * Kernels cannot be added once a connection has been made.
151     * <p>
152     * Second, add connections. There are two forms of connections.
153     * Kernel to Kernel and Kernel to Field. Kernel to Kernel is
154     * higher performance and should be used where possible. The
155     * line of connections cannot form a loop. If a loop is detected
156     * an exception is thrown.
157     * <p>
158     * Once all the connections are made a call to create will
159     * return the ScriptGroup object.
160     *
161     */
162    public static final class Builder {
163        private RenderScript mRS;
164        private ArrayList<Node> mNodes = new ArrayList<Node>();
165        private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
166        private int mKernelCount;
167
168        /**
169         * Create a builder for generating a ScriptGroup.
170         *
171         *
172         * @param rs The Renderscript context.
173         */
174        public Builder(RenderScript rs) {
175            mRS = rs;
176        }
177
178        // do a DFS from original node, looking for original node
179        // any cycle that could be created must contain original node
180        private void validateCycle(Node target, Node original) {
181            for (int ct = 0; ct < target.mOutputs.size(); ct++) {
182                final ConnectLine cl = target.mOutputs.get(ct);
183                if (cl.mToK != null) {
184                    Node tn = findNode(cl.mToK.mScript);
185                    if (tn.equals(original)) {
186                        throw new RSInvalidStateException("Loops in group not allowed.");
187                    }
188                    validateCycle(tn, original);
189                }
190                if (cl.mToF != null) {
191                    Node tn = findNode(cl.mToF.mScript);
192                    if (tn.equals(original)) {
193                        throw new RSInvalidStateException("Loops in group not allowed.");
194                    }
195                    validateCycle(tn, original);
196                }
197            }
198        }
199
200        private void mergeDAGs(int valueUsed, int valueKilled) {
201            for (int ct=0; ct < mNodes.size(); ct++) {
202                if (mNodes.get(ct).dagNumber == valueKilled)
203                    mNodes.get(ct).dagNumber = valueUsed;
204            }
205        }
206
207        private void validateDAGRecurse(Node n, int dagNumber) {
208            // combine DAGs if this node has been seen already
209            if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
210                mergeDAGs(n.dagNumber, dagNumber);
211                return;
212            }
213
214            n.dagNumber = dagNumber;
215            for (int ct=0; ct < n.mOutputs.size(); ct++) {
216                final ConnectLine cl = n.mOutputs.get(ct);
217                if (cl.mToK != null) {
218                    Node tn = findNode(cl.mToK.mScript);
219                    validateDAGRecurse(tn, dagNumber);
220                }
221                if (cl.mToF != null) {
222                    Node tn = findNode(cl.mToF.mScript);
223                    validateDAGRecurse(tn, dagNumber);
224                }
225            }
226        }
227
228        private void validateDAG() {
229            for (int ct=0; ct < mNodes.size(); ct++) {
230                Node n = mNodes.get(ct);
231                if (n.mInputs.size() == 0) {
232                    if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
233                        throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
234                    }
235                    validateDAGRecurse(n, ct+1);
236                }
237            }
238            int dagNumber = mNodes.get(0).dagNumber;
239            for (int ct=0; ct < mNodes.size(); ct++) {
240                if (mNodes.get(ct).dagNumber != dagNumber) {
241                    throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
242                }
243            }
244        }
245
246        private Node findNode(Script s) {
247            for (int ct=0; ct < mNodes.size(); ct++) {
248                if (s == mNodes.get(ct).mScript) {
249                    return mNodes.get(ct);
250                }
251            }
252            return null;
253        }
254
255        private Node findNode(Script.KernelID k) {
256            for (int ct=0; ct < mNodes.size(); ct++) {
257                Node n = mNodes.get(ct);
258                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
259                    if (k == n.mKernels.get(ct2)) {
260                        return n;
261                    }
262                }
263            }
264            return null;
265        }
266
267        /**
268         * Adds a Kernel to the group.
269         *
270         *
271         * @param k The kernel to add.
272         *
273         * @return Builder Returns this.
274         */
275        public Builder addKernel(Script.KernelID k) {
276            if (mLines.size() != 0) {
277                throw new RSInvalidStateException(
278                    "Kernels may not be added once connections exist.");
279            }
280
281            //android.util.Log.v("RSR", "addKernel 1 k=" + k);
282            if (findNode(k) != null) {
283                return this;
284            }
285            //android.util.Log.v("RSR", "addKernel 2 ");
286            mKernelCount++;
287            Node n = findNode(k.mScript);
288            if (n == null) {
289                //android.util.Log.v("RSR", "addKernel 3 ");
290                n = new Node(k.mScript);
291                mNodes.add(n);
292            }
293            n.mKernels.add(k);
294            return this;
295        }
296
297        /**
298         * Adds a connection to the group.
299         *
300         *
301         * @param t The type of the connection. This is used to
302         *          determine the kernel launch sizes on the source side
303         *          of this connection.
304         * @param from The source for the connection.
305         * @param to The destination of the connection.
306         *
307         * @return Builder Returns this
308         */
309        public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
310            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
311
312            Node nf = findNode(from);
313            if (nf == null) {
314                throw new RSInvalidStateException("From script not found.");
315            }
316
317            Node nt = findNode(to.mScript);
318            if (nt == null) {
319                throw new RSInvalidStateException("To script not found.");
320            }
321
322            ConnectLine cl = new ConnectLine(t, from, to);
323            mLines.add(new ConnectLine(t, from, to));
324
325            nf.mOutputs.add(cl);
326            nt.mInputs.add(cl);
327
328            validateCycle(nf, nf);
329            return this;
330        }
331
332        /**
333         * Adds a connection to the group.
334         *
335         *
336         * @param t The type of the connection. This is used to
337         *          determine the kernel launch sizes for both sides of
338         *          this connection.
339         * @param from The source for the connection.
340         * @param to The destination of the connection.
341         *
342         * @return Builder Returns this
343         */
344        public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
345            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
346
347            Node nf = findNode(from);
348            if (nf == null) {
349                throw new RSInvalidStateException("From script not found.");
350            }
351
352            Node nt = findNode(to);
353            if (nt == null) {
354                throw new RSInvalidStateException("To script not found.");
355            }
356
357            ConnectLine cl = new ConnectLine(t, from, to);
358            mLines.add(new ConnectLine(t, from, to));
359
360            nf.mOutputs.add(cl);
361            nt.mInputs.add(cl);
362
363            validateCycle(nf, nf);
364            return this;
365        }
366
367
368
369        /**
370         * Creates the Script group.
371         *
372         *
373         * @return ScriptGroup The new ScriptGroup
374         */
375        public ScriptGroup create() {
376
377            if (mNodes.size() == 0) {
378                throw new RSInvalidStateException("Empty script groups are not allowed");
379            }
380
381            // reset DAG numbers in case we're building a second group
382            for (int ct=0; ct < mNodes.size(); ct++) {
383                mNodes.get(ct).dagNumber = 0;
384            }
385            validateDAG();
386
387            ArrayList<IO> inputs = new ArrayList<IO>();
388            ArrayList<IO> outputs = new ArrayList<IO>();
389
390            int[] kernels = new int[mKernelCount];
391            int idx = 0;
392            for (int ct=0; ct < mNodes.size(); ct++) {
393                Node n = mNodes.get(ct);
394                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
395                    final Script.KernelID kid = n.mKernels.get(ct2);
396                    kernels[idx++] = kid.getID(mRS);
397
398                    boolean hasInput = false;
399                    boolean hasOutput = false;
400                    for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
401                        if (n.mInputs.get(ct3).mToK == kid) {
402                            hasInput = true;
403                        }
404                    }
405                    for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
406                        if (n.mOutputs.get(ct3).mFrom == kid) {
407                            hasOutput = true;
408                        }
409                    }
410                    if (!hasInput) {
411                        inputs.add(new IO(kid));
412                    }
413                    if (!hasOutput) {
414                        outputs.add(new IO(kid));
415                    }
416
417                }
418            }
419            if (idx != mKernelCount) {
420                throw new RSRuntimeException("Count mismatch, should not happen.");
421            }
422
423            int[] src = new int[mLines.size()];
424            int[] dstk = new int[mLines.size()];
425            int[] dstf = new int[mLines.size()];
426            int[] types = new int[mLines.size()];
427
428            for (int ct=0; ct < mLines.size(); ct++) {
429                ConnectLine cl = mLines.get(ct);
430                src[ct] = cl.mFrom.getID(mRS);
431                if (cl.mToK != null) {
432                    dstk[ct] = cl.mToK.getID(mRS);
433                }
434                if (cl.mToF != null) {
435                    dstf[ct] = cl.mToF.getID(mRS);
436                }
437                types[ct] = cl.mAllocationType.getID(mRS);
438            }
439
440            int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
441            if (id == 0) {
442                throw new RSRuntimeException("Object creation error, should not happen.");
443            }
444
445            ScriptGroup sg = new ScriptGroup(id, mRS);
446            sg.mOutputs = new IO[outputs.size()];
447            for (int ct=0; ct < outputs.size(); ct++) {
448                sg.mOutputs[ct] = outputs.get(ct);
449            }
450
451            sg.mInputs = new IO[inputs.size()];
452            for (int ct=0; ct < inputs.size(); ct++) {
453                sg.mInputs[ct] = inputs.get(ct);
454            }
455
456            return sg;
457        }
458
459    }
460
461
462}
463
464
465