1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.support.v8.renderscript;
18
19import java.lang.reflect.Method;
20import java.util.ArrayList;
21
22/**
23 * ScriptGroup creates a group of kernels that are executed
24 * together with one execution call as if they were a single kernel.
25 * The kernels may be connected internally or to an external allocation.
26 * The intermediate results for internal connections are not observable
27 * after the execution of the script.
28 * <p>
29 * External connections are grouped into inputs and outputs.
30 * All outputs are produced by a script kernel and placed into a
31 * user-supplied allocation. Inputs provide the input of a kernel.
32 * Inputs bound to script globals are set directly upon the script.
33 * <p>
34 * A ScriptGroup must contain at least one kernel. A ScriptGroup
35 * must contain only a single directed acyclic graph (DAG) of
36 * script kernels and connections. Attempting to create a
37 * ScriptGroup with multiple DAGs or attempting to create
38 * a cycle within a ScriptGroup will throw an exception.
39 * <p>
40 * Currently, all kernels in a ScriptGroup must be from separate
41 * Script objects. Attempting to use multiple kernels from the same
42 * Script object will result in an
43 * {@link android.support.v8.renderscript.RSInvalidStateException}.
44 *
45 **/
46public class ScriptGroup extends BaseObj {
47    IO mOutputs[];
48    IO mInputs[];
49
50    static class IO {
51        Script.KernelID mKID;
52        Allocation mAllocation;
53
54        IO(Script.KernelID s) {
55            mKID = s;
56        }
57    }
58
59    static class ConnectLine {
60        ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
61            mFrom = from;
62            mToK = to;
63            mAllocationType = t;
64        }
65
66        ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
67            mFrom = from;
68            mToF = to;
69            mAllocationType = t;
70        }
71
72        Script.FieldID mToF;
73        Script.KernelID mToK;
74        Script.KernelID mFrom;
75        Type mAllocationType;
76    }
77
78    static class Node {
79        Script mScript;
80        ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
81        ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
82        ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
83        int dagNumber;
84
85        Node mNext;
86
87        Node(Script s) {
88            mScript = s;
89        }
90    }
91
92
93    ScriptGroup(int id, RenderScript rs) {
94        super(id, rs);
95    }
96
97    /**
98     * Sets an input of the ScriptGroup. This specifies an
99     * Allocation to be used for kernels that require an input
100     * Allocation provided from outside of the ScriptGroup.
101     *
102     * @param s The ID of the kernel where the allocation should be
103     *          connected.
104     * @param a The allocation to connect.
105     */
106    public void setInput(Script.KernelID s, Allocation a) {
107        for (int ct=0; ct < mInputs.length; ct++) {
108            if (mInputs[ct].mKID == s) {
109                mInputs[ct].mAllocation = a;
110                mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
111                return;
112            }
113        }
114        throw new RSIllegalArgumentException("Script not found");
115    }
116
117    /**
118     * Sets an output of the ScriptGroup. This specifies an
119     * Allocation to be used for the kernels that require an output
120     * Allocation visible after the ScriptGroup is executed.
121     *
122     * @param s The ID of the kernel where the allocation should be
123     *          connected.
124     * @param a The allocation to connect.
125     */
126    public void setOutput(Script.KernelID s, Allocation a) {
127        for (int ct=0; ct < mOutputs.length; ct++) {
128            if (mOutputs[ct].mKID == s) {
129                mOutputs[ct].mAllocation = a;
130                mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
131                return;
132            }
133        }
134        throw new RSIllegalArgumentException("Script not found");
135    }
136
137    /**
138     * Execute the ScriptGroup.  This will run all the kernels in
139     * the ScriptGroup.  No internal connection results will be visible
140     * after execution of the ScriptGroup.
141     */
142    public void execute() {
143        mRS.nScriptGroupExecute(getID(mRS));
144    }
145
146
147    /**
148     * Helper class to build a ScriptGroup. A ScriptGroup is
149     * created in two steps.
150     * <p>
151     * First, all kernels to be used by the ScriptGroup should be added.
152     * <p>
153     * Second, add connections between kernels. There are two types
154     * of connections: kernel to kernel and kernel to field.
155     * Kernel to kernel allows a kernel's output to be passed to
156     * another kernel as input. Kernel to field allows the output of
157     * one kernel to be bound as a script global. Kernel to kernel is
158     * higher performance and should be used where possible.
159     * <p>
160     * A ScriptGroup must contain a single directed acyclic graph (DAG); it
161     * cannot contain cycles. Currently, all kernels used in a ScriptGroup
162     * must come from different Script objects. Additionally, all kernels
163     * in a ScriptGroup must have at least one input, output, or internal
164     * connection.
165     * <p>
166     * Once all connections are made, a call to {@link #create} will
167     * return the ScriptGroup object.
168     *
169     */
170    public static final class Builder {
171        private RenderScript mRS;
172        private ArrayList<Node> mNodes = new ArrayList<Node>();
173        private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
174        private int mKernelCount;
175
176        private ScriptGroupThunker.Builder mT;
177
178        /**
179         * Create a Builder for generating a ScriptGroup.
180         *
181         *
182         * @param rs The RenderScript context.
183         */
184        public Builder(RenderScript rs) {
185            if (rs.isNative) {
186                mT = new ScriptGroupThunker.Builder(rs);
187            }
188            mRS = rs;
189        }
190
191        // do a DFS from original node, looking for original node
192        // any cycle that could be created must contain original node
193        private void validateCycle(Node target, Node original) {
194            for (int ct = 0; ct < target.mOutputs.size(); ct++) {
195                final ConnectLine cl = target.mOutputs.get(ct);
196                if (cl.mToK != null) {
197                    Node tn = findNode(cl.mToK.mScript);
198                    if (tn.equals(original)) {
199                        throw new RSInvalidStateException("Loops in group not allowed.");
200                    }
201                    validateCycle(tn, original);
202                }
203                if (cl.mToF != null) {
204                    Node tn = findNode(cl.mToF.mScript);
205                    if (tn.equals(original)) {
206                        throw new RSInvalidStateException("Loops in group not allowed.");
207                    }
208                    validateCycle(tn, original);
209                }
210            }
211        }
212
213        private void mergeDAGs(int valueUsed, int valueKilled) {
214            for (int ct=0; ct < mNodes.size(); ct++) {
215                if (mNodes.get(ct).dagNumber == valueKilled)
216                    mNodes.get(ct).dagNumber = valueUsed;
217            }
218        }
219
220        private void validateDAGRecurse(Node n, int dagNumber) {
221            // combine DAGs if this node has been seen already
222            if (n.dagNumber != 0 && n.dagNumber != dagNumber) {
223                mergeDAGs(n.dagNumber, dagNumber);
224                return;
225            }
226
227            n.dagNumber = dagNumber;
228            for (int ct=0; ct < n.mOutputs.size(); ct++) {
229                final ConnectLine cl = n.mOutputs.get(ct);
230                if (cl.mToK != null) {
231                    Node tn = findNode(cl.mToK.mScript);
232                    validateDAGRecurse(tn, dagNumber);
233                }
234                if (cl.mToF != null) {
235                    Node tn = findNode(cl.mToF.mScript);
236                    validateDAGRecurse(tn, dagNumber);
237                }
238            }
239        }
240
241        private void validateDAG() {
242            for (int ct=0; ct < mNodes.size(); ct++) {
243                Node n = mNodes.get(ct);
244                if (n.mInputs.size() == 0) {
245                    if (n.mOutputs.size() == 0 && mNodes.size() > 1) {
246                        throw new RSInvalidStateException("Groups cannot contain unconnected scripts");
247                    }
248                    validateDAGRecurse(n, ct+1);
249                }
250            }
251            int dagNumber = mNodes.get(0).dagNumber;
252            for (int ct=0; ct < mNodes.size(); ct++) {
253                if (mNodes.get(ct).dagNumber != dagNumber) {
254                    throw new RSInvalidStateException("Multiple DAGs in group not allowed.");
255                }
256            }
257        }
258
259        private Node findNode(Script s) {
260            for (int ct=0; ct < mNodes.size(); ct++) {
261                if (s == mNodes.get(ct).mScript) {
262                    return mNodes.get(ct);
263                }
264            }
265            return null;
266        }
267
268        private Node findNode(Script.KernelID k) {
269            for (int ct=0; ct < mNodes.size(); ct++) {
270                Node n = mNodes.get(ct);
271                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
272                    if (k == n.mKernels.get(ct2)) {
273                        return n;
274                    }
275                }
276            }
277            return null;
278        }
279
280        /**
281         * Adds a Kernel to the group.
282         *
283         *
284         * @param k The kernel to add.
285         *
286         * @return Builder Returns this.
287         */
288        public Builder addKernel(Script.KernelID k) {
289            if (mT != null) {
290                mT.addKernel(k);
291                return this;
292            }
293
294            if (mLines.size() != 0) {
295                throw new RSInvalidStateException(
296                    "Kernels may not be added once connections exist.");
297            }
298
299            //android.util.Log.v("RSR", "addKernel 1 k=" + k);
300            if (findNode(k) != null) {
301                return this;
302            }
303            //android.util.Log.v("RSR", "addKernel 2 ");
304            mKernelCount++;
305            Node n = findNode(k.mScript);
306            if (n == null) {
307                //android.util.Log.v("RSR", "addKernel 3 ");
308                n = new Node(k.mScript);
309                mNodes.add(n);
310            }
311            n.mKernels.add(k);
312            return this;
313        }
314
315        /**
316         * Adds a connection to the group.
317         *
318         *
319         * @param t The type of the connection. This is used to
320         *          determine the kernel launch sizes on the source side
321         *          of this connection.
322         * @param from The source for the connection.
323         * @param to The destination of the connection.
324         *
325         * @return Builder Returns this
326         */
327        public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
328            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
329
330            if (mT != null) {
331                mT.addConnection(t, from, to);
332                return this;
333            }
334
335            Node nf = findNode(from);
336            if (nf == null) {
337                throw new RSInvalidStateException("From script not found.");
338            }
339
340            Node nt = findNode(to.mScript);
341            if (nt == null) {
342                throw new RSInvalidStateException("To script not found.");
343            }
344
345            ConnectLine cl = new ConnectLine(t, from, to);
346            mLines.add(new ConnectLine(t, from, to));
347
348            nf.mOutputs.add(cl);
349            nt.mInputs.add(cl);
350
351            validateCycle(nf, nf);
352            return this;
353        }
354
355        /**
356         * Adds a connection to the group.
357         *
358         *
359         * @param t The type of the connection. This is used to
360         *          determine the kernel launch sizes for both sides of
361         *          this connection.
362         * @param from The source for the connection.
363         * @param to The destination of the connection.
364         *
365         * @return Builder Returns this
366         */
367        public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
368            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
369
370            if (mT != null) {
371                mT.addConnection(t, from, to);
372                return this;
373            }
374
375            Node nf = findNode(from);
376            if (nf == null) {
377                throw new RSInvalidStateException("From script not found.");
378            }
379
380            Node nt = findNode(to);
381            if (nt == null) {
382                throw new RSInvalidStateException("To script not found.");
383            }
384
385            ConnectLine cl = new ConnectLine(t, from, to);
386            mLines.add(new ConnectLine(t, from, to));
387
388            nf.mOutputs.add(cl);
389            nt.mInputs.add(cl);
390
391            validateCycle(nf, nf);
392            return this;
393        }
394
395
396
397        /**
398         * Creates the Script group.
399         *
400         *
401         * @return ScriptGroup The new ScriptGroup
402         */
403        public ScriptGroup create() {
404
405            if (mT != null) {
406                return mT.create();
407            }
408
409            if (mNodes.size() == 0) {
410                throw new RSInvalidStateException("Empty script groups are not allowed");
411            }
412
413            // reset DAG numbers in case we're building a second group
414            for (int ct=0; ct < mNodes.size(); ct++) {
415                mNodes.get(ct).dagNumber = 0;
416            }
417            validateDAG();
418
419            ArrayList<IO> inputs = new ArrayList<IO>();
420            ArrayList<IO> outputs = new ArrayList<IO>();
421
422            int[] kernels = new int[mKernelCount];
423            int idx = 0;
424            for (int ct=0; ct < mNodes.size(); ct++) {
425                Node n = mNodes.get(ct);
426                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
427                    final Script.KernelID kid = n.mKernels.get(ct2);
428                    kernels[idx++] = kid.getID(mRS);
429
430                    boolean hasInput = false;
431                    boolean hasOutput = false;
432                    for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
433                        if (n.mInputs.get(ct3).mToK == kid) {
434                            hasInput = true;
435                        }
436                    }
437                    for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
438                        if (n.mOutputs.get(ct3).mFrom == kid) {
439                            hasOutput = true;
440                        }
441                    }
442                    if (!hasInput) {
443                        inputs.add(new IO(kid));
444                    }
445                    if (!hasOutput) {
446                        outputs.add(new IO(kid));
447                    }
448
449                }
450            }
451            if (idx != mKernelCount) {
452                throw new RSRuntimeException("Count mismatch, should not happen.");
453            }
454
455            int[] src = new int[mLines.size()];
456            int[] dstk = new int[mLines.size()];
457            int[] dstf = new int[mLines.size()];
458            int[] types = new int[mLines.size()];
459
460            for (int ct=0; ct < mLines.size(); ct++) {
461                ConnectLine cl = mLines.get(ct);
462                src[ct] = cl.mFrom.getID(mRS);
463                if (cl.mToK != null) {
464                    dstk[ct] = cl.mToK.getID(mRS);
465                }
466                if (cl.mToF != null) {
467                    dstf[ct] = cl.mToF.getID(mRS);
468                }
469                types[ct] = cl.mAllocationType.getID(mRS);
470            }
471
472            int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
473            if (id == 0) {
474                throw new RSRuntimeException("Object creation error, should not happen.");
475            }
476
477            ScriptGroup sg = new ScriptGroup(id, mRS);
478            sg.mOutputs = new IO[outputs.size()];
479            for (int ct=0; ct < outputs.size(); ct++) {
480                sg.mOutputs[ct] = outputs.get(ct);
481            }
482
483            sg.mInputs = new IO[inputs.size()];
484            for (int ct=0; ct < inputs.size(); ct++) {
485                sg.mInputs[ct] = inputs.get(ct);
486            }
487
488            return sg;
489        }
490
491    }
492
493
494}
495
496
497