1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "rsContext.h"
18#include <time.h>
19
20using namespace android;
21using namespace android::renderscript;
22
23ScriptGroup::ScriptGroup(Context *rsc) : ObjectBase(rsc) {
24}
25
26ScriptGroup::~ScriptGroup() {
27    if (mRSC->mHal.funcs.scriptgroup.destroy) {
28        mRSC->mHal.funcs.scriptgroup.destroy(mRSC, this);
29    }
30
31    for (size_t ct=0; ct < mLinks.size(); ct++) {
32        delete mLinks[ct];
33    }
34}
35
36ScriptGroup::IO::IO(const ScriptKernelID *kid) {
37    mKernel = kid;
38}
39
40ScriptGroup::Node::Node(Script *s) {
41    mScript = s;
42    mSeen = false;
43    mOrder = 0;
44}
45
46ScriptGroup::Node * ScriptGroup::findNode(Script *s) const {
47    //ALOGE("find %p   %i", s, (int)mNodes.size());
48    for (size_t ct=0; ct < mNodes.size(); ct++) {
49        Node *n = mNodes[ct];
50        for (size_t ct2=0; ct2 < n->mKernels.size(); ct2++) {
51            if (n->mKernels[ct2]->mScript == s) {
52                return n;
53            }
54        }
55    }
56    return NULL;
57}
58
59bool ScriptGroup::calcOrderRecurse(Node *n, int depth) {
60    n->mSeen = true;
61    if (n->mOrder < depth) {
62        n->mOrder = depth;
63    }
64    bool ret = true;
65    for (size_t ct=0; ct < n->mOutputs.size(); ct++) {
66        const Link *l = n->mOutputs[ct];
67        Node *nt = NULL;
68        if (l->mDstField.get()) {
69            nt = findNode(l->mDstField->mScript);
70        } else {
71            nt = findNode(l->mDstKernel->mScript);
72        }
73        if (nt->mSeen) {
74            return false;
75        }
76        ret &= calcOrderRecurse(nt, n->mOrder + 1);
77    }
78    return ret;
79}
80
81static int CompareNodeForSort(ScriptGroup::Node *const* lhs,
82                              ScriptGroup::Node *const* rhs) {
83    if (lhs[0]->mOrder > rhs[0]->mOrder) {
84        return 1;
85    }
86    return 0;
87}
88
89
90bool ScriptGroup::calcOrder() {
91    // Make nodes
92    for (size_t ct=0; ct < mKernels.size(); ct++) {
93        const ScriptKernelID *k = mKernels[ct].get();
94        //ALOGE(" kernel %i, %p  s=%p", (int)ct, k, mKernels[ct]->mScript);
95        Node *n = findNode(k->mScript);
96        //ALOGE("    n = %p", n);
97        if (n == NULL) {
98            n = new Node(k->mScript);
99            mNodes.add(n);
100        }
101        n->mKernels.add(k);
102    }
103
104    // add links
105    //ALOGE("link count %i", (int)mLinks.size());
106    for (size_t ct=0; ct < mLinks.size(); ct++) {
107        Link *l = mLinks[ct];
108        //ALOGE("link  %i %p", (int)ct, l);
109        Node *n = findNode(l->mSource->mScript);
110        //ALOGE("link n %p", n);
111        n->mOutputs.add(l);
112
113        if (l->mDstKernel.get()) {
114            //ALOGE("l->mDstKernel.get() %p", l->mDstKernel.get());
115            n = findNode(l->mDstKernel->mScript);
116            //ALOGE("  n1 %p", n);
117            n->mInputs.add(l);
118        } else {
119            n = findNode(l->mDstField->mScript);
120            //ALOGE("  n2 %p", n);
121            n->mInputs.add(l);
122        }
123    }
124
125    //ALOGE("node count %i", (int)mNodes.size());
126    // Order nodes
127    bool ret = true;
128    for (size_t ct=0; ct < mNodes.size(); ct++) {
129        Node *n = mNodes[ct];
130        if (n->mInputs.size() == 0) {
131            for (size_t ct2=0; ct2 < mNodes.size(); ct2++) {
132                mNodes[ct2]->mSeen = false;
133            }
134            ret &= calcOrderRecurse(n, 0);
135        }
136    }
137
138    for (size_t ct=0; ct < mKernels.size(); ct++) {
139        const ScriptKernelID *k = mKernels[ct].get();
140        const Node *n = findNode(k->mScript);
141
142        if (k->mHasKernelOutput) {
143            bool found = false;
144            for (size_t ct2=0; ct2 < n->mOutputs.size(); ct2++) {
145                if (n->mOutputs[ct2]->mSource.get() == k) {
146                    found = true;
147                    break;
148                }
149            }
150            if (!found) {
151                //ALOGE("add io out %p", k);
152                mOutputs.add(new IO(k));
153            }
154        }
155
156        if (k->mHasKernelInput) {
157            bool found = false;
158            for (size_t ct2=0; ct2 < n->mInputs.size(); ct2++) {
159                if (n->mInputs[ct2]->mDstKernel.get() == k) {
160                    found = true;
161                    break;
162                }
163            }
164            if (!found) {
165                //ALOGE("add io in %p", k);
166                mInputs.add(new IO(k));
167            }
168        }
169    }
170
171    // sort
172    mNodes.sort(&CompareNodeForSort);
173
174    return ret;
175}
176
177ScriptGroup * ScriptGroup::create(Context *rsc,
178                           ScriptKernelID ** kernels, size_t kernelsSize,
179                           ScriptKernelID ** src, size_t srcSize,
180                           ScriptKernelID ** dstK, size_t dstKSize,
181                           ScriptFieldID  ** dstF, size_t dstFSize,
182                           const Type ** type, size_t typeSize) {
183
184    size_t kernelCount = kernelsSize / sizeof(ScriptKernelID *);
185    size_t linkCount = typeSize / sizeof(Type *);
186
187    //ALOGE("ScriptGroup::create kernels=%i  links=%i", (int)kernelCount, (int)linkCount);
188
189
190    // Start by counting unique kernel sources
191
192    ScriptGroup *sg = new ScriptGroup(rsc);
193
194    sg->mKernels.reserve(kernelCount);
195    for (size_t ct=0; ct < kernelCount; ct++) {
196        sg->mKernels.add(kernels[ct]);
197    }
198
199    sg->mLinks.reserve(linkCount);
200    for (size_t ct=0; ct < linkCount; ct++) {
201        Link *l = new Link();
202        l->mType = type[ct];
203        l->mSource = src[ct];
204        l->mDstField = dstF[ct];
205        l->mDstKernel = dstK[ct];
206        sg->mLinks.add(l);
207    }
208
209    sg->calcOrder();
210
211    // allocate links
212    for (size_t ct=0; ct < sg->mNodes.size(); ct++) {
213        const Node *n = sg->mNodes[ct];
214        for (size_t ct2=0; ct2 < n->mOutputs.size(); ct2++) {
215            Link *l = n->mOutputs[ct2];
216            if (l->mAlloc.get()) {
217                continue;
218            }
219            const ScriptKernelID *k = l->mSource.get();
220
221            Allocation * alloc = Allocation::createAllocation(rsc,
222                    l->mType.get(), RS_ALLOCATION_USAGE_SCRIPT);
223            l->mAlloc = alloc;
224
225            for (size_t ct3=ct2+1; ct3 < n->mOutputs.size(); ct3++) {
226                if (n->mOutputs[ct3]->mSource.get() == l->mSource.get()) {
227                    n->mOutputs[ct3]->mAlloc = alloc;
228                }
229            }
230        }
231    }
232
233    if (rsc->mHal.funcs.scriptgroup.init) {
234        rsc->mHal.funcs.scriptgroup.init(rsc, sg);
235    }
236    return sg;
237}
238
239void ScriptGroup::setInput(Context *rsc, ScriptKernelID *kid, Allocation *a) {
240    for (size_t ct=0; ct < mInputs.size(); ct++) {
241        if (mInputs[ct]->mKernel == kid) {
242            mInputs[ct]->mAlloc = a;
243
244            if (rsc->mHal.funcs.scriptgroup.setInput) {
245                rsc->mHal.funcs.scriptgroup.setInput(rsc, this, kid, a);
246            }
247            return;
248        }
249    }
250    rsAssert(!"ScriptGroup:setInput kid not found");
251}
252
253void ScriptGroup::setOutput(Context *rsc, ScriptKernelID *kid, Allocation *a) {
254    for (size_t ct=0; ct < mOutputs.size(); ct++) {
255        if (mOutputs[ct]->mKernel == kid) {
256            mOutputs[ct]->mAlloc = a;
257
258            if (rsc->mHal.funcs.scriptgroup.setOutput) {
259                rsc->mHal.funcs.scriptgroup.setOutput(rsc, this, kid, a);
260            }
261            return;
262        }
263    }
264    rsAssert(!"ScriptGroup:setOutput kid not found");
265}
266
267void ScriptGroup::execute(Context *rsc) {
268    //ALOGE("ScriptGroup::execute");
269    if (rsc->mHal.funcs.scriptgroup.execute) {
270        rsc->mHal.funcs.scriptgroup.execute(rsc, this);
271        return;
272    }
273
274    for (size_t ct=0; ct < mNodes.size(); ct++) {
275        Node *n = mNodes[ct];
276        //ALOGE("node %i, order %i, in %i out %i", (int)ct, n->mOrder, (int)n->mInputs.size(), (int)n->mOutputs.size());
277
278        for (size_t ct2=0; ct2 < n->mKernels.size(); ct2++) {
279            const ScriptKernelID *k = n->mKernels[ct2];
280            Allocation *ain = NULL;
281            Allocation *aout = NULL;
282
283            for (size_t ct3=0; ct3 < n->mInputs.size(); ct3++) {
284                if (n->mInputs[ct3]->mDstKernel.get() == k) {
285                    ain = n->mInputs[ct3]->mAlloc.get();
286                    //ALOGE(" link in %p", ain);
287                }
288            }
289            for (size_t ct3=0; ct3 < mInputs.size(); ct3++) {
290                if (mInputs[ct3]->mKernel == k) {
291                    ain = mInputs[ct3]->mAlloc.get();
292                    //ALOGE(" io in %p", ain);
293                }
294            }
295
296            for (size_t ct3=0; ct3 < n->mOutputs.size(); ct3++) {
297                if (n->mOutputs[ct3]->mSource.get() == k) {
298                    aout = n->mOutputs[ct3]->mAlloc.get();
299                    //ALOGE(" link out %p", aout);
300                }
301            }
302            for (size_t ct3=0; ct3 < mOutputs.size(); ct3++) {
303                if (mOutputs[ct3]->mKernel == k) {
304                    aout = mOutputs[ct3]->mAlloc.get();
305                    //ALOGE(" io out %p", aout);
306                }
307            }
308
309            n->mScript->runForEach(rsc, k->mSlot, ain, aout, NULL, 0);
310        }
311
312    }
313
314}
315
316void ScriptGroup::serialize(Context *rsc, OStream *stream) const {
317}
318
319RsA3DClassID ScriptGroup::getClassId() const {
320    return RS_A3D_CLASS_ID_SCRIPT_GROUP;
321}
322
323ScriptGroup::Link::Link() {
324}
325
326ScriptGroup::Link::~Link() {
327}
328
329namespace android {
330namespace renderscript {
331
332
333RsScriptGroup rsi_ScriptGroupCreate(Context *rsc,
334                           RsScriptKernelID * kernels, size_t kernelsSize,
335                           RsScriptKernelID * src, size_t srcSize,
336                           RsScriptKernelID * dstK, size_t dstKSize,
337                           RsScriptFieldID * dstF, size_t dstFSize,
338                           const RsType * type, size_t typeSize) {
339
340
341    return ScriptGroup::create(rsc,
342                               (ScriptKernelID **) kernels, kernelsSize,
343                               (ScriptKernelID **) src, srcSize,
344                               (ScriptKernelID **) dstK, dstKSize,
345                               (ScriptFieldID  **) dstF, dstFSize,
346                               (const Type **) type, typeSize);
347}
348
349
350void rsi_ScriptGroupSetInput(Context *rsc, RsScriptGroup sg, RsScriptKernelID kid,
351        RsAllocation alloc) {
352    //ALOGE("rsi_ScriptGroupSetInput");
353    ScriptGroup *s = (ScriptGroup *)sg;
354    s->setInput(rsc, (ScriptKernelID *)kid, (Allocation *)alloc);
355}
356
357void rsi_ScriptGroupSetOutput(Context *rsc, RsScriptGroup sg, RsScriptKernelID kid,
358        RsAllocation alloc) {
359    //ALOGE("rsi_ScriptGroupSetOutput");
360    ScriptGroup *s = (ScriptGroup *)sg;
361    s->setOutput(rsc, (ScriptKernelID *)kid, (Allocation *)alloc);
362}
363
364void rsi_ScriptGroupExecute(Context *rsc, RsScriptGroup sg) {
365    //ALOGE("rsi_ScriptGroupExecute");
366    ScriptGroup *s = (ScriptGroup *)sg;
367    s->execute(rsc);
368}
369
370}
371}
372
373