1/*
2 * function: kernel_3d_denoise
3 *     3D Noise Reduction
4 * gain:        The parameter determines the filtering strength for the reference block
5 * threshold:   Noise variances of observed image
6 * restoredPrev: The previous restored image, image2d_t as read only
7 * output:      restored image, image2d_t as write only
8 * input:       observed image, image2d_t as read only
9 * inputPrev1:  reference image, image2d_t as read only
10 * inputPrev2:  reference image, image2d_t as read only
11 */
12
13#ifndef REFERENCE_FRAME_COUNT
14#define REFERENCE_FRAME_COUNT 2
15#endif
16
17#ifndef ENABLE_IIR_FILERING
18#define ENABLE_IIR_FILERING 1
19#endif
20
21#define ENABLE_GRADIENT     1
22
23#ifndef WORKGROUP_WIDTH
24#define WORKGROUP_WIDTH    2
25#endif
26
27#ifndef WORKGROUP_HEIGHT
28#define WORKGROUP_HEIGHT   32
29#endif
30
31#define REF_BLOCK_X_OFFSET  1
32#define REF_BLOCK_Y_OFFSET  4
33
34#define REF_BLOCK_WIDTH     (WORKGROUP_WIDTH + 2 * REF_BLOCK_X_OFFSET)
35#define REF_BLOCK_HEIGHT    (WORKGROUP_HEIGHT + 2 * REF_BLOCK_Y_OFFSET)
36
37inline int2 subgroup_pos(const int sg_id, const int sg_lid)
38{
39    int2 pos;
40    pos.x = mad24(2, sg_id % 2, sg_lid % 2);
41    pos.y = mad24(4, sg_id / 2, sg_lid / 2);
42
43    return pos;
44}
45
46inline void average_slice(float8 ref,
47                          float8 observe,
48                          float8* restore,
49                          float2* sum_weight,
50                          float gain,
51                          float threshold,
52                          uint sg_id,
53                          uint sg_lid)
54{
55    float8 grad = 0.0f;
56    float8 gradient = 0.0f;
57    float8 dist = 0.0f;
58    float8 distance = 0.0f;
59    float weight = 0.0f;
60
61#if ENABLE_GRADIENT
62    // calculate & cumulate gradient
63    if (sg_lid % 2 == 0) {
64        grad = intel_sub_group_shuffle(ref, 4);
65    } else {
66        grad = intel_sub_group_shuffle(ref, 5);
67    }
68    gradient = (float8)(grad.s1, grad.s1, grad.s1, grad.s1, grad.s5, grad.s5, grad.s5, grad.s5);
69
70    // normalize gradient "1/(4*255.0f) = 0.00098039f"
71    grad = fabs(gradient - ref) * 0.00098039f;
72    //grad = mad(-2, gradient, (ref + grad)) * 0.0004902f;
73
74    grad.s0 = (grad.s0 + grad.s1 + grad.s2 + grad.s3);
75    grad.s4 = (grad.s4 + grad.s5 + grad.s6 + grad.s7);
76#endif
77    // calculate & normalize distance "1/255.0f = 0.00392157f"
78    dist = (observe - ref) * 0.00392157f;
79    dist = dist * dist;
80
81    float8 dist_shuffle[8];
82    dist_shuffle[0] = (intel_sub_group_shuffle(dist, 0));
83    dist_shuffle[1] = (intel_sub_group_shuffle(dist, 1));
84    dist_shuffle[2] = (intel_sub_group_shuffle(dist, 2));
85    dist_shuffle[3] = (intel_sub_group_shuffle(dist, 3));
86    dist_shuffle[4] = (intel_sub_group_shuffle(dist, 4));
87    dist_shuffle[5] = (intel_sub_group_shuffle(dist, 5));
88    dist_shuffle[6] = (intel_sub_group_shuffle(dist, 6));
89    dist_shuffle[7] = (intel_sub_group_shuffle(dist, 7));
90
91    if (sg_lid % 2 == 0) {
92        distance = dist_shuffle[0];
93        distance += dist_shuffle[2];
94        distance += dist_shuffle[4];
95        distance += dist_shuffle[6];
96    }
97    else {
98        distance = dist_shuffle[1];
99        distance += dist_shuffle[3];
100        distance += dist_shuffle[5];
101        distance += dist_shuffle[7];
102    }
103
104    // cumulate distance
105    dist.s0 = (distance.s0 + distance.s1 + distance.s2 + distance.s3);
106    dist.s4 = (distance.s4 + distance.s5 + distance.s6 + distance.s7);
107    gain = (grad.s0 < threshold) ? gain : 2.0f * gain;
108    weight = native_exp(-gain * dist.s0);
109    (*restore).lo = mad(weight, ref.lo, (*restore).lo);
110    (*sum_weight).lo = (*sum_weight).lo + weight;
111
112    gain = (grad.s4 < threshold) ? gain : 2.0f * gain;
113    weight = native_exp(-gain * dist.s4);
114    (*restore).hi = mad(weight, ref.hi, (*restore).hi);
115    (*sum_weight).hi = (*sum_weight).hi + weight;
116}
117
118inline void weighted_average (__read_only image2d_t input,
119                              __local uchar8* ref_cache,
120                              bool load_observe,
121                              float8* observe,
122                              float8* restore,
123                              float2* sum_weight,
124                              float gain,
125                              float threshold,
126                              uint sg_id,
127                              uint sg_lid)
128{
129    sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP_TO_EDGE | CLK_FILTER_NEAREST;
130
131    int local_id_x = get_local_id(0);
132    int local_id_y = get_local_id(1);
133    const int group_id_x = get_group_id(0);
134    const int group_id_y = get_group_id(1);
135
136    int start_x = mad24(group_id_x, WORKGROUP_WIDTH, -REF_BLOCK_X_OFFSET);
137    int start_y = mad24(group_id_y, WORKGROUP_HEIGHT, -REF_BLOCK_Y_OFFSET);
138
139    int i = local_id_x + local_id_y * WORKGROUP_WIDTH;
140    for ( int j = i; j < (REF_BLOCK_HEIGHT * REF_BLOCK_WIDTH);
141            j += (WORKGROUP_HEIGHT * WORKGROUP_WIDTH) ) {
142        int corrd_x = start_x + (j % REF_BLOCK_WIDTH);
143        int corrd_y = start_y + (j / REF_BLOCK_WIDTH);
144
145        ref_cache[j] = as_uchar8( convert_ushort4(read_imageui(input,
146                                  sampler,
147                                  (int2)(corrd_x, corrd_y))));
148    }
149    barrier(CLK_LOCAL_MEM_FENCE);
150
151#if WORKGROUP_WIDTH == 4
152    int2 pos = subgroup_pos(sg_id, sg_lid);
153    local_id_x = pos.x;
154    local_id_y = pos.y;
155#endif
156
157    if (load_observe) {
158        (*observe) = convert_float8(
159                         ref_cache[mad24(local_id_y + REF_BLOCK_Y_OFFSET,
160                                         REF_BLOCK_WIDTH,
161                                         local_id_x + REF_BLOCK_X_OFFSET)]);
162        (*restore) = (*observe);
163        (*sum_weight) = 1.0f;
164    }
165
166    float8 ref[2] = {0.0f, 0.0f};
167    __local uchar4* p_ref = (__local uchar4*)(ref_cache);
168
169    // top-left
170    ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24(local_id_y,
171                            2 * REF_BLOCK_WIDTH,
172                            mad24(2, local_id_x, 1))));
173    average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
174
175    // top-right
176    ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24(local_id_y,
177                            2 * REF_BLOCK_WIDTH,
178                            mad24(2, local_id_x, 3))));
179    average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
180
181    // top-mid
182    average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
183
184    // mid-left
185    ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 4),
186                            2 * REF_BLOCK_WIDTH,
187                            mad24(2, local_id_x, 1))));
188    average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
189
190    // mid-right
191    ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 4),
192                            2 * REF_BLOCK_WIDTH,
193                            mad24(2, local_id_x, 3))));
194    average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
195
196    // mid-mid
197    if (!load_observe) {
198        average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
199    }
200
201    // bottom-left
202    ref[0] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 8),
203                            2 * REF_BLOCK_WIDTH,
204                            mad24(2, local_id_x, 1))));
205    average_slice(ref[0], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
206
207    // bottom-right
208    ref[1] = convert_float8(*(__local uchar8*)(p_ref + mad24((local_id_y + 8),
209                            2 * REF_BLOCK_WIDTH,
210                            mad24(2, local_id_x, 3))));
211    average_slice(ref[1], *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
212
213    // bottom-mid
214    average_slice((float8)(ref[0].hi, ref[1].lo), *observe, restore, sum_weight, gain, threshold, sg_id, sg_lid);
215}
216
217__kernel void kernel_3d_denoise ( float gain,
218                                  float threshold,
219                                  __read_only image2d_t restoredPrev,
220                                  __write_only image2d_t output,
221                                  __read_only image2d_t input,
222                                  __read_only image2d_t inputPrev1,
223                                  __read_only image2d_t inputPrev2)
224{
225    float8 restore = 0.0f;
226    float8 observe = 0.0f;
227    float2 sum_weight = 0.0f;
228
229    const int sg_id = get_sub_group_id();
230    const int sg_lid = (get_local_id(1) * WORKGROUP_WIDTH + get_local_id(0)) % 8;
231
232    __local uchar8 ref_cache[REF_BLOCK_HEIGHT * REF_BLOCK_WIDTH];
233
234    weighted_average (input, ref_cache, true, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid);
235
236#if ENABLE_IIR_FILERING
237    weighted_average (restoredPrev, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid);
238#else
239#if REFERENCE_FRAME_COUNT > 1
240    weighted_average (inputPrev1, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid);
241#endif
242
243#if REFERENCE_FRAME_COUNT > 2
244    weighted_average (inputPrev2, ref_cache, false, &observe, &restore, &sum_weight, gain, threshold, sg_id, sg_lid);
245#endif
246#endif
247
248    restore.lo = restore.lo / sum_weight.lo;
249    restore.hi = restore.hi / sum_weight.hi;
250
251    int local_id_x = get_local_id(0);
252    int local_id_y = get_local_id(1);
253    const int group_id_x = get_group_id(0);
254    const int group_id_y = get_group_id(1);
255
256#if WORKGROUP_WIDTH == 4
257    int2 pos = subgroup_pos(sg_id, sg_lid);
258    local_id_x = pos.x;
259    local_id_y = pos.y;
260#endif
261
262    int coor_x = mad24(group_id_x, WORKGROUP_WIDTH, local_id_x);
263    int coor_y = mad24(group_id_y, WORKGROUP_HEIGHT, local_id_y);
264
265    write_imageui(output,
266                  (int2)(coor_x, coor_y),
267                  convert_uint4(as_ushort4(convert_uchar8(restore))));
268}
269
270