1/*
2 *  Copyright 2011 The LibYuv Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "libyuv/compare.h"
12
13#include <float.h>
14#include <math.h>
15#ifdef _OPENMP
16#include <omp.h>
17#endif
18
19#include "libyuv/basic_types.h"
20#include "libyuv/cpu_id.h"
21#include "libyuv/row.h"
22
23#ifdef __cplusplus
24namespace libyuv {
25extern "C" {
26#endif
27
28// hash seed of 5381 recommended.
29// Internal C version of HashDjb2 with int sized count for efficiency.
30static uint32 HashDjb2_C(const uint8* src, int count, uint32 seed) {
31  uint32 hash = seed;
32  for (int i = 0; i < count; ++i) {
33    hash += (hash << 5) + src[i];
34  }
35  return hash;
36}
37
38// This module is for Visual C x86
39#if !defined(YUV_DISABLE_ASM) && defined(_M_IX86)
40#define HAS_HASHDJB2_SSE41
41static const uvec32 kHash16x33 = { 0x92d9e201, 0, 0, 0 };  // 33 ^ 16
42static const uvec32 kHashMul0 = {
43  0x0c3525e1,  // 33 ^ 15
44  0xa3476dc1,  // 33 ^ 14
45  0x3b4039a1,  // 33 ^ 13
46  0x4f5f0981,  // 33 ^ 12
47};
48static const uvec32 kHashMul1 = {
49  0x30f35d61,  // 33 ^ 11
50  0x855cb541,  // 33 ^ 10
51  0x040a9121,  // 33 ^ 9
52  0x747c7101,  // 33 ^ 8
53};
54static const uvec32 kHashMul2 = {
55  0xec41d4e1,  // 33 ^ 7
56  0x4cfa3cc1,  // 33 ^ 6
57  0x025528a1,  // 33 ^ 5
58  0x00121881,  // 33 ^ 4
59};
60static const uvec32 kHashMul3 = {
61  0x00008c61,  // 33 ^ 3
62  0x00000441,  // 33 ^ 2
63  0x00000021,  // 33 ^ 1
64  0x00000001,  // 33 ^ 0
65};
66
67// 27: 66 0F 38 40 C6     pmulld      xmm0,xmm6
68// 44: 66 0F 38 40 DD     pmulld      xmm3,xmm5
69// 59: 66 0F 38 40 E5     pmulld      xmm4,xmm5
70// 72: 66 0F 38 40 D5     pmulld      xmm2,xmm5
71// 83: 66 0F 38 40 CD     pmulld      xmm1,xmm5
72#define pmulld(reg) _asm _emit 0x66 _asm _emit 0x0F _asm _emit 0x38 \
73    _asm _emit 0x40 _asm _emit reg
74
75__declspec(naked) __declspec(align(16))
76static uint32 HashDjb2_SSE41(const uint8* src, int count, uint32 seed) {
77  __asm {
78    mov        eax, [esp + 4]    // src
79    mov        ecx, [esp + 8]    // count
80    movd       xmm0, [esp + 12]  // seed
81
82    pxor       xmm7, xmm7        // constant 0 for unpck
83    movdqa     xmm6, kHash16x33
84
85    align      16
86  wloop:
87    movdqu     xmm1, [eax]       // src[0-15]
88    lea        eax, [eax + 16]
89    pmulld(0xc6)                 // pmulld      xmm0,xmm6  hash *= 33 ^ 16
90    movdqa     xmm5, kHashMul0
91    movdqa     xmm2, xmm1
92    punpcklbw  xmm2, xmm7        // src[0-7]
93    movdqa     xmm3, xmm2
94    punpcklwd  xmm3, xmm7        // src[0-3]
95    pmulld(0xdd)                 // pmulld     xmm3, xmm5
96    movdqa     xmm5, kHashMul1
97    movdqa     xmm4, xmm2
98    punpckhwd  xmm4, xmm7        // src[4-7]
99    pmulld(0xe5)                 // pmulld     xmm4, xmm5
100    movdqa     xmm5, kHashMul2
101    punpckhbw  xmm1, xmm7        // src[8-15]
102    movdqa     xmm2, xmm1
103    punpcklwd  xmm2, xmm7        // src[8-11]
104    pmulld(0xd5)                 // pmulld     xmm2, xmm5
105    movdqa     xmm5, kHashMul3
106    punpckhwd  xmm1, xmm7        // src[12-15]
107    pmulld(0xcd)                 // pmulld     xmm1, xmm5
108    paddd      xmm3, xmm4        // add 16 results
109    paddd      xmm1, xmm2
110    sub        ecx, 16
111    paddd      xmm1, xmm3
112
113    pshufd     xmm2, xmm1, 14    // upper 2 dwords
114    paddd      xmm1, xmm2
115    pshufd     xmm2, xmm1, 1
116    paddd      xmm1, xmm2
117    paddd      xmm0, xmm1
118    jg         wloop
119
120    movd       eax, xmm0        // return hash
121    ret
122  }
123}
124
125#elif !defined(YUV_DISABLE_ASM) && \
126    (defined(__x86_64__) || (defined(__i386__) && !defined(__pic__)))
127// GCC 4.2 on OSX has link error when passing static or const to inline.
128// TODO(fbarchard): Use static const when gcc 4.2 support is dropped.
129#ifdef __APPLE__
130#define CONST
131#else
132#define CONST static const
133#endif
134#define HAS_HASHDJB2_SSE41
135CONST uvec32 kHash16x33 = { 0x92d9e201, 0, 0, 0 };  // 33 ^ 16
136CONST uvec32 kHashMul0 = {
137  0x0c3525e1,  // 33 ^ 15
138  0xa3476dc1,  // 33 ^ 14
139  0x3b4039a1,  // 33 ^ 13
140  0x4f5f0981,  // 33 ^ 12
141};
142CONST uvec32 kHashMul1 = {
143  0x30f35d61,  // 33 ^ 11
144  0x855cb541,  // 33 ^ 10
145  0x040a9121,  // 33 ^ 9
146  0x747c7101,  // 33 ^ 8
147};
148CONST uvec32 kHashMul2 = {
149  0xec41d4e1,  // 33 ^ 7
150  0x4cfa3cc1,  // 33 ^ 6
151  0x025528a1,  // 33 ^ 5
152  0x00121881,  // 33 ^ 4
153};
154CONST uvec32 kHashMul3 = {
155  0x00008c61,  // 33 ^ 3
156  0x00000441,  // 33 ^ 2
157  0x00000021,  // 33 ^ 1
158  0x00000001,  // 33 ^ 0
159};
160static uint32 HashDjb2_SSE41(const uint8* src, int count, uint32 seed) {
161  uint32 hash;
162  asm volatile (
163    "movd      %2,%%xmm0                       \n"
164    "pxor      %%xmm7,%%xmm7                   \n"
165    "movdqa    %4,%%xmm6                       \n"
166    ".p2align  4                               \n"
167  "1:                                          \n"
168    "movdqu    (%0),%%xmm1                     \n"
169    "lea       0x10(%0),%0                     \n"
170    "pmulld    %%xmm6,%%xmm0                   \n"
171    "movdqa    %5,%%xmm5                       \n"
172    "movdqa    %%xmm1,%%xmm2                   \n"
173    "punpcklbw %%xmm7,%%xmm2                   \n"
174    "movdqa    %%xmm2,%%xmm3                   \n"
175    "punpcklwd %%xmm7,%%xmm3                   \n"
176    "pmulld    %%xmm5,%%xmm3                   \n"
177    "movdqa    %6,%%xmm5                       \n"
178    "movdqa    %%xmm2,%%xmm4                   \n"
179    "punpckhwd %%xmm7,%%xmm4                   \n"
180    "pmulld    %%xmm5,%%xmm4                   \n"
181    "movdqa    %7,%%xmm5                       \n"
182    "punpckhbw %%xmm7,%%xmm1                   \n"
183    "movdqa    %%xmm1,%%xmm2                   \n"
184    "punpcklwd %%xmm7,%%xmm2                   \n"
185    "pmulld    %%xmm5,%%xmm2                   \n"
186    "movdqa    %8,%%xmm5                       \n"
187    "punpckhwd %%xmm7,%%xmm1                   \n"
188    "pmulld    %%xmm5,%%xmm1                   \n"
189    "paddd     %%xmm4,%%xmm3                   \n"
190    "paddd     %%xmm2,%%xmm1                   \n"
191    "sub       $0x10,%1                        \n"
192    "paddd     %%xmm3,%%xmm1                   \n"
193    "pshufd    $0xe,%%xmm1,%%xmm2              \n"
194    "paddd     %%xmm2,%%xmm1                   \n"
195    "pshufd    $0x1,%%xmm1,%%xmm2              \n"
196    "paddd     %%xmm2,%%xmm1                   \n"
197    "paddd     %%xmm1,%%xmm0                   \n"
198    "jg        1b                              \n"
199    "movd      %%xmm0,%3                       \n"
200  : "+r"(src),        // %0
201    "+r"(count),      // %1
202    "+rm"(seed),      // %2
203    "=g"(hash)        // %3
204  : "m"(kHash16x33),  // %4
205    "m"(kHashMul0),   // %5
206    "m"(kHashMul1),   // %6
207    "m"(kHashMul2),   // %7
208    "m"(kHashMul3)    // %8
209  : "memory", "cc"
210#if defined(__SSE2__)
211    , "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7"
212#endif
213  );
214  return hash;
215}
216#endif  // HAS_HASHDJB2_SSE41
217
218// hash seed of 5381 recommended.
219LIBYUV_API
220uint32 HashDjb2(const uint8* src, uint64 count, uint32 seed) {
221  uint32 (*HashDjb2_SSE)(const uint8* src, int count, uint32 seed) = HashDjb2_C;
222#if defined(HAS_HASHDJB2_SSE41)
223  if (TestCpuFlag(kCpuHasSSE41)) {
224    HashDjb2_SSE = HashDjb2_SSE41;
225  }
226#endif
227
228  const int kBlockSize = 1 << 15;  // 32768;
229  while (count >= static_cast<uint64>(kBlockSize)) {
230    seed = HashDjb2_SSE(src, kBlockSize, seed);
231    src += kBlockSize;
232    count -= kBlockSize;
233  }
234  int remainder = static_cast<int>(count) & ~15;
235  if (remainder) {
236    seed = HashDjb2_SSE(src, remainder, seed);
237    src += remainder;
238    count -= remainder;
239  }
240  remainder = static_cast<int>(count) & 15;
241  if (remainder) {
242    seed = HashDjb2_C(src, remainder, seed);
243  }
244  return seed;
245}
246
247#if !defined(YUV_DISABLE_ASM) && (defined(__ARM_NEON__) || defined(LIBYUV_NEON))
248#define HAS_SUMSQUAREERROR_NEON
249
250uint32 SumSquareError_NEON(const uint8* src_a, const uint8* src_b, int count);
251
252#elif !defined(YUV_DISABLE_ASM) && defined(_M_IX86)
253#define HAS_SUMSQUAREERROR_SSE2
254__declspec(naked) __declspec(align(16))
255static uint32 SumSquareError_SSE2(const uint8* src_a, const uint8* src_b,
256                                  int count) {
257  __asm {
258    mov        eax, [esp + 4]    // src_a
259    mov        edx, [esp + 8]    // src_b
260    mov        ecx, [esp + 12]   // count
261    pxor       xmm0, xmm0
262    pxor       xmm5, xmm5
263    sub        edx, eax
264
265    align      16
266  wloop:
267    movdqa     xmm1, [eax]
268    movdqa     xmm2, [eax + edx]
269    lea        eax,  [eax + 16]
270    sub        ecx, 16
271    movdqa     xmm3, xmm1  // abs trick
272    psubusb    xmm1, xmm2
273    psubusb    xmm2, xmm3
274    por        xmm1, xmm2
275    movdqa     xmm2, xmm1
276    punpcklbw  xmm1, xmm5
277    punpckhbw  xmm2, xmm5
278    pmaddwd    xmm1, xmm1
279    pmaddwd    xmm2, xmm2
280    paddd      xmm0, xmm1
281    paddd      xmm0, xmm2
282    jg         wloop
283
284    pshufd     xmm1, xmm0, 0EEh
285    paddd      xmm0, xmm1
286    pshufd     xmm1, xmm0, 01h
287    paddd      xmm0, xmm1
288    movd       eax, xmm0
289    ret
290  }
291}
292
293#elif !defined(YUV_DISABLE_ASM) && (defined(__x86_64__) || defined(__i386__))
294#define HAS_SUMSQUAREERROR_SSE2
295static uint32 SumSquareError_SSE2(const uint8* src_a, const uint8* src_b,
296                                  int count) {
297  uint32 sse;
298  asm volatile (
299    "pxor      %%xmm0,%%xmm0                   \n"
300    "pxor      %%xmm5,%%xmm5                   \n"
301    "sub       %0,%1                           \n"
302    ".p2align  4                               \n"
303    "1:                                        \n"
304    "movdqa    (%0),%%xmm1                     \n"
305    "movdqa    (%0,%1,1),%%xmm2                \n"
306    "lea       0x10(%0),%0                     \n"
307    "sub       $0x10,%2                        \n"
308    "movdqa    %%xmm1,%%xmm3                   \n"
309    "psubusb   %%xmm2,%%xmm1                   \n"
310    "psubusb   %%xmm3,%%xmm2                   \n"
311    "por       %%xmm2,%%xmm1                   \n"
312    "movdqa    %%xmm1,%%xmm2                   \n"
313    "punpcklbw %%xmm5,%%xmm1                   \n"
314    "punpckhbw %%xmm5,%%xmm2                   \n"
315    "pmaddwd   %%xmm1,%%xmm1                   \n"
316    "pmaddwd   %%xmm2,%%xmm2                   \n"
317    "paddd     %%xmm1,%%xmm0                   \n"
318    "paddd     %%xmm2,%%xmm0                   \n"
319    "jg        1b                              \n"
320
321    "pshufd    $0xee,%%xmm0,%%xmm1             \n"
322    "paddd     %%xmm1,%%xmm0                   \n"
323    "pshufd    $0x1,%%xmm0,%%xmm1              \n"
324    "paddd     %%xmm1,%%xmm0                   \n"
325    "movd      %%xmm0,%3                       \n"
326
327  : "+r"(src_a),      // %0
328    "+r"(src_b),      // %1
329    "+r"(count),      // %2
330    "=g"(sse)         // %3
331  :
332  : "memory", "cc"
333#if defined(__SSE2__)
334    , "xmm0", "xmm1", "xmm2", "xmm5"
335#endif
336  );
337  return sse;
338}
339#endif
340
341static uint32 SumSquareError_C(const uint8* src_a, const uint8* src_b,
342                               int count) {
343  uint32 sse = 0u;
344  for (int i = 0; i < count; ++i) {
345    int diff = src_a[i] - src_b[i];
346    sse += static_cast<uint32>(diff * diff);
347  }
348  return sse;
349}
350
351LIBYUV_API
352uint64 ComputeSumSquareError(const uint8* src_a, const uint8* src_b,
353                             int count) {
354  uint32 (*SumSquareError)(const uint8* src_a, const uint8* src_b, int count) =
355      SumSquareError_C;
356#if defined(HAS_SUMSQUAREERROR_NEON)
357  if (TestCpuFlag(kCpuHasNEON)) {
358    SumSquareError = SumSquareError_NEON;
359  }
360#elif defined(HAS_SUMSQUAREERROR_SSE2)
361  if (TestCpuFlag(kCpuHasSSE2) &&
362      IS_ALIGNED(src_a, 16) && IS_ALIGNED(src_b, 16)) {
363    // Note only used for multiples of 16 so count is not checked.
364    SumSquareError = SumSquareError_SSE2;
365  }
366#endif
367  // 32K values will fit a 32bit int return value from SumSquareError.
368  // After each block of 32K, accumulate into 64 bit int.
369  const int kBlockSize = 1 << 15;  // 32768;
370  uint64 sse = 0;
371#ifdef _OPENMP
372#pragma omp parallel for reduction(+: sse)
373#endif
374  for (int i = 0; i < (count - (kBlockSize - 1)); i += kBlockSize) {
375    sse += SumSquareError(src_a + i, src_b + i, kBlockSize);
376  }
377  src_a += count & ~(kBlockSize - 1);
378  src_b += count & ~(kBlockSize - 1);
379  int remainder = count & (kBlockSize - 1) & ~15;
380  if (remainder) {
381    sse += SumSquareError(src_a, src_b, remainder);
382    src_a += remainder;
383    src_b += remainder;
384  }
385  remainder = count & 15;
386  if (remainder) {
387    sse += SumSquareError_C(src_a, src_b, remainder);
388  }
389  return sse;
390}
391
392LIBYUV_API
393uint64 ComputeSumSquareErrorPlane(const uint8* src_a, int stride_a,
394                                  const uint8* src_b, int stride_b,
395                                  int width, int height) {
396  uint32 (*SumSquareError)(const uint8* src_a, const uint8* src_b, int count) =
397      SumSquareError_C;
398#if defined(HAS_SUMSQUAREERROR_NEON)
399  if (TestCpuFlag(kCpuHasNEON)) {
400    SumSquareError = SumSquareError_NEON;
401  }
402#elif defined(HAS_SUMSQUAREERROR_SSE2)
403  if (TestCpuFlag(kCpuHasSSE2) && IS_ALIGNED(width, 16) &&
404      IS_ALIGNED(src_a, 16) && IS_ALIGNED(stride_a, 16) &&
405      IS_ALIGNED(src_b, 16) && IS_ALIGNED(stride_b, 16)) {
406    SumSquareError = SumSquareError_SSE2;
407  }
408#endif
409
410  uint64 sse = 0;
411  for (int h = 0; h < height; ++h) {
412    sse += SumSquareError(src_a, src_b, width);
413    src_a += stride_a;
414    src_b += stride_b;
415  }
416
417  return sse;
418}
419
420LIBYUV_API
421double SumSquareErrorToPsnr(uint64 sse, uint64 count) {
422  double psnr;
423  if (sse > 0) {
424    double mse = static_cast<double>(count) / static_cast<double>(sse);
425    psnr = 10.0 * log10(255.0 * 255.0 * mse);
426  } else {
427    psnr = kMaxPsnr;      // Limit to prevent divide by 0
428  }
429
430  if (psnr > kMaxPsnr)
431    psnr = kMaxPsnr;
432
433  return psnr;
434}
435
436LIBYUV_API
437double CalcFramePsnr(const uint8* src_a, int stride_a,
438                     const uint8* src_b, int stride_b,
439                     int width, int height) {
440  const uint64 samples = width * height;
441  const uint64 sse = ComputeSumSquareErrorPlane(src_a, stride_a,
442                                                src_b, stride_b,
443                                                width, height);
444  return SumSquareErrorToPsnr(sse, samples);
445}
446
447LIBYUV_API
448double I420Psnr(const uint8* src_y_a, int stride_y_a,
449                const uint8* src_u_a, int stride_u_a,
450                const uint8* src_v_a, int stride_v_a,
451                const uint8* src_y_b, int stride_y_b,
452                const uint8* src_u_b, int stride_u_b,
453                const uint8* src_v_b, int stride_v_b,
454                int width, int height) {
455  const uint64 sse_y = ComputeSumSquareErrorPlane(src_y_a, stride_y_a,
456                                                  src_y_b, stride_y_b,
457                                                  width, height);
458  const int width_uv = (width + 1) >> 1;
459  const int height_uv = (height + 1) >> 1;
460  const uint64 sse_u = ComputeSumSquareErrorPlane(src_u_a, stride_u_a,
461                                                  src_u_b, stride_u_b,
462                                                  width_uv, height_uv);
463  const uint64 sse_v = ComputeSumSquareErrorPlane(src_v_a, stride_v_a,
464                                                  src_v_b, stride_v_b,
465                                                  width_uv, height_uv);
466  const uint64 samples = width * height + 2 * (width_uv * height_uv);
467  const uint64 sse = sse_y + sse_u + sse_v;
468  return SumSquareErrorToPsnr(sse, samples);
469}
470
471static const int64 cc1 =  26634;  // (64^2*(.01*255)^2
472static const int64 cc2 = 239708;  // (64^2*(.03*255)^2
473
474static double Ssim8x8_C(const uint8* src_a, int stride_a,
475                        const uint8* src_b, int stride_b) {
476  int64 sum_a = 0;
477  int64 sum_b = 0;
478  int64 sum_sq_a = 0;
479  int64 sum_sq_b = 0;
480  int64 sum_axb = 0;
481
482  for (int i = 0; i < 8; ++i) {
483    for (int j = 0; j < 8; ++j) {
484      sum_a += src_a[j];
485      sum_b += src_b[j];
486      sum_sq_a += src_a[j] * src_a[j];
487      sum_sq_b += src_b[j] * src_b[j];
488      sum_axb += src_a[j] * src_b[j];
489    }
490
491    src_a += stride_a;
492    src_b += stride_b;
493  }
494
495  const int64 count = 64;
496  // scale the constants by number of pixels
497  const int64 c1 = (cc1 * count * count) >> 12;
498  const int64 c2 = (cc2 * count * count) >> 12;
499
500  const int64 sum_a_x_sum_b = sum_a * sum_b;
501
502  const int64 ssim_n = (2 * sum_a_x_sum_b + c1) *
503                       (2 * count * sum_axb - 2 * sum_a_x_sum_b + c2);
504
505  const int64 sum_a_sq = sum_a*sum_a;
506  const int64 sum_b_sq = sum_b*sum_b;
507
508  const int64 ssim_d = (sum_a_sq + sum_b_sq + c1) *
509                       (count * sum_sq_a - sum_a_sq +
510                        count * sum_sq_b - sum_b_sq + c2);
511
512  if (ssim_d == 0.0)
513    return DBL_MAX;
514  return ssim_n * 1.0 / ssim_d;
515}
516
517// We are using a 8x8 moving window with starting location of each 8x8 window
518// on the 4x4 pixel grid. Such arrangement allows the windows to overlap
519// block boundaries to penalize blocking artifacts.
520LIBYUV_API
521double CalcFrameSsim(const uint8* src_a, int stride_a,
522                     const uint8* src_b, int stride_b,
523                     int width, int height) {
524  int samples = 0;
525  double ssim_total = 0;
526
527  double (*Ssim8x8)(const uint8* src_a, int stride_a,
528                    const uint8* src_b, int stride_b);
529
530  Ssim8x8 = Ssim8x8_C;
531
532  // sample point start with each 4x4 location
533  for (int i = 0; i < height - 8; i += 4) {
534    for (int j = 0; j < width - 8; j += 4) {
535      ssim_total += Ssim8x8(src_a + j, stride_a, src_b + j, stride_b);
536      samples++;
537    }
538
539    src_a += stride_a * 4;
540    src_b += stride_b * 4;
541  }
542
543  ssim_total /= samples;
544  return ssim_total;
545}
546
547LIBYUV_API
548double I420Ssim(const uint8* src_y_a, int stride_y_a,
549                const uint8* src_u_a, int stride_u_a,
550                const uint8* src_v_a, int stride_v_a,
551                const uint8* src_y_b, int stride_y_b,
552                const uint8* src_u_b, int stride_u_b,
553                const uint8* src_v_b, int stride_v_b,
554                int width, int height) {
555  const double ssim_y = CalcFrameSsim(src_y_a, stride_y_a,
556                                      src_y_b, stride_y_b, width, height);
557  const int width_uv = (width + 1) >> 1;
558  const int height_uv = (height + 1) >> 1;
559  const double ssim_u = CalcFrameSsim(src_u_a, stride_u_a,
560                                      src_u_b, stride_u_b,
561                                      width_uv, height_uv);
562  const double ssim_v = CalcFrameSsim(src_v_a, stride_v_a,
563                                      src_v_b, stride_v_b,
564                                      width_uv, height_uv);
565  return ssim_y * 0.8 + 0.1 * (ssim_u + ssim_v);
566}
567
568#ifdef __cplusplus
569}  // extern "C"
570}  // namespace libyuv
571#endif
572