1/*M///////////////////////////////////////////////////////////////////////////////////////
2//
3//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4//
5//  By downloading, copying, installing or using the software you agree to this license.
6//  If you do not agree to this license, do not download, install,
7//  copy or use the software.
8//
9//
10//                           License Agreement
11//                For Open Source Computer Vision Library
12//
13// Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
14// Copyright (C) 2009, Willow Garage Inc., all rights reserved.
15// Third party copyrights are property of their respective owners.
16//
17// Redistribution and use in source and binary forms, with or without modification,
18// are permitted provided that the following conditions are met:
19//
20//   * Redistribution's of source code must retain the above copyright notice,
21//     this list of conditions and the following disclaimer.
22//
23//   * Redistribution's in binary form must reproduce the above copyright notice,
24//     this list of conditions and the following disclaimer in the documentation
25//     and/or other materials provided with the distribution.
26//
27//   * The name of the copyright holders may not be used to endorse or promote products
28//     derived from this software without specific prior written permission.
29//
30// This software is provided by the copyright holders and contributors "as is" and
31// any express or implied warranties, including, but not limited to, the implied
32// warranties of merchantability and fitness for a particular purpose are disclaimed.
33// In no event shall the Intel Corporation or contributors be liable for any direct,
34// indirect, incidental, special, exemplary, or consequential damages
35// (including, but not limited to, procurement of substitute goods or services;
36// loss of use, data, or profits; or business interruption) however caused
37// and on any theory of liability, whether in contract, strict liability,
38// or tort (including negligence or otherwise) arising in any way out of
39// the use of this software, even if advised of the possibility of such damage.
40//
41//M*/
42
43#if !defined CUDA_DISABLER
44
45#include "opencv2/core/cuda/common.hpp"
46#include "opencv2/core/cuda/saturate_cast.hpp"
47#include "opencv2/core/cuda/limits.hpp"
48#include "opencv2/core/cuda/reduce.hpp"
49#include "opencv2/core/cuda/functional.hpp"
50
51#include "cuda/stereocsbp.hpp"
52
53namespace cv { namespace cuda { namespace device
54{
55    namespace stereocsbp
56    {
57        ///////////////////////////////////////////////////////////////
58        /////////////////////// init data cost ////////////////////////
59        ///////////////////////////////////////////////////////////////
60
61        template <int channels> static float __device__ pixeldiff(const uchar* left, const uchar* right, float max_data_term);
62        template<> __device__ __forceinline__ float pixeldiff<1>(const uchar* left, const uchar* right, float max_data_term)
63        {
64            return fminf( ::abs((int)*left - *right), max_data_term);
65        }
66        template<> __device__ __forceinline__ float pixeldiff<3>(const uchar* left, const uchar* right, float max_data_term)
67        {
68            float tb = 0.114f * ::abs((int)left[0] - right[0]);
69            float tg = 0.587f * ::abs((int)left[1] - right[1]);
70            float tr = 0.299f * ::abs((int)left[2] - right[2]);
71
72            return fminf(tr + tg + tb, max_data_term);
73        }
74        template<> __device__ __forceinline__ float pixeldiff<4>(const uchar* left, const uchar* right, float max_data_term)
75        {
76            uchar4 l = *((const uchar4*)left);
77            uchar4 r = *((const uchar4*)right);
78
79            float tb = 0.114f * ::abs((int)l.x - r.x);
80            float tg = 0.587f * ::abs((int)l.y - r.y);
81            float tr = 0.299f * ::abs((int)l.z - r.z);
82
83            return fminf(tr + tg + tb, max_data_term);
84        }
85
86        template <typename T>
87        __global__ void get_first_k_initial_global(uchar *ctemp, T* data_cost_selected_, T *selected_disp_pyr, int h, int w, int nr_plane, int ndisp,
88            size_t msg_step, size_t disp_step)
89        {
90            int x = blockIdx.x * blockDim.x + threadIdx.x;
91            int y = blockIdx.y * blockDim.y + threadIdx.y;
92
93            if (y < h && x < w)
94            {
95                T* selected_disparity = selected_disp_pyr + y * msg_step + x;
96                T* data_cost_selected = data_cost_selected_ + y * msg_step + x;
97                T* data_cost = (T*)ctemp + y * msg_step + x;
98
99                for(int i = 0; i < nr_plane; i++)
100                {
101                    T minimum = device::numeric_limits<T>::max();
102                    int id = 0;
103                    for(int d = 0; d < ndisp; d++)
104                    {
105                        T cur = data_cost[d * disp_step];
106                        if(cur < minimum)
107                        {
108                            minimum = cur;
109                            id = d;
110                        }
111                    }
112
113                    data_cost_selected[i  * disp_step] = minimum;
114                    selected_disparity[i  * disp_step] = id;
115                    data_cost         [id * disp_step] = numeric_limits<T>::max();
116                }
117            }
118        }
119
120
121        template <typename T>
122        __global__ void get_first_k_initial_local(uchar *ctemp, T* data_cost_selected_, T* selected_disp_pyr, int h, int w, int nr_plane, int ndisp,
123            size_t msg_step, size_t disp_step)
124        {
125            int x = blockIdx.x * blockDim.x + threadIdx.x;
126            int y = blockIdx.y * blockDim.y + threadIdx.y;
127
128            if (y < h && x < w)
129            {
130                T* selected_disparity = selected_disp_pyr + y * msg_step + x;
131                T* data_cost_selected = data_cost_selected_ + y * msg_step + x;
132                T* data_cost = (T*)ctemp + y * msg_step + x;
133
134                int nr_local_minimum = 0;
135
136                T prev = data_cost[0 * disp_step];
137                T cur  = data_cost[1 * disp_step];
138                T next = data_cost[2 * disp_step];
139
140                for (int d = 1; d < ndisp - 1 && nr_local_minimum < nr_plane; d++)
141                {
142                    if (cur < prev && cur < next)
143                    {
144                        data_cost_selected[nr_local_minimum * disp_step] = cur;
145                        selected_disparity[nr_local_minimum * disp_step] = d;
146
147                        data_cost[d * disp_step] = numeric_limits<T>::max();
148
149                        nr_local_minimum++;
150                    }
151                    prev = cur;
152                    cur = next;
153                    next = data_cost[(d + 1) * disp_step];
154                }
155
156                for (int i = nr_local_minimum; i < nr_plane; i++)
157                {
158                    T minimum = numeric_limits<T>::max();
159                    int id = 0;
160
161                    for (int d = 0; d < ndisp; d++)
162                    {
163                        cur = data_cost[d * disp_step];
164                        if (cur < minimum)
165                        {
166                            minimum = cur;
167                            id = d;
168                        }
169                    }
170                    data_cost_selected[i * disp_step] = minimum;
171                    selected_disparity[i * disp_step] = id;
172
173                    data_cost[id * disp_step] = numeric_limits<T>::max();
174                }
175            }
176        }
177
178        template <typename T, int channels>
179        __global__ void init_data_cost(const uchar *cleft, const uchar *cright, uchar *ctemp, size_t cimg_step,
180                                      int h, int w, int level, int ndisp, float data_weight, float max_data_term,
181                                      int min_disp, size_t msg_step, size_t disp_step)
182        {
183            int x = blockIdx.x * blockDim.x + threadIdx.x;
184            int y = blockIdx.y * blockDim.y + threadIdx.y;
185
186            if (y < h && x < w)
187            {
188                int y0 = y << level;
189                int yt = (y + 1) << level;
190
191                int x0 = x << level;
192                int xt = (x + 1) << level;
193
194                T* data_cost = (T*)ctemp + y * msg_step + x;
195
196                for(int d = 0; d < ndisp; ++d)
197                {
198                    float val = 0.0f;
199                    for(int yi = y0; yi < yt; yi++)
200                    {
201                        for(int xi = x0; xi < xt; xi++)
202                        {
203                            int xr = xi - d;
204                            if(d < min_disp || xr < 0)
205                                val += data_weight * max_data_term;
206                            else
207                            {
208                                const uchar* lle = cleft + yi * cimg_step + xi * channels;
209                                const uchar* lri = cright + yi * cimg_step + xr * channels;
210
211                                val += data_weight * pixeldiff<channels>(lle, lri, max_data_term);
212                            }
213                        }
214                    }
215                    data_cost[disp_step * d] = saturate_cast<T>(val);
216                }
217            }
218        }
219
220        template <typename T, int winsz, int channels>
221        __global__ void init_data_cost_reduce(const uchar *cleft, const uchar *cright, uchar *ctemp, size_t cimg_step,
222                                              int level, int rows, int cols, int h, int ndisp, float data_weight, float max_data_term,
223                                              int min_disp, size_t msg_step, size_t disp_step)
224        {
225            int x_out = blockIdx.x;
226            int y_out = blockIdx.y % h;
227            int d = (blockIdx.y / h) * blockDim.z + threadIdx.z;
228
229            int tid = threadIdx.x;
230
231            if (d < ndisp)
232            {
233                int x0 = x_out << level;
234                int y0 = y_out << level;
235
236                int len = ::min(y0 + winsz, rows) - y0;
237
238                float val = 0.0f;
239                if (x0 + tid < cols)
240                {
241                    if (x0 + tid - d < 0 || d < min_disp)
242                        val = data_weight * max_data_term * len;
243                    else
244                    {
245                        const uchar* lle =  cleft + y0 * cimg_step + channels * (x0 + tid    );
246                        const uchar* lri = cright + y0 * cimg_step + channels * (x0 + tid - d);
247
248                        for(int y = 0; y < len; ++y)
249                        {
250                            val += data_weight * pixeldiff<channels>(lle, lri, max_data_term);
251
252                            lle += cimg_step;
253                            lri += cimg_step;
254                        }
255                    }
256                }
257
258                extern __shared__ float smem[];
259
260                reduce<winsz>(smem + winsz * threadIdx.z, val, tid, plus<float>());
261
262                T* data_cost = (T*)ctemp + y_out * msg_step + x_out;
263
264                if (tid == 0)
265                    data_cost[disp_step * d] = saturate_cast<T>(val);
266            }
267        }
268
269
270        template <typename T>
271        void init_data_cost_caller_(const uchar *cleft, const uchar *cright, uchar *ctemp, size_t cimg_step, int /*rows*/, int /*cols*/, int h, int w, int level, int ndisp, int channels, float data_weight, float max_data_term, int min_disp, size_t msg_step, size_t disp_step, cudaStream_t stream)
272        {
273            dim3 threads(32, 8, 1);
274            dim3 grid(1, 1, 1);
275
276            grid.x = divUp(w, threads.x);
277            grid.y = divUp(h, threads.y);
278
279            switch (channels)
280            {
281            case 1: init_data_cost<T, 1><<<grid, threads, 0, stream>>>(cleft, cright, ctemp, cimg_step, h, w, level, ndisp, data_weight, max_data_term, min_disp, msg_step, disp_step); break;
282            case 3: init_data_cost<T, 3><<<grid, threads, 0, stream>>>(cleft, cright, ctemp, cimg_step, h, w, level, ndisp, data_weight, max_data_term, min_disp, msg_step, disp_step); break;
283            case 4: init_data_cost<T, 4><<<grid, threads, 0, stream>>>(cleft, cright, ctemp, cimg_step, h, w, level, ndisp, data_weight, max_data_term, min_disp, msg_step, disp_step); break;
284            default: CV_Error(cv::Error::BadNumChannels, "Unsupported channels count");
285            }
286        }
287
288        template <typename T, int winsz>
289        void init_data_cost_reduce_caller_(const uchar *cleft, const uchar *cright, uchar *ctemp, size_t cimg_step, int rows, int cols, int h, int w, int level, int ndisp, int channels, float data_weight, float max_data_term, int min_disp, size_t msg_step, size_t disp_step, cudaStream_t stream)
290        {
291            const int threadsNum = 256;
292            const size_t smem_size = threadsNum * sizeof(float);
293
294            dim3 threads(winsz, 1, threadsNum / winsz);
295            dim3 grid(w, h, 1);
296            grid.y *= divUp(ndisp, threads.z);
297
298            switch (channels)
299            {
300            case 1: init_data_cost_reduce<T, winsz, 1><<<grid, threads, smem_size, stream>>>(cleft, cright, ctemp, cimg_step, level, rows, cols, h, ndisp, data_weight, max_data_term, min_disp, msg_step, disp_step); break;
301            case 3: init_data_cost_reduce<T, winsz, 3><<<grid, threads, smem_size, stream>>>(cleft, cright, ctemp, cimg_step, level, rows, cols, h, ndisp, data_weight, max_data_term, min_disp, msg_step, disp_step); break;
302            case 4: init_data_cost_reduce<T, winsz, 4><<<grid, threads, smem_size, stream>>>(cleft, cright, ctemp, cimg_step, level, rows, cols, h, ndisp, data_weight, max_data_term, min_disp, msg_step, disp_step); break;
303            default: CV_Error(cv::Error::BadNumChannels, "Unsupported channels count");
304            }
305        }
306
307        template<class T>
308        void init_data_cost(const uchar *cleft, const uchar *cright, uchar *ctemp, size_t cimg_step, int rows, int cols, T* disp_selected_pyr, T* data_cost_selected, size_t msg_step,
309                    int h, int w, int level, int nr_plane, int ndisp, int channels, float data_weight, float max_data_term, int min_disp, bool use_local_init_data_cost, cudaStream_t stream)
310        {
311
312            typedef void (*InitDataCostCaller)(const uchar *cleft, const uchar *cright, uchar *ctemp, size_t cimg_step, int cols, int rows, int w, int h, int level, int ndisp, int channels, float data_weight, float max_data_term, int min_disp, size_t msg_step, size_t disp_step, cudaStream_t stream);
313
314            static const InitDataCostCaller init_data_cost_callers[] =
315            {
316                init_data_cost_caller_<T>, init_data_cost_caller_<T>, init_data_cost_reduce_caller_<T, 4>,
317                init_data_cost_reduce_caller_<T, 8>, init_data_cost_reduce_caller_<T, 16>, init_data_cost_reduce_caller_<T, 32>,
318                init_data_cost_reduce_caller_<T, 64>, init_data_cost_reduce_caller_<T, 128>, init_data_cost_reduce_caller_<T, 256>
319            };
320
321            size_t disp_step = msg_step * h;
322
323            init_data_cost_callers[level](cleft, cright, ctemp, cimg_step, rows, cols, h, w, level, ndisp, channels, data_weight, max_data_term, min_disp, msg_step, disp_step, stream);
324            cudaSafeCall( cudaGetLastError() );
325
326            if (stream == 0)
327                cudaSafeCall( cudaDeviceSynchronize() );
328
329            dim3 threads(32, 8, 1);
330            dim3 grid(1, 1, 1);
331
332            grid.x = divUp(w, threads.x);
333            grid.y = divUp(h, threads.y);
334
335            if (use_local_init_data_cost == true)
336                get_first_k_initial_local<<<grid, threads, 0, stream>>> (ctemp, data_cost_selected, disp_selected_pyr, h, w, nr_plane, ndisp, msg_step, disp_step);
337            else
338                get_first_k_initial_global<<<grid, threads, 0, stream>>>(ctemp, data_cost_selected, disp_selected_pyr, h, w, nr_plane, ndisp, msg_step, disp_step);
339
340            cudaSafeCall( cudaGetLastError() );
341
342            if (stream == 0)
343                cudaSafeCall( cudaDeviceSynchronize() );
344        }
345
346        template void init_data_cost<short>(const uchar *cleft, const uchar *cright, uchar *ctemp, size_t cimg_step, int rows, int cols, short* disp_selected_pyr, short* data_cost_selected, size_t msg_step,
347                    int h, int w, int level, int nr_plane, int ndisp, int channels, float data_weight, float max_data_term, int min_disp, bool use_local_init_data_cost, cudaStream_t stream);
348
349        template void init_data_cost<float>(const uchar *cleft, const uchar *cright, uchar *ctemp, size_t cimg_step, int rows, int cols, float* disp_selected_pyr, float* data_cost_selected, size_t msg_step,
350                    int h, int w, int level, int nr_plane, int ndisp, int channels, float data_weight, float max_data_term, int min_disp, bool use_local_init_data_cost, cudaStream_t stream);
351
352        ///////////////////////////////////////////////////////////////
353        ////////////////////// compute data cost //////////////////////
354        ///////////////////////////////////////////////////////////////
355
356        template <typename T, int channels>
357        __global__ void compute_data_cost(const uchar *cleft, const uchar *cright, size_t cimg_step, const T* selected_disp_pyr, T* data_cost_, int h, int w, int level, int nr_plane, float data_weight, float max_data_term, int min_disp, size_t msg_step, size_t disp_step1, size_t disp_step2)
358        {
359            int x = blockIdx.x * blockDim.x + threadIdx.x;
360            int y = blockIdx.y * blockDim.y + threadIdx.y;
361
362            if (y < h && x < w)
363            {
364                int y0 = y << level;
365                int yt = (y + 1) << level;
366
367                int x0 = x << level;
368                int xt = (x + 1) << level;
369
370                const T* selected_disparity = selected_disp_pyr + y/2 * msg_step + x/2;
371                T* data_cost = data_cost_ + y * msg_step + x;
372
373                for(int d = 0; d < nr_plane; d++)
374                {
375                    float val = 0.0f;
376                    for(int yi = y0; yi < yt; yi++)
377                    {
378                        for(int xi = x0; xi < xt; xi++)
379                        {
380                            int sel_disp = selected_disparity[d * disp_step2];
381                            int xr = xi - sel_disp;
382
383                            if (xr < 0 || sel_disp < min_disp)
384                                val += data_weight * max_data_term;
385                            else
386                            {
387                                const uchar* left_x = cleft + yi * cimg_step + xi * channels;
388                                const uchar* right_x = cright + yi * cimg_step + xr * channels;
389
390                                val += data_weight * pixeldiff<channels>(left_x, right_x, max_data_term);
391                            }
392                        }
393                    }
394                    data_cost[disp_step1 * d] = saturate_cast<T>(val);
395                }
396            }
397        }
398
399        template <typename T, int winsz, int channels>
400        __global__ void compute_data_cost_reduce(const uchar *cleft, const uchar *cright, size_t cimg_step, const T* selected_disp_pyr, T* data_cost_, int level, int rows, int cols, int h, int nr_plane, float data_weight, float max_data_term, int min_disp, size_t msg_step, size_t disp_step1, size_t disp_step2)
401        {
402            int x_out = blockIdx.x;
403            int y_out = blockIdx.y % h;
404            int d = (blockIdx.y / h) * blockDim.z + threadIdx.z;
405
406            int tid = threadIdx.x;
407
408            const T* selected_disparity = selected_disp_pyr + y_out/2 * msg_step + x_out/2;
409            T* data_cost = data_cost_ + y_out * msg_step + x_out;
410
411            if (d < nr_plane)
412            {
413                int sel_disp = selected_disparity[d * disp_step2];
414
415                int x0 = x_out << level;
416                int y0 = y_out << level;
417
418                int len = ::min(y0 + winsz, rows) - y0;
419
420                float val = 0.0f;
421                if (x0 + tid < cols)
422                {
423                    if (x0 + tid - sel_disp < 0 || sel_disp < min_disp)
424                        val = data_weight * max_data_term * len;
425                    else
426                    {
427                        const uchar* lle =  cleft + y0 * cimg_step + channels * (x0 + tid    );
428                        const uchar* lri = cright + y0 * cimg_step + channels * (x0 + tid - sel_disp);
429
430                        for(int y = 0; y < len; ++y)
431                        {
432                            val += data_weight * pixeldiff<channels>(lle, lri, max_data_term);
433
434                            lle += cimg_step;
435                            lri += cimg_step;
436                        }
437                    }
438                }
439
440                extern __shared__ float smem[];
441
442                reduce<winsz>(smem + winsz * threadIdx.z, val, tid, plus<float>());
443
444                if (tid == 0)
445                    data_cost[disp_step1 * d] = saturate_cast<T>(val);
446            }
447        }
448
449        template <typename T>
450        void compute_data_cost_caller_(const uchar *cleft, const uchar *cright, size_t cimg_step, const T* disp_selected_pyr, T* data_cost, int /*rows*/, int /*cols*/,
451                                      int h, int w, int level, int nr_plane, int channels, float data_weight, float max_data_term, int min_disp, size_t msg_step, size_t disp_step1, size_t disp_step2, cudaStream_t stream)
452        {
453            dim3 threads(32, 8, 1);
454            dim3 grid(1, 1, 1);
455
456            grid.x = divUp(w, threads.x);
457            grid.y = divUp(h, threads.y);
458
459            switch(channels)
460            {
461            case 1: compute_data_cost<T, 1><<<grid, threads, 0, stream>>>(cleft, cright, cimg_step, disp_selected_pyr, data_cost, h, w, level, nr_plane, data_weight, max_data_term, min_disp, msg_step, disp_step1, disp_step2); break;
462            case 3: compute_data_cost<T, 3><<<grid, threads, 0, stream>>>(cleft, cright, cimg_step, disp_selected_pyr, data_cost, h, w, level, nr_plane, data_weight, max_data_term, min_disp, msg_step, disp_step1, disp_step2); break;
463            case 4: compute_data_cost<T, 4><<<grid, threads, 0, stream>>>(cleft, cright, cimg_step, disp_selected_pyr, data_cost, h, w, level, nr_plane, data_weight, max_data_term, min_disp, msg_step, disp_step1, disp_step2); break;
464            default: CV_Error(cv::Error::BadNumChannels, "Unsupported channels count");
465            }
466        }
467
468        template <typename T, int winsz>
469        void compute_data_cost_reduce_caller_(const uchar *cleft, const uchar *cright, size_t cimg_step, const T* disp_selected_pyr, T* data_cost, int rows, int cols,
470                                      int h, int w, int level, int nr_plane, int channels, float data_weight, float max_data_term, int min_disp, size_t msg_step, size_t disp_step1, size_t disp_step2, cudaStream_t stream)
471        {
472            const int threadsNum = 256;
473            const size_t smem_size = threadsNum * sizeof(float);
474
475            dim3 threads(winsz, 1, threadsNum / winsz);
476            dim3 grid(w, h, 1);
477            grid.y *= divUp(nr_plane, threads.z);
478
479            switch (channels)
480            {
481            case 1: compute_data_cost_reduce<T, winsz, 1><<<grid, threads, smem_size, stream>>>(cleft, cright, cimg_step, disp_selected_pyr, data_cost, level, rows, cols, h, nr_plane, data_weight, max_data_term, min_disp, msg_step, disp_step1, disp_step2); break;
482            case 3: compute_data_cost_reduce<T, winsz, 3><<<grid, threads, smem_size, stream>>>(cleft, cright, cimg_step, disp_selected_pyr, data_cost, level, rows, cols, h, nr_plane, data_weight, max_data_term, min_disp, msg_step, disp_step1, disp_step2); break;
483            case 4: compute_data_cost_reduce<T, winsz, 4><<<grid, threads, smem_size, stream>>>(cleft, cright, cimg_step, disp_selected_pyr, data_cost, level, rows, cols, h, nr_plane, data_weight, max_data_term, min_disp, msg_step, disp_step1, disp_step2); break;
484            default: CV_Error(cv::Error::BadNumChannels, "Unsupported channels count");
485            }
486        }
487
488        template<class T>
489        void compute_data_cost(const uchar *cleft, const uchar *cright, size_t cimg_step, const T* disp_selected_pyr, T* data_cost, size_t msg_step,
490                               int rows, int cols, int h, int w, int h2, int level, int nr_plane, int channels, float data_weight, float max_data_term,
491                               int min_disp, cudaStream_t stream)
492        {
493            typedef void (*ComputeDataCostCaller)(const uchar *cleft, const uchar *cright, size_t cimg_step, const T* disp_selected_pyr, T* data_cost, int rows, int cols,
494                int h, int w, int level, int nr_plane, int channels, float data_weight, float max_data_term, int min_disp, size_t msg_step, size_t disp_step1, size_t disp_step2, cudaStream_t stream);
495
496            static const ComputeDataCostCaller callers[] =
497            {
498                compute_data_cost_caller_<T>, compute_data_cost_caller_<T>, compute_data_cost_reduce_caller_<T, 4>,
499                compute_data_cost_reduce_caller_<T, 8>, compute_data_cost_reduce_caller_<T, 16>, compute_data_cost_reduce_caller_<T, 32>,
500                compute_data_cost_reduce_caller_<T, 64>, compute_data_cost_reduce_caller_<T, 128>, compute_data_cost_reduce_caller_<T, 256>
501            };
502
503            size_t disp_step1 = msg_step * h;
504            size_t disp_step2 = msg_step * h2;
505
506            callers[level](cleft, cright, cimg_step, disp_selected_pyr, data_cost, rows, cols, h, w, level, nr_plane, channels, data_weight, max_data_term, min_disp, msg_step, disp_step1, disp_step2, stream);
507            cudaSafeCall( cudaGetLastError() );
508
509            if (stream == 0)
510                cudaSafeCall( cudaDeviceSynchronize() );
511        }
512
513        template void compute_data_cost(const uchar *cleft, const uchar *cright, size_t cimg_step, const short* disp_selected_pyr, short* data_cost, size_t msg_step,
514                               int rows, int cols, int h, int w, int h2, int level, int nr_plane, int channels, float data_weight, float max_data_term, int min_disp, cudaStream_t stream);
515
516        template void compute_data_cost(const uchar *cleft, const uchar *cright, size_t cimg_step, const float* disp_selected_pyr, float* data_cost, size_t msg_step,
517                               int rows, int cols, int h, int w, int h2, int level, int nr_plane, int channels, float data_weight, float max_data_term, int min_disp, cudaStream_t stream);
518
519
520        ///////////////////////////////////////////////////////////////
521        //////////////////////// init message /////////////////////////
522        ///////////////////////////////////////////////////////////////
523
524
525         template <typename T>
526        __device__ void get_first_k_element_increase(T* u_new, T* d_new, T* l_new, T* r_new,
527                                                     const T* u_cur, const T* d_cur, const T* l_cur, const T* r_cur,
528                                                     T* data_cost_selected, T* disparity_selected_new, T* data_cost_new,
529                                                     const T* data_cost_cur, const T* disparity_selected_cur,
530                                                     int nr_plane, int nr_plane2, size_t disp_step1, size_t disp_step2)
531        {
532            for(int i = 0; i < nr_plane; i++)
533            {
534                T minimum = numeric_limits<T>::max();
535                int id = 0;
536                for(int j = 0; j < nr_plane2; j++)
537                {
538                    T cur = data_cost_new[j * disp_step1];
539                    if(cur < minimum)
540                    {
541                        minimum = cur;
542                        id = j;
543                    }
544                }
545
546                data_cost_selected[i * disp_step1] = data_cost_cur[id * disp_step1];
547                disparity_selected_new[i * disp_step1] = disparity_selected_cur[id * disp_step2];
548
549                u_new[i * disp_step1] = u_cur[id * disp_step2];
550                d_new[i * disp_step1] = d_cur[id * disp_step2];
551                l_new[i * disp_step1] = l_cur[id * disp_step2];
552                r_new[i * disp_step1] = r_cur[id * disp_step2];
553
554                data_cost_new[id * disp_step1] = numeric_limits<T>::max();
555            }
556        }
557
558        template <typename T>
559        __global__ void init_message(uchar *ctemp, T* u_new_, T* d_new_, T* l_new_, T* r_new_,
560                                     const T* u_cur_, const T* d_cur_, const T* l_cur_, const T* r_cur_,
561                                     T* selected_disp_pyr_new, const T* selected_disp_pyr_cur,
562                                     T* data_cost_selected_, const T* data_cost_,
563                                     int h, int w, int nr_plane, int h2, int w2, int nr_plane2,
564                                     size_t msg_step, size_t disp_step1, size_t disp_step2)
565        {
566            int x = blockIdx.x * blockDim.x + threadIdx.x;
567            int y = blockIdx.y * blockDim.y + threadIdx.y;
568
569            if (y < h && x < w)
570            {
571                const T* u_cur = u_cur_ + ::min(h2-1, y/2 + 1) * msg_step + x/2;
572                const T* d_cur = d_cur_ + ::max(0, y/2 - 1)    * msg_step + x/2;
573                const T* l_cur = l_cur_ + (y/2)                * msg_step + ::min(w2-1, x/2 + 1);
574                const T* r_cur = r_cur_ + (y/2)                * msg_step + ::max(0, x/2 - 1);
575
576                T* data_cost_new = (T*)ctemp + y * msg_step + x;
577
578                const T* disparity_selected_cur = selected_disp_pyr_cur + y/2 * msg_step + x/2;
579                const T* data_cost = data_cost_ + y * msg_step + x;
580
581                for(int d = 0; d < nr_plane2; d++)
582                {
583                    int idx2 = d * disp_step2;
584
585                    T val  = data_cost[d * disp_step1] + u_cur[idx2] + d_cur[idx2] + l_cur[idx2] + r_cur[idx2];
586                    data_cost_new[d * disp_step1] = val;
587                }
588
589                T* data_cost_selected = data_cost_selected_ + y * msg_step + x;
590                T* disparity_selected_new = selected_disp_pyr_new + y * msg_step + x;
591
592                T* u_new = u_new_ + y * msg_step + x;
593                T* d_new = d_new_ + y * msg_step + x;
594                T* l_new = l_new_ + y * msg_step + x;
595                T* r_new = r_new_ + y * msg_step + x;
596
597                u_cur = u_cur_ + y/2 * msg_step + x/2;
598                d_cur = d_cur_ + y/2 * msg_step + x/2;
599                l_cur = l_cur_ + y/2 * msg_step + x/2;
600                r_cur = r_cur_ + y/2 * msg_step + x/2;
601
602                get_first_k_element_increase(u_new, d_new, l_new, r_new, u_cur, d_cur, l_cur, r_cur,
603                                             data_cost_selected, disparity_selected_new, data_cost_new,
604                                             data_cost, disparity_selected_cur, nr_plane, nr_plane2,
605                                             disp_step1, disp_step2);
606            }
607        }
608
609
610        template<class T>
611        void init_message(uchar *ctemp, T* u_new, T* d_new, T* l_new, T* r_new,
612                          const T* u_cur, const T* d_cur, const T* l_cur, const T* r_cur,
613                          T* selected_disp_pyr_new, const T* selected_disp_pyr_cur,
614                          T* data_cost_selected, const T* data_cost, size_t msg_step,
615                          int h, int w, int nr_plane, int h2, int w2, int nr_plane2, cudaStream_t stream)
616        {
617
618            size_t disp_step1 = msg_step * h;
619            size_t disp_step2 = msg_step * h2;
620
621            dim3 threads(32, 8, 1);
622            dim3 grid(1, 1, 1);
623
624            grid.x = divUp(w, threads.x);
625            grid.y = divUp(h, threads.y);
626
627            init_message<<<grid, threads, 0, stream>>>(ctemp, u_new, d_new, l_new, r_new,
628                                                       u_cur, d_cur, l_cur, r_cur,
629                                                       selected_disp_pyr_new, selected_disp_pyr_cur,
630                                                       data_cost_selected, data_cost,
631                                                       h, w, nr_plane, h2, w2, nr_plane2,
632                                                       msg_step, disp_step1, disp_step2);
633            cudaSafeCall( cudaGetLastError() );
634
635            if (stream == 0)
636                cudaSafeCall( cudaDeviceSynchronize() );
637        }
638
639
640        template void init_message(uchar *ctemp, short* u_new, short* d_new, short* l_new, short* r_new,
641                          const short* u_cur, const short* d_cur, const short* l_cur, const short* r_cur,
642                          short* selected_disp_pyr_new, const short* selected_disp_pyr_cur,
643                          short* data_cost_selected, const short* data_cost, size_t msg_step,
644                          int h, int w, int nr_plane, int h2, int w2, int nr_plane2, cudaStream_t stream);
645
646        template void init_message(uchar *ctemp, float* u_new, float* d_new, float* l_new, float* r_new,
647                          const float* u_cur, const float* d_cur, const float* l_cur, const float* r_cur,
648                          float* selected_disp_pyr_new, const float* selected_disp_pyr_cur,
649                          float* data_cost_selected, const float* data_cost, size_t msg_step,
650                          int h, int w, int nr_plane, int h2, int w2, int nr_plane2, cudaStream_t stream);
651
652        ///////////////////////////////////////////////////////////////
653        ////////////////////  calc all iterations /////////////////////
654        ///////////////////////////////////////////////////////////////
655
656        template <typename T>
657        __device__ void message_per_pixel(const T* data, T* msg_dst, const T* msg1, const T* msg2, const T* msg3,
658                                          const T* dst_disp, const T* src_disp, int nr_plane, int max_disc_term, float disc_single_jump, volatile T* temp,
659                                          size_t disp_step)
660        {
661            T minimum = numeric_limits<T>::max();
662
663            for(int d = 0; d < nr_plane; d++)
664            {
665                int idx = d * disp_step;
666                T val  = data[idx] + msg1[idx] + msg2[idx] + msg3[idx];
667
668                if(val < minimum)
669                    minimum = val;
670
671                msg_dst[idx] = val;
672            }
673
674            float sum = 0;
675            for(int d = 0; d < nr_plane; d++)
676            {
677                float cost_min = minimum + max_disc_term;
678                T src_disp_reg = src_disp[d * disp_step];
679
680                for(int d2 = 0; d2 < nr_plane; d2++)
681                    cost_min = fmin(cost_min, msg_dst[d2 * disp_step] + disc_single_jump * ::abs(dst_disp[d2 * disp_step] - src_disp_reg));
682
683                temp[d * disp_step] = saturate_cast<T>(cost_min);
684                sum += cost_min;
685            }
686            sum /= nr_plane;
687
688            for(int d = 0; d < nr_plane; d++)
689                msg_dst[d * disp_step] = saturate_cast<T>(temp[d * disp_step] - sum);
690        }
691
692        template <typename T>
693        __global__ void compute_message(uchar *ctemp, T* u_, T* d_, T* l_, T* r_, const T* data_cost_selected, const T* selected_disp_pyr_cur, int h, int w, int nr_plane, int i, int max_disc_term, float disc_single_jump, size_t msg_step, size_t disp_step)
694        {
695            int y = blockIdx.y * blockDim.y + threadIdx.y;
696            int x = ((blockIdx.x * blockDim.x + threadIdx.x) << 1) + ((y + i) & 1);
697
698            if (y > 0 && y < h - 1 && x > 0 && x < w - 1)
699            {
700                const T* data = data_cost_selected + y * msg_step + x;
701
702                T* u = u_ + y * msg_step + x;
703                T* d = d_ + y * msg_step + x;
704                T* l = l_ + y * msg_step + x;
705                T* r = r_ + y * msg_step + x;
706
707                const T* disp = selected_disp_pyr_cur + y * msg_step + x;
708
709                T* temp = (T*)ctemp + y * msg_step + x;
710
711                message_per_pixel(data, u, r - 1, u + msg_step, l + 1, disp, disp - msg_step, nr_plane, max_disc_term, disc_single_jump, temp, disp_step);
712                message_per_pixel(data, d, d - msg_step, r - 1, l + 1, disp, disp + msg_step, nr_plane, max_disc_term, disc_single_jump, temp, disp_step);
713                message_per_pixel(data, l, u + msg_step, d - msg_step, l + 1, disp, disp - 1, nr_plane, max_disc_term, disc_single_jump, temp, disp_step);
714                message_per_pixel(data, r, u + msg_step, d - msg_step, r - 1, disp, disp + 1, nr_plane, max_disc_term, disc_single_jump, temp, disp_step);
715            }
716        }
717
718
719        template<class T>
720        void calc_all_iterations(uchar *ctemp, T* u, T* d, T* l, T* r, const T* data_cost_selected,
721            const T* selected_disp_pyr_cur, size_t msg_step, int h, int w, int nr_plane, int iters, int max_disc_term, float disc_single_jump, cudaStream_t stream)
722        {
723            size_t disp_step = msg_step * h;
724
725            dim3 threads(32, 8, 1);
726            dim3 grid(1, 1, 1);
727
728            grid.x = divUp(w, threads.x << 1);
729            grid.y = divUp(h, threads.y);
730
731            for(int t = 0; t < iters; ++t)
732            {
733                compute_message<<<grid, threads, 0, stream>>>(ctemp, u, d, l, r, data_cost_selected, selected_disp_pyr_cur, h, w, nr_plane, t & 1, max_disc_term, disc_single_jump, msg_step, disp_step);
734                cudaSafeCall( cudaGetLastError() );
735            }
736            if (stream == 0)
737                    cudaSafeCall( cudaDeviceSynchronize() );
738        };
739
740        template void calc_all_iterations(uchar *ctemp, short* u, short* d, short* l, short* r, const short* data_cost_selected, const short* selected_disp_pyr_cur, size_t msg_step,
741            int h, int w, int nr_plane, int iters, int max_disc_term, float disc_single_jump, cudaStream_t stream);
742
743        template void calc_all_iterations(uchar *ctemp, float* u, float* d, float* l, float* r, const float* data_cost_selected, const float* selected_disp_pyr_cur, size_t msg_step,
744            int h, int w, int nr_plane, int iters, int max_disc_term, float disc_single_jump, cudaStream_t stream);
745
746
747        ///////////////////////////////////////////////////////////////
748        /////////////////////////// output ////////////////////////////
749        ///////////////////////////////////////////////////////////////
750
751
752        template <typename T>
753        __global__ void compute_disp(const T* u_, const T* d_, const T* l_, const T* r_,
754                                     const T* data_cost_selected, const T* disp_selected_pyr,
755                                     PtrStepSz<short> disp, int nr_plane, size_t msg_step, size_t disp_step)
756        {
757            int x = blockIdx.x * blockDim.x + threadIdx.x;
758            int y = blockIdx.y * blockDim.y + threadIdx.y;
759
760            if (y > 0 && y < disp.rows - 1 && x > 0 && x < disp.cols - 1)
761            {
762                const T* data = data_cost_selected + y * msg_step + x;
763                const T* disp_selected = disp_selected_pyr + y * msg_step + x;
764
765                const T* u = u_ + (y+1) * msg_step + (x+0);
766                const T* d = d_ + (y-1) * msg_step + (x+0);
767                const T* l = l_ + (y+0) * msg_step + (x+1);
768                const T* r = r_ + (y+0) * msg_step + (x-1);
769
770                int best = 0;
771                T best_val = numeric_limits<T>::max();
772                for (int i = 0; i < nr_plane; ++i)
773                {
774                    int idx = i * disp_step;
775                    T val = data[idx]+ u[idx] + d[idx] + l[idx] + r[idx];
776
777                    if (val < best_val)
778                    {
779                        best_val = val;
780                        best = saturate_cast<short>(disp_selected[idx]);
781                    }
782                }
783                disp(y, x) = best;
784            }
785        }
786
787        template<class T>
788        void compute_disp(const T* u, const T* d, const T* l, const T* r, const T* data_cost_selected, const T* disp_selected, size_t msg_step,
789            const PtrStepSz<short>& disp, int nr_plane, cudaStream_t stream)
790        {
791            size_t disp_step = disp.rows * msg_step;
792
793            dim3 threads(32, 8, 1);
794            dim3 grid(1, 1, 1);
795
796            grid.x = divUp(disp.cols, threads.x);
797            grid.y = divUp(disp.rows, threads.y);
798
799            compute_disp<<<grid, threads, 0, stream>>>(u, d, l, r, data_cost_selected, disp_selected, disp, nr_plane, msg_step, disp_step);
800            cudaSafeCall( cudaGetLastError() );
801
802            if (stream == 0)
803                cudaSafeCall( cudaDeviceSynchronize() );
804        }
805
806        template void compute_disp(const short* u, const short* d, const short* l, const short* r, const short* data_cost_selected, const short* disp_selected, size_t msg_step,
807            const PtrStepSz<short>& disp, int nr_plane, cudaStream_t stream);
808
809        template void compute_disp(const float* u, const float* d, const float* l, const float* r, const float* data_cost_selected, const float* disp_selected, size_t msg_step,
810            const PtrStepSz<short>& disp, int nr_plane, cudaStream_t stream);
811    } // namespace stereocsbp
812}}} // namespace cv { namespace cuda { namespace cudev {
813
814#endif /* CUDA_DISABLER */
815