1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "rsScriptGroup.h"
18
19#include "rsContext.h"
20// TODO: Is this header needed here?
21#include "rsScriptGroup2.h"
22
23#include <algorithm>
24#include <time.h>
25
26namespace android {
27namespace renderscript {
28
29ScriptGroup::ScriptGroup(Context *rsc) : ScriptGroupBase(rsc) {
30}
31
32ScriptGroup::~ScriptGroup() {
33    if (mRSC->mHal.funcs.scriptgroup.destroy) {
34        mRSC->mHal.funcs.scriptgroup.destroy(mRSC, this);
35    }
36
37    for (size_t ct=0; ct < mLinks.size(); ct++) {
38        delete mLinks[ct];
39    }
40
41    for (auto input : mInputs) {
42        input->mAlloc.clear();
43    }
44
45    for (auto output : mOutputs) {
46        output->mAlloc.clear();
47    }
48}
49
50ScriptGroup::IO::IO(const ScriptKernelID *kid) {
51    mKernel = kid;
52}
53
54ScriptGroup::Node::Node(Script *s) {
55    mScript = s;
56    mSeen = false;
57    mOrder = 0;
58}
59
60ScriptGroup::Node * ScriptGroup::findNode(Script *s) const {
61    //ALOGE("find %p   %i", s, (int)mNodes.size());
62    for (size_t ct=0; ct < mNodes.size(); ct++) {
63        Node *n = mNodes[ct];
64        for (size_t ct2=0; ct2 < n->mKernels.size(); ct2++) {
65            if (n->mKernels[ct2]->mScript == s) {
66                return n;
67            }
68        }
69    }
70
71    return nullptr;
72}
73
74bool ScriptGroup::calcOrderRecurse(Node *n, int depth) {
75    n->mSeen = true;
76    if (n->mOrder < depth) {
77        n->mOrder = depth;
78    }
79    bool ret = true;
80
81    for (size_t ct=0; ct < n->mOutputs.size(); ct++) {
82        const Link *l = n->mOutputs[ct];
83        Node *nt = NULL;
84        if (l->mDstField.get()) {
85            nt = findNode(l->mDstField->mScript);
86        } else {
87            nt = findNode(l->mDstKernel->mScript);
88        }
89        if (nt->mSeen) {
90            return false;
91        }
92        ret &= calcOrderRecurse(nt, n->mOrder + 1);
93    }
94    return ret;
95}
96
97class NodeCompare {
98public:
99    bool operator() (const ScriptGroup::Node* lhs,
100                     const ScriptGroup::Node* rhs) {
101        return (lhs->mOrder < rhs->mOrder);
102    }
103};
104
105bool ScriptGroup::calcOrder() {
106    // Make nodes
107
108    for (size_t ct=0; ct < mKernels.size(); ct++) {
109        const ScriptKernelID *k = mKernels[ct].get();
110        //ALOGE(" kernel %i, %p  s=%p", (int)ct, k, mKernels[ct]->mScript);
111        Node *n = findNode(k->mScript);
112        //ALOGE("    n = %p", n);
113        if (n == NULL) {
114            n = new Node(k->mScript);
115            mNodes.push_back(n);
116        }
117        n->mKernels.push_back(k);
118    }
119
120    // add links
121    //ALOGE("link count %i", (int)mLinks.size());
122    for (size_t ct=0; ct < mLinks.size(); ct++) {
123        Link *l = mLinks[ct];
124        //ALOGE("link  %i %p", (int)ct, l);
125        Node *n = findNode(l->mSource->mScript);
126        //ALOGE("link n %p", n);
127        n->mOutputs.push_back(l);
128
129        if (l->mDstKernel.get()) {
130            //ALOGE("l->mDstKernel.get() %p", l->mDstKernel.get());
131            n = findNode(l->mDstKernel->mScript);
132            //ALOGE("  n1 %p", n);
133            n->mInputs.push_back(l);
134        } else {
135            n = findNode(l->mDstField->mScript);
136            //ALOGE("  n2 %p", n);
137            n->mInputs.push_back(l);
138        }
139    }
140
141    //ALOGE("node count %i", (int)mNodes.size());
142    // Order nodes
143    bool ret = true;
144    for (size_t ct=0; ct < mNodes.size(); ct++) {
145        Node *n = mNodes[ct];
146        if (n->mInputs.size() == 0) {
147            for (size_t ct2=0; ct2 < mNodes.size(); ct2++) {
148                mNodes[ct2]->mSeen = false;
149            }
150            ret &= calcOrderRecurse(n, 0);
151        }
152    }
153
154    for (size_t ct=0; ct < mKernels.size(); ct++) {
155        const ScriptKernelID *k = mKernels[ct].get();
156        const Node *n = findNode(k->mScript);
157
158        if (k->mHasKernelOutput) {
159            bool found = false;
160            for (size_t ct2=0; ct2 < n->mOutputs.size(); ct2++) {
161                if (n->mOutputs[ct2]->mSource.get() == k) {
162                    found = true;
163                    break;
164                }
165            }
166            if (!found) {
167                //ALOGE("add io out %p", k);
168                mOutputs.push_back(new IO(k));
169            }
170        }
171
172        if (k->mHasKernelInput) {
173            bool found = false;
174            for (size_t ct2=0; ct2 < n->mInputs.size(); ct2++) {
175                if (n->mInputs[ct2]->mDstKernel.get() == k) {
176                    found = true;
177                    break;
178                }
179            }
180            if (!found) {
181                //ALOGE("add io in %p", k);
182                mInputs.push_back(new IO(k));
183            }
184        }
185    }
186
187    // Sort mNodes in the increasing order.
188    std::sort(mNodes.begin(), mNodes.end(), NodeCompare());
189    return ret;
190}
191
192ScriptGroup * ScriptGroup::create(Context *rsc,
193                           ScriptKernelID ** kernels, size_t kernelsSize,
194                           ScriptKernelID ** src, size_t srcSize,
195                           ScriptKernelID ** dstK, size_t dstKSize,
196                           ScriptFieldID  ** dstF, size_t dstFSize,
197                           const Type ** type, size_t typeSize) {
198
199    size_t kernelCount = kernelsSize / sizeof(ScriptKernelID *);
200    size_t linkCount = typeSize / sizeof(Type *);
201
202    //ALOGE("ScriptGroup::create kernels=%i  links=%i", (int)kernelCount, (int)linkCount);
203
204
205    // Start by counting unique kernel sources
206
207    ScriptGroup *sg = new ScriptGroup(rsc);
208
209    sg->mKernels.reserve(kernelCount);
210    for (size_t ct=0; ct < kernelCount; ct++) {
211        sg->mKernels.push_back(kernels[ct]);
212    }
213
214    sg->mLinks.reserve(linkCount);
215    for (size_t ct=0; ct < linkCount; ct++) {
216        Link *l = new Link();
217        l->mType = type[ct];
218        l->mSource = src[ct];
219        l->mDstField = dstF[ct];
220        l->mDstKernel = dstK[ct];
221        sg->mLinks.push_back(l);
222    }
223
224    sg->calcOrder();
225
226    // allocate links
227    for (size_t ct=0; ct < sg->mNodes.size(); ct++) {
228        const Node *n = sg->mNodes[ct];
229        for (size_t ct2=0; ct2 < n->mOutputs.size(); ct2++) {
230            Link *l = n->mOutputs[ct2];
231            if (l->mAlloc.get()) {
232                continue;
233            }
234            Allocation * alloc = Allocation::createAllocation(rsc,
235                    l->mType.get(), RS_ALLOCATION_USAGE_SCRIPT);
236            l->mAlloc = alloc;
237
238            for (size_t ct3=ct2+1; ct3 < n->mOutputs.size(); ct3++) {
239                if (n->mOutputs[ct3]->mSource.get() == l->mSource.get()) {
240                    n->mOutputs[ct3]->mAlloc = alloc;
241                }
242            }
243        }
244    }
245
246    if (rsc->mHal.funcs.scriptgroup.init) {
247        rsc->mHal.funcs.scriptgroup.init(rsc, sg);
248    }
249    sg->incUserRef();
250    return sg;
251}
252
253void ScriptGroup::setInput(Context *rsc, ScriptKernelID *kid, Allocation *a) {
254    for (size_t ct=0; ct < mInputs.size(); ct++) {
255        if (mInputs[ct]->mKernel == kid) {
256            mInputs[ct]->mAlloc = a;
257
258            if (rsc->mHal.funcs.scriptgroup.setInput) {
259                rsc->mHal.funcs.scriptgroup.setInput(rsc, this, kid, a);
260            }
261            return;
262        }
263    }
264    rsAssert(!"ScriptGroup:setInput kid not found");
265}
266
267void ScriptGroup::setOutput(Context *rsc, ScriptKernelID *kid, Allocation *a) {
268    for (size_t ct=0; ct < mOutputs.size(); ct++) {
269        if (mOutputs[ct]->mKernel == kid) {
270            mOutputs[ct]->mAlloc = a;
271
272            if (rsc->mHal.funcs.scriptgroup.setOutput) {
273                rsc->mHal.funcs.scriptgroup.setOutput(rsc, this, kid, a);
274            }
275            return;
276        }
277    }
278    rsAssert(!"ScriptGroup:setOutput kid not found");
279}
280
281bool ScriptGroup::validateInputAndOutput(Context *rsc) {
282    for(size_t i = 0; i < mInputs.size(); i++) {
283        if (mInputs[i]->mAlloc.get() == nullptr) {
284            rsc->setError(RS_ERROR_BAD_VALUE, "ScriptGroup missing input.");
285            return false;
286        }
287    }
288
289    for(size_t i = 0; i < mOutputs.size(); i++) {
290        if (mOutputs[i]->mAlloc.get() == nullptr) {
291            rsc->setError(RS_ERROR_BAD_VALUE, "ScriptGroup missing output.");
292            return false;
293        }
294    }
295
296    return true;
297}
298
299void ScriptGroup::execute(Context *rsc) {
300    if (!validateInputAndOutput(rsc)) {
301        return;
302    }
303
304    if (rsc->mHal.funcs.scriptgroup.execute) {
305        rsc->mHal.funcs.scriptgroup.execute(rsc, this);
306        return;
307    }
308
309    for (size_t ct=0; ct < mNodes.size(); ct++) {
310        Node *n = mNodes[ct];
311        //ALOGE("node %i, order %i, in %i out %i", (int)ct, n->mOrder, (int)n->mInputs.size(), (int)n->mOutputs.size());
312
313        for (size_t ct2=0; ct2 < n->mKernels.size(); ct2++) {
314            const ScriptKernelID *k = n->mKernels[ct2];
315            Allocation *ain = NULL;
316            Allocation *aout = NULL;
317
318            for (size_t ct3=0; ct3 < n->mInputs.size(); ct3++) {
319                if (n->mInputs[ct3]->mDstKernel.get() == k) {
320                    ain = n->mInputs[ct3]->mAlloc.get();
321                    //ALOGE(" link in %p", ain);
322                }
323            }
324            for (size_t ct3=0; ct3 < mInputs.size(); ct3++) {
325                if (mInputs[ct3]->mKernel == k) {
326                    ain = mInputs[ct3]->mAlloc.get();
327                    //ALOGE(" io in %p", ain);
328                }
329            }
330
331            for (size_t ct3=0; ct3 < n->mOutputs.size(); ct3++) {
332                if (n->mOutputs[ct3]->mSource.get() == k) {
333                    aout = n->mOutputs[ct3]->mAlloc.get();
334                    //ALOGE(" link out %p", aout);
335                }
336            }
337            for (size_t ct3=0; ct3 < mOutputs.size(); ct3++) {
338                if (mOutputs[ct3]->mKernel == k) {
339                    aout = mOutputs[ct3]->mAlloc.get();
340                    //ALOGE(" io out %p", aout);
341                }
342            }
343
344            if (ain == NULL) {
345                n->mScript->runForEach(rsc, k->mSlot, NULL, 0, aout, NULL, 0);
346
347            } else {
348                const Allocation *ains[1] = {ain};
349                n->mScript->runForEach(rsc, k->mSlot, ains,
350                                       sizeof(ains) / sizeof(RsAllocation),
351                                       aout, NULL, 0);
352            }
353        }
354
355    }
356
357}
358
359ScriptGroup::Link::Link() {
360}
361
362ScriptGroup::Link::~Link() {
363}
364
365
366RsScriptGroup rsi_ScriptGroupCreate(Context *rsc,
367                           RsScriptKernelID * kernels, size_t kernelsSize,
368                           RsScriptKernelID * src, size_t srcSize,
369                           RsScriptKernelID * dstK, size_t dstKSize,
370                           RsScriptFieldID * dstF, size_t dstFSize,
371                           const RsType * type, size_t typeSize) {
372
373
374    return ScriptGroup::create(rsc,
375                               (ScriptKernelID **) kernels, kernelsSize,
376                               (ScriptKernelID **) src, srcSize,
377                               (ScriptKernelID **) dstK, dstKSize,
378                               (ScriptFieldID  **) dstF, dstFSize,
379                               (const Type **) type, typeSize);
380}
381
382
383void rsi_ScriptGroupSetInput(Context *rsc, RsScriptGroup sg, RsScriptKernelID kid,
384        RsAllocation alloc) {
385    //ALOGE("rsi_ScriptGroupSetInput");
386    ScriptGroup *s = (ScriptGroup *)sg;
387    s->setInput(rsc, (ScriptKernelID *)kid, (Allocation *)alloc);
388}
389
390void rsi_ScriptGroupSetOutput(Context *rsc, RsScriptGroup sg, RsScriptKernelID kid,
391        RsAllocation alloc) {
392    //ALOGE("rsi_ScriptGroupSetOutput");
393    ScriptGroup *s = (ScriptGroup *)sg;
394    s->setOutput(rsc, (ScriptKernelID *)kid, (Allocation *)alloc);
395}
396
397void rsi_ScriptGroupExecute(Context *rsc, RsScriptGroup sg) {
398    ScriptGroupBase *s = (ScriptGroupBase *)sg;
399    s->execute(rsc);
400}
401
402} // namespace renderscript
403} // namespace android
404