1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "rsScriptGroup.h"
18
19#include "rsContext.h"
20// TODO: Is this header needed here?
21#include "rsScriptGroup2.h"
22
23#include <algorithm>
24#include <time.h>
25
26using namespace android;
27using namespace android::renderscript;
28
29ScriptGroup::ScriptGroup(Context *rsc) : ScriptGroupBase(rsc) {
30}
31
32ScriptGroup::~ScriptGroup() {
33    if (mRSC->mHal.funcs.scriptgroup.destroy) {
34        mRSC->mHal.funcs.scriptgroup.destroy(mRSC, this);
35    }
36
37    for (size_t ct=0; ct < mLinks.size(); ct++) {
38        delete mLinks[ct];
39    }
40
41    for (auto input : mInputs) {
42        input->mAlloc.clear();
43    }
44
45    for (auto output : mOutputs) {
46        output->mAlloc.clear();
47    }
48}
49
50ScriptGroup::IO::IO(const ScriptKernelID *kid) {
51    mKernel = kid;
52}
53
54ScriptGroup::Node::Node(Script *s) {
55    mScript = s;
56    mSeen = false;
57    mOrder = 0;
58}
59
60ScriptGroup::Node * ScriptGroup::findNode(Script *s) const {
61    //ALOGE("find %p   %i", s, (int)mNodes.size());
62    for (size_t ct=0; ct < mNodes.size(); ct++) {
63        Node *n = mNodes[ct];
64        for (size_t ct2=0; ct2 < n->mKernels.size(); ct2++) {
65            if (n->mKernels[ct2]->mScript == s) {
66                return n;
67            }
68        }
69    }
70
71    return nullptr;
72}
73
74bool ScriptGroup::calcOrderRecurse(Node *n, int depth) {
75    n->mSeen = true;
76    if (n->mOrder < depth) {
77        n->mOrder = depth;
78    }
79    bool ret = true;
80
81    for (size_t ct=0; ct < n->mOutputs.size(); ct++) {
82        const Link *l = n->mOutputs[ct];
83        Node *nt = NULL;
84        if (l->mDstField.get()) {
85            nt = findNode(l->mDstField->mScript);
86        } else {
87            nt = findNode(l->mDstKernel->mScript);
88        }
89        if (nt->mSeen) {
90            return false;
91        }
92        ret &= calcOrderRecurse(nt, n->mOrder + 1);
93    }
94    return ret;
95}
96
97#if !defined(RS_SERVER) && !defined(RS_COMPATIBILITY_LIB)
98static int CompareNodeForSort(ScriptGroup::Node *const* lhs,
99                              ScriptGroup::Node *const* rhs) {
100    if (lhs[0]->mOrder > rhs[0]->mOrder) {
101        return 1;
102    }
103    return 0;
104}
105#else
106class NodeCompare {
107public:
108    bool operator() (const ScriptGroup::Node* lhs,
109                     const ScriptGroup::Node* rhs) {
110        if (lhs->mOrder > rhs->mOrder) {
111            return true;
112        }
113        return false;
114    }
115};
116#endif
117
118bool ScriptGroup::calcOrder() {
119    // Make nodes
120
121    for (size_t ct=0; ct < mKernels.size(); ct++) {
122        const ScriptKernelID *k = mKernels[ct].get();
123        //ALOGE(" kernel %i, %p  s=%p", (int)ct, k, mKernels[ct]->mScript);
124        Node *n = findNode(k->mScript);
125        //ALOGE("    n = %p", n);
126        if (n == NULL) {
127            n = new Node(k->mScript);
128            mNodes.add(n);
129        }
130        n->mKernels.add(k);
131    }
132
133    // add links
134    //ALOGE("link count %i", (int)mLinks.size());
135    for (size_t ct=0; ct < mLinks.size(); ct++) {
136        Link *l = mLinks[ct];
137        //ALOGE("link  %i %p", (int)ct, l);
138        Node *n = findNode(l->mSource->mScript);
139        //ALOGE("link n %p", n);
140        n->mOutputs.add(l);
141
142        if (l->mDstKernel.get()) {
143            //ALOGE("l->mDstKernel.get() %p", l->mDstKernel.get());
144            n = findNode(l->mDstKernel->mScript);
145            //ALOGE("  n1 %p", n);
146            n->mInputs.add(l);
147        } else {
148            n = findNode(l->mDstField->mScript);
149            //ALOGE("  n2 %p", n);
150            n->mInputs.add(l);
151        }
152    }
153
154    //ALOGE("node count %i", (int)mNodes.size());
155    // Order nodes
156    bool ret = true;
157    for (size_t ct=0; ct < mNodes.size(); ct++) {
158        Node *n = mNodes[ct];
159        if (n->mInputs.size() == 0) {
160            for (size_t ct2=0; ct2 < mNodes.size(); ct2++) {
161                mNodes[ct2]->mSeen = false;
162            }
163            ret &= calcOrderRecurse(n, 0);
164        }
165    }
166
167    for (size_t ct=0; ct < mKernels.size(); ct++) {
168        const ScriptKernelID *k = mKernels[ct].get();
169        const Node *n = findNode(k->mScript);
170
171        if (k->mHasKernelOutput) {
172            bool found = false;
173            for (size_t ct2=0; ct2 < n->mOutputs.size(); ct2++) {
174                if (n->mOutputs[ct2]->mSource.get() == k) {
175                    found = true;
176                    break;
177                }
178            }
179            if (!found) {
180                //ALOGE("add io out %p", k);
181                mOutputs.add(new IO(k));
182            }
183        }
184
185        if (k->mHasKernelInput) {
186            bool found = false;
187            for (size_t ct2=0; ct2 < n->mInputs.size(); ct2++) {
188                if (n->mInputs[ct2]->mDstKernel.get() == k) {
189                    found = true;
190                    break;
191                }
192            }
193            if (!found) {
194                //ALOGE("add io in %p", k);
195                mInputs.add(new IO(k));
196            }
197        }
198    }
199
200    // sort
201#if !defined(RS_SERVER) && !defined(RS_COMPATIBILITY_LIB)
202    mNodes.sort(&CompareNodeForSort);
203#else
204    std::sort(mNodes.begin(), mNodes.end(), NodeCompare());
205#endif
206
207    return ret;
208}
209
210ScriptGroup * ScriptGroup::create(Context *rsc,
211                           ScriptKernelID ** kernels, size_t kernelsSize,
212                           ScriptKernelID ** src, size_t srcSize,
213                           ScriptKernelID ** dstK, size_t dstKSize,
214                           ScriptFieldID  ** dstF, size_t dstFSize,
215                           const Type ** type, size_t typeSize) {
216
217    size_t kernelCount = kernelsSize / sizeof(ScriptKernelID *);
218    size_t linkCount = typeSize / sizeof(Type *);
219
220    //ALOGE("ScriptGroup::create kernels=%i  links=%i", (int)kernelCount, (int)linkCount);
221
222
223    // Start by counting unique kernel sources
224
225    ScriptGroup *sg = new ScriptGroup(rsc);
226
227    sg->mKernels.reserve(kernelCount);
228    for (size_t ct=0; ct < kernelCount; ct++) {
229        sg->mKernels.add(kernels[ct]);
230    }
231
232    sg->mLinks.reserve(linkCount);
233    for (size_t ct=0; ct < linkCount; ct++) {
234        Link *l = new Link();
235        l->mType = type[ct];
236        l->mSource = src[ct];
237        l->mDstField = dstF[ct];
238        l->mDstKernel = dstK[ct];
239        sg->mLinks.add(l);
240    }
241
242    sg->calcOrder();
243
244    // allocate links
245    for (size_t ct=0; ct < sg->mNodes.size(); ct++) {
246        const Node *n = sg->mNodes[ct];
247        for (size_t ct2=0; ct2 < n->mOutputs.size(); ct2++) {
248            Link *l = n->mOutputs[ct2];
249            if (l->mAlloc.get()) {
250                continue;
251            }
252            const ScriptKernelID *k = l->mSource.get();
253
254            Allocation * alloc = Allocation::createAllocation(rsc,
255                    l->mType.get(), RS_ALLOCATION_USAGE_SCRIPT);
256            l->mAlloc = alloc;
257
258            for (size_t ct3=ct2+1; ct3 < n->mOutputs.size(); ct3++) {
259                if (n->mOutputs[ct3]->mSource.get() == l->mSource.get()) {
260                    n->mOutputs[ct3]->mAlloc = alloc;
261                }
262            }
263        }
264    }
265
266    if (rsc->mHal.funcs.scriptgroup.init) {
267        rsc->mHal.funcs.scriptgroup.init(rsc, sg);
268    }
269    sg->incUserRef();
270    return sg;
271}
272
273void ScriptGroup::setInput(Context *rsc, ScriptKernelID *kid, Allocation *a) {
274    for (size_t ct=0; ct < mInputs.size(); ct++) {
275        if (mInputs[ct]->mKernel == kid) {
276            mInputs[ct]->mAlloc = a;
277
278            if (rsc->mHal.funcs.scriptgroup.setInput) {
279                rsc->mHal.funcs.scriptgroup.setInput(rsc, this, kid, a);
280            }
281            return;
282        }
283    }
284    rsAssert(!"ScriptGroup:setInput kid not found");
285}
286
287void ScriptGroup::setOutput(Context *rsc, ScriptKernelID *kid, Allocation *a) {
288    for (size_t ct=0; ct < mOutputs.size(); ct++) {
289        if (mOutputs[ct]->mKernel == kid) {
290            mOutputs[ct]->mAlloc = a;
291
292            if (rsc->mHal.funcs.scriptgroup.setOutput) {
293                rsc->mHal.funcs.scriptgroup.setOutput(rsc, this, kid, a);
294            }
295            return;
296        }
297    }
298    rsAssert(!"ScriptGroup:setOutput kid not found");
299}
300
301bool ScriptGroup::validateInputAndOutput(Context *rsc) {
302    for(size_t i = 0; i < mInputs.size(); i++) {
303        if (mInputs[i]->mAlloc.get() == nullptr) {
304            rsc->setError(RS_ERROR_BAD_VALUE, "ScriptGroup missing input.");
305            return false;
306        }
307    }
308
309    for(size_t i = 0; i < mOutputs.size(); i++) {
310        if (mOutputs[i]->mAlloc.get() == nullptr) {
311            rsc->setError(RS_ERROR_BAD_VALUE, "ScriptGroup missing output.");
312            return false;
313        }
314    }
315
316    return true;
317}
318
319void ScriptGroup::execute(Context *rsc) {
320    if (!validateInputAndOutput(rsc)) {
321        return;
322    }
323
324    if (rsc->mHal.funcs.scriptgroup.execute) {
325        rsc->mHal.funcs.scriptgroup.execute(rsc, this);
326        return;
327    }
328
329    for (size_t ct=0; ct < mNodes.size(); ct++) {
330        Node *n = mNodes[ct];
331        //ALOGE("node %i, order %i, in %i out %i", (int)ct, n->mOrder, (int)n->mInputs.size(), (int)n->mOutputs.size());
332
333        for (size_t ct2=0; ct2 < n->mKernels.size(); ct2++) {
334            const ScriptKernelID *k = n->mKernels[ct2];
335            Allocation *ain = NULL;
336            Allocation *aout = NULL;
337
338            for (size_t ct3=0; ct3 < n->mInputs.size(); ct3++) {
339                if (n->mInputs[ct3]->mDstKernel.get() == k) {
340                    ain = n->mInputs[ct3]->mAlloc.get();
341                    //ALOGE(" link in %p", ain);
342                }
343            }
344            for (size_t ct3=0; ct3 < mInputs.size(); ct3++) {
345                if (mInputs[ct3]->mKernel == k) {
346                    ain = mInputs[ct3]->mAlloc.get();
347                    //ALOGE(" io in %p", ain);
348                }
349            }
350
351            for (size_t ct3=0; ct3 < n->mOutputs.size(); ct3++) {
352                if (n->mOutputs[ct3]->mSource.get() == k) {
353                    aout = n->mOutputs[ct3]->mAlloc.get();
354                    //ALOGE(" link out %p", aout);
355                }
356            }
357            for (size_t ct3=0; ct3 < mOutputs.size(); ct3++) {
358                if (mOutputs[ct3]->mKernel == k) {
359                    aout = mOutputs[ct3]->mAlloc.get();
360                    //ALOGE(" io out %p", aout);
361                }
362            }
363
364            if (ain == NULL) {
365                n->mScript->runForEach(rsc, k->mSlot, NULL, 0, aout, NULL, 0);
366
367            } else {
368                const Allocation *ains[1] = {ain};
369                n->mScript->runForEach(rsc, k->mSlot, ains,
370                                       sizeof(ains) / sizeof(RsAllocation),
371                                       aout, NULL, 0);
372            }
373        }
374
375    }
376
377}
378
379ScriptGroup::Link::Link() {
380}
381
382ScriptGroup::Link::~Link() {
383}
384
385namespace android {
386namespace renderscript {
387
388
389RsScriptGroup rsi_ScriptGroupCreate(Context *rsc,
390                           RsScriptKernelID * kernels, size_t kernelsSize,
391                           RsScriptKernelID * src, size_t srcSize,
392                           RsScriptKernelID * dstK, size_t dstKSize,
393                           RsScriptFieldID * dstF, size_t dstFSize,
394                           const RsType * type, size_t typeSize) {
395
396
397    return ScriptGroup::create(rsc,
398                               (ScriptKernelID **) kernels, kernelsSize,
399                               (ScriptKernelID **) src, srcSize,
400                               (ScriptKernelID **) dstK, dstKSize,
401                               (ScriptFieldID  **) dstF, dstFSize,
402                               (const Type **) type, typeSize);
403}
404
405
406void rsi_ScriptGroupSetInput(Context *rsc, RsScriptGroup sg, RsScriptKernelID kid,
407        RsAllocation alloc) {
408    //ALOGE("rsi_ScriptGroupSetInput");
409    ScriptGroup *s = (ScriptGroup *)sg;
410    s->setInput(rsc, (ScriptKernelID *)kid, (Allocation *)alloc);
411}
412
413void rsi_ScriptGroupSetOutput(Context *rsc, RsScriptGroup sg, RsScriptKernelID kid,
414        RsAllocation alloc) {
415    //ALOGE("rsi_ScriptGroupSetOutput");
416    ScriptGroup *s = (ScriptGroup *)sg;
417    s->setOutput(rsc, (ScriptKernelID *)kid, (Allocation *)alloc);
418}
419
420void rsi_ScriptGroupExecute(Context *rsc, RsScriptGroup sg) {
421    ScriptGroupBase *s = (ScriptGroupBase *)sg;
422    s->execute(rsc);
423}
424
425}
426}
427