1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.support.v8.renderscript;
18
19import java.lang.reflect.Method;
20import java.util.ArrayList;
21
22/**
23 * ScriptGroup creates a groups of scripts which are executed
24 * together based upon upon one execution call as if they were
25 * all part of a single script.  The scripts may be connected
26 * internally or to an external allocation. For the internal
27 * connections the intermediate results are not observable after
28 * the execution of the script.
29 * <p>
30 * The external connections are grouped into inputs and outputs.
31 * All outputs are produced by a script kernel and placed into a
32 * user supplied allocation. Inputs are similar but supply the
33 * input of a kernal. Inputs bounds to a script are set directly
34 * upon the script.
35 *
36 **/
37public final class ScriptGroup extends BaseObj {
38    IO mOutputs[];
39    IO mInputs[];
40
41    static class IO {
42        Script.KernelID mKID;
43        Allocation mAllocation;
44
45        IO(Script.KernelID s) {
46            mKID = s;
47        }
48    }
49
50    static class ConnectLine {
51        ConnectLine(Type t, Script.KernelID from, Script.KernelID to) {
52            mFrom = from;
53            mToK = to;
54            mAllocationType = t;
55        }
56
57        ConnectLine(Type t, Script.KernelID from, Script.FieldID to) {
58            mFrom = from;
59            mToF = to;
60            mAllocationType = t;
61        }
62
63        Script.FieldID mToF;
64        Script.KernelID mToK;
65        Script.KernelID mFrom;
66        Type mAllocationType;
67    }
68
69    static class Node {
70        Script mScript;
71        ArrayList<Script.KernelID> mKernels = new ArrayList<Script.KernelID>();
72        ArrayList<ConnectLine> mInputs = new ArrayList<ConnectLine>();
73        ArrayList<ConnectLine> mOutputs = new ArrayList<ConnectLine>();
74        boolean mSeen;
75
76        Node mNext;
77
78        Node(Script s) {
79            mScript = s;
80        }
81    }
82
83
84    ScriptGroup(int id, RenderScript rs) {
85        super(id, rs);
86    }
87
88    /**
89     * Sets an input of the ScriptGroup. This specifies an
90     * Allocation to be used for the kernels which require a kernel
91     * input and that input is provided external to the group.
92     *
93     * @param s The ID of the kernel where the allocation should be
94     *          connected.
95     * @param a The allocation to connect.
96     */
97    public void setInput(Script.KernelID s, Allocation a) {
98        for (int ct=0; ct < mInputs.length; ct++) {
99            if (mInputs[ct].mKID == s) {
100                mInputs[ct].mAllocation = a;
101                mRS.nScriptGroupSetInput(getID(mRS), s.getID(mRS), mRS.safeID(a));
102                return;
103            }
104        }
105        throw new RSIllegalArgumentException("Script not found");
106    }
107
108    /**
109     * Sets an output of the ScriptGroup. This specifies an
110     * Allocation to be used for the kernels which require a kernel
111     * output and that output is provided external to the group.
112     *
113     * @param s The ID of the kernel where the allocation should be
114     *          connected.
115     * @param a The allocation to connect.
116     */
117    public void setOutput(Script.KernelID s, Allocation a) {
118        for (int ct=0; ct < mOutputs.length; ct++) {
119            if (mOutputs[ct].mKID == s) {
120                mOutputs[ct].mAllocation = a;
121                mRS.nScriptGroupSetOutput(getID(mRS), s.getID(mRS), mRS.safeID(a));
122                return;
123            }
124        }
125        throw new RSIllegalArgumentException("Script not found");
126    }
127
128    /**
129     * Execute the ScriptGroup.  This will run all the kernels in
130     * the script.  The state of the connecting lines will not be
131     * observable after this operation.
132     */
133    public void execute() {
134        mRS.nScriptGroupExecute(getID(mRS));
135    }
136
137
138    /**
139     * Create a ScriptGroup. There are two steps to creating a
140     * ScriptGoup.
141     * <p>
142     * First all the Kernels to be used by the group should be
143     * added.  Once this is done the kernels should be connected.
144     * Kernels cannot be added once a connection has been made.
145     * <p>
146     * Second, add connections. There are two forms of connections.
147     * Kernel to Kernel and Kernel to Field. Kernel to Kernel is
148     * higher performance and should be used where possible. The
149     * line of connections cannot form a loop. If a loop is detected
150     * an exception is thrown.
151     * <p>
152     * Once all the connections are made a call to create will
153     * return the ScriptGroup object.
154     *
155     */
156    public static final class Builder {
157        private RenderScript mRS;
158        private ArrayList<Node> mNodes = new ArrayList<Node>();
159        private ArrayList<ConnectLine> mLines = new ArrayList<ConnectLine>();
160        private int mKernelCount;
161
162        /**
163         * Create a builder for generating a ScriptGroup.
164         *
165         *
166         * @param rs The Renderscript context.
167         */
168        public Builder(RenderScript rs) {
169            mRS = rs;
170        }
171
172        private void validateRecurse(Node n, int depth) {
173            n.mSeen = true;
174
175            //android.util.Log.v("RSR", " validateRecurse outputCount " + n.mOutputs.size());
176            for (int ct=0; ct < n.mOutputs.size(); ct++) {
177                final ConnectLine cl = n.mOutputs.get(ct);
178                if (cl.mToK != null) {
179                    Node tn = findNode(cl.mToK.mScript);
180                    if (tn.mSeen) {
181                        throw new RSInvalidStateException("Loops in group not allowed.");
182                    }
183                    validateRecurse(tn, depth + 1);
184                }
185                if (cl.mToF != null) {
186                    Node tn = findNode(cl.mToF.mScript);
187                    if (tn.mSeen) {
188                        throw new RSInvalidStateException("Loops in group not allowed.");
189                    }
190                    validateRecurse(tn, depth + 1);
191                }
192            }
193        }
194
195        private void validate() {
196            //android.util.Log.v("RSR", "validate");
197
198            for (int ct=0; ct < mNodes.size(); ct++) {
199                for (int ct2=0; ct2 < mNodes.size(); ct2++) {
200                    mNodes.get(ct2).mSeen = false;
201                }
202                Node n = mNodes.get(ct);
203                if (n.mInputs.size() == 0) {
204                    validateRecurse(n, 0);
205                }
206            }
207        }
208
209        private Node findNode(Script s) {
210            for (int ct=0; ct < mNodes.size(); ct++) {
211                if (s == mNodes.get(ct).mScript) {
212                    return mNodes.get(ct);
213                }
214            }
215            return null;
216        }
217
218        private Node findNode(Script.KernelID k) {
219            for (int ct=0; ct < mNodes.size(); ct++) {
220                Node n = mNodes.get(ct);
221                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
222                    if (k == n.mKernels.get(ct2)) {
223                        return n;
224                    }
225                }
226            }
227            return null;
228        }
229
230        /**
231         * Adds a Kernel to the group.
232         *
233         *
234         * @param k The kernel to add.
235         *
236         * @return Builder Returns this.
237         */
238        public Builder addKernel(Script.KernelID k) {
239            if (mLines.size() != 0) {
240                throw new RSInvalidStateException(
241                    "Kernels may not be added once connections exist.");
242            }
243
244            //android.util.Log.v("RSR", "addKernel 1 k=" + k);
245            if (findNode(k) != null) {
246                return this;
247            }
248            //android.util.Log.v("RSR", "addKernel 2 ");
249            mKernelCount++;
250            Node n = findNode(k.mScript);
251            if (n == null) {
252                //android.util.Log.v("RSR", "addKernel 3 ");
253                n = new Node(k.mScript);
254                mNodes.add(n);
255            }
256            n.mKernels.add(k);
257            return this;
258        }
259
260        /**
261         * Adds a connection to the group.
262         *
263         *
264         * @param t The type of the connection. This is used to
265         *          determine the kernel launch sizes on the source side
266         *          of this connection.
267         * @param from The source for the connection.
268         * @param to The destination of the connection.
269         *
270         * @return Builder Returns this
271         */
272        public Builder addConnection(Type t, Script.KernelID from, Script.FieldID to) {
273            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
274
275            Node nf = findNode(from);
276            if (nf == null) {
277                throw new RSInvalidStateException("From kernel not found.");
278            }
279
280            Node nt = findNode(to.mScript);
281            if (nt == null) {
282                throw new RSInvalidStateException("To script not found.");
283            }
284
285            ConnectLine cl = new ConnectLine(t, from, to);
286            mLines.add(new ConnectLine(t, from, to));
287
288            nf.mOutputs.add(cl);
289            nt.mInputs.add(cl);
290
291            validate();
292            return this;
293        }
294
295        /**
296         * Adds a connection to the group.
297         *
298         *
299         * @param t The type of the connection. This is used to
300         *          determine the kernel launch sizes for both sides of
301         *          this connection.
302         * @param from The source for the connection.
303         * @param to The destination of the connection.
304         *
305         * @return Builder Returns this
306         */
307        public Builder addConnection(Type t, Script.KernelID from, Script.KernelID to) {
308            //android.util.Log.v("RSR", "addConnection " + t +", " + from + ", " + to);
309
310            Node nf = findNode(from);
311            if (nf == null) {
312                throw new RSInvalidStateException("From kernel not found.");
313            }
314
315            Node nt = findNode(to);
316            if (nt == null) {
317                throw new RSInvalidStateException("To script not found.");
318            }
319
320            ConnectLine cl = new ConnectLine(t, from, to);
321            mLines.add(new ConnectLine(t, from, to));
322
323            nf.mOutputs.add(cl);
324            nt.mInputs.add(cl);
325
326            validate();
327            return this;
328        }
329
330
331
332        /**
333         * Creates the Script group.
334         *
335         *
336         * @return ScriptGroup The new ScriptGroup
337         */
338        public ScriptGroup create() {
339            ArrayList<IO> inputs = new ArrayList<IO>();
340            ArrayList<IO> outputs = new ArrayList<IO>();
341
342            int[] kernels = new int[mKernelCount];
343            int idx = 0;
344            for (int ct=0; ct < mNodes.size(); ct++) {
345                Node n = mNodes.get(ct);
346                for (int ct2=0; ct2 < n.mKernels.size(); ct2++) {
347                    final Script.KernelID kid = n.mKernels.get(ct2);
348                    kernels[idx++] = kid.getID(mRS);
349
350                    boolean hasInput = false;
351                    boolean hasOutput = false;
352                    for (int ct3=0; ct3 < n.mInputs.size(); ct3++) {
353                        if (n.mInputs.get(ct3).mToK == kid) {
354                            hasInput = true;
355                        }
356                    }
357                    for (int ct3=0; ct3 < n.mOutputs.size(); ct3++) {
358                        if (n.mOutputs.get(ct3).mFrom == kid) {
359                            hasOutput = true;
360                        }
361                    }
362                    if (!hasInput) {
363                        inputs.add(new IO(kid));
364                    }
365                    if (!hasOutput) {
366                        outputs.add(new IO(kid));
367                    }
368
369                }
370            }
371            if (idx != mKernelCount) {
372                throw new RSRuntimeException("Count mismatch, should not happen.");
373            }
374
375            int[] src = new int[mLines.size()];
376            int[] dstk = new int[mLines.size()];
377            int[] dstf = new int[mLines.size()];
378            int[] types = new int[mLines.size()];
379
380            for (int ct=0; ct < mLines.size(); ct++) {
381                ConnectLine cl = mLines.get(ct);
382                src[ct] = cl.mFrom.getID(mRS);
383                if (cl.mToK != null) {
384                    dstk[ct] = cl.mToK.getID(mRS);
385                }
386                if (cl.mToF != null) {
387                    dstf[ct] = cl.mToF.getID(mRS);
388                }
389                types[ct] = cl.mAllocationType.getID(mRS);
390            }
391
392            int id = mRS.nScriptGroupCreate(kernels, src, dstk, dstf, types);
393            if (id == 0) {
394                throw new RSRuntimeException("Object creation error, should not happen.");
395            }
396
397            ScriptGroup sg = new ScriptGroup(id, mRS);
398            sg.mOutputs = new IO[outputs.size()];
399            for (int ct=0; ct < outputs.size(); ct++) {
400                sg.mOutputs[ct] = outputs.get(ct);
401            }
402
403            sg.mInputs = new IO[inputs.size()];
404            for (int ct=0; ct < inputs.size(); ct++) {
405                sg.mInputs[ct] = inputs.get(ct);
406            }
407
408            return sg;
409        }
410
411    }
412
413
414}
415
416
417