From 9675ff0714df15e433dbe78d6e40c2430c21b519 Mon Sep 17 00:00:00 2001
From: Kieran Kunhya <kieran@kunhya.com>
Date: Wed, 27 Dec 2017 01:08:39 +0000
Subject: [PATCH] simple_idct: Template functions to support an input bitdepth
 parameter

---
 libavcodec/bit_depth_template.c   | 17 +++++++++++++++-
 libavcodec/idctdsp.c              | 18 ++++++++--------
 libavcodec/me_cmp.c               |  2 +-
 libavcodec/simple_idct.c          | 15 +++++++++++---
 libavcodec/simple_idct.h          | 24 +++++++++++++---------
 libavcodec/simple_idct_template.c | 43 +++++++++++++++++++++++++--------------
 libavcodec/vc1.c                  |  4 ++--
 7 files changed, 82 insertions(+), 41 deletions(-)

diff --git a/libavcodec/bit_depth_template.c b/libavcodec/bit_depth_template.c
index 8018489..bd7237f 100644
--- a/libavcodec/bit_depth_template.c
+++ b/libavcodec/bit_depth_template.c
@@ -29,6 +29,7 @@
 #   undef pixel2
 #   undef pixel4
 #   undef dctcoef
+#   undef idctin
 #   undef INIT_CLIP
 #   undef no_rnd_avg_pixel4
 #   undef rnd_avg_pixel4
@@ -53,6 +54,16 @@
 #   define pixel4 uint64_t
 #   define dctcoef int32_t
 
+#ifdef IN_IDCT_DEPTH
+#if IN_IDCT_DEPTH == 32
+#   define idctin int32_t
+#else
+#   define idctin int16_t
+#endif
+#else
+#   define idctin int16_t
+#endif
+
 #   define INIT_CLIP
 #   define no_rnd_avg_pixel4 no_rnd_avg64
 #   define    rnd_avg_pixel4    rnd_avg64
@@ -71,6 +82,7 @@
 #   define pixel2 uint16_t
 #   define pixel4 uint32_t
 #   define dctcoef int16_t
+#   define idctin  int16_t
 
 #   define INIT_CLIP
 #   define no_rnd_avg_pixel4 no_rnd_avg32
@@ -87,7 +99,10 @@
 #   define CLIP(a) av_clip_uint8(a)
 #endif
 
-#define FUNC3(a, b, c)  a ## _ ## b ## c
+#define FUNC3(a, b, c)  a ## _ ## b ##  c
 #define FUNC2(a, b, c)  FUNC3(a, b, c)
 #define FUNC(a)  FUNC2(a, BIT_DEPTH,)
 #define FUNCC(a) FUNC2(a, BIT_DEPTH, _c)
+#define FUNC4(a, b, c)  a ## _ ## b ## _ ## c
+#define FUNC5(a, b, c)  FUNC4(a, b, c)
+#define FUNC6(a)  FUNC5(a, IN_IDCT_DEPTH, BIT_DEPTH)
\ No newline at end of file
diff --git a/libavcodec/idctdsp.c b/libavcodec/idctdsp.c
index 0ff74d8..16703aa 100644
--- a/libavcodec/idctdsp.c
+++ b/libavcodec/idctdsp.c
@@ -256,14 +256,14 @@ av_cold void ff_idctdsp_init(IDCTDSPContext *c, AVCodecContext *avctx)
         c->perm_type = FF_IDCT_PERM_NONE;
     } else {
         if (avctx->bits_per_raw_sample == 10 || avctx->bits_per_raw_sample == 9) {
-            c->idct_put              = ff_simple_idct_put_10;
-            c->idct_add              = ff_simple_idct_add_10;
-            c->idct                  = ff_simple_idct_10;
+            c->idct_put              = ff_simple_idct_put_16_10;
+            c->idct_add              = ff_simple_idct_add_16_10;
+            c->idct                  = ff_simple_idct_16_10;
             c->perm_type             = FF_IDCT_PERM_NONE;
         } else if (avctx->bits_per_raw_sample == 12) {
-            c->idct_put              = ff_simple_idct_put_12;
-            c->idct_add              = ff_simple_idct_add_12;
-            c->idct                  = ff_simple_idct_12;
+            c->idct_put              = ff_simple_idct_put_16_12;
+            c->idct_add              = ff_simple_idct_add_16_12;
+            c->idct                  = ff_simple_idct_16_12;
             c->perm_type             = FF_IDCT_PERM_NONE;
         } else {
             if (avctx->idct_algo == FF_IDCT_INT) {
@@ -280,9 +280,9 @@ av_cold void ff_idctdsp_init(IDCTDSPContext *c, AVCodecContext *avctx)
 #endif /* CONFIG_FAANIDCT */
             } else { // accurate/default
                 /* Be sure FF_IDCT_NONE will select this one, since it uses FF_IDCT_PERM_NONE */
-                c->idct_put  = ff_simple_idct_put_8;
-                c->idct_add  = ff_simple_idct_add_8;
-                c->idct      = ff_simple_idct_8;
+                c->idct_put  = ff_simple_idct_put_16_8;
+                c->idct_add  = ff_simple_idct_add_16_8;
+                c->idct      = ff_simple_idct_16_8;
                 c->perm_type = FF_IDCT_PERM_NONE;
             }
         }
diff --git a/libavcodec/me_cmp.c b/libavcodec/me_cmp.c
index 5e34a11..eb9d76c 100644
--- a/libavcodec/me_cmp.c
+++ b/libavcodec/me_cmp.c
@@ -721,7 +721,7 @@ static int quant_psnr8x8_c(MpegEncContext *s, uint8_t *src1,
     s->block_last_index[0 /* FIXME */] =
         s->fast_dct_quantize(s, temp, 0 /* FIXME */, s->qscale, &i);
     s->dct_unquantize_inter(s, temp, 0, s->qscale);
-    ff_simple_idct_8(temp); // FIXME
+    ff_simple_idct_16_8(temp); // FIXME
 
     for (i = 0; i < 64; i++)
         sum += (temp[i] - bak[i]) * (temp[i] - bak[i]);
diff --git a/libavcodec/simple_idct.c b/libavcodec/simple_idct.c
index 1d05b2f..dfe2a8c 100644
--- a/libavcodec/simple_idct.c
+++ b/libavcodec/simple_idct.c
@@ -30,6 +30,8 @@
 #include "mathops.h"
 #include "simple_idct.h"
 
+#define IN_IDCT_DEPTH 16
+
 #define BIT_DEPTH 8
 #include "simple_idct_template.c"
 #undef BIT_DEPTH
@@ -46,6 +48,13 @@
 #define BIT_DEPTH 12
 #include "simple_idct_template.c"
 #undef BIT_DEPTH
+#undef IN_IDCT_DEPTH
+
+#define IN_IDCT_DEPTH 32
+#define BIT_DEPTH 10
+#include "simple_idct_template.c"
+#undef BIT_DEPTH
+#undef IN_IDCT_DEPTH
 
 /* 2x4x8 idct */
 
@@ -115,7 +124,7 @@ void ff_simple_idct248_put(uint8_t *dest, ptrdiff_t line_size, int16_t *block)
 
     /* IDCT8 on each line */
     for(i=0; i<8; i++) {
-        idctRowCondDC_8(block + i*8, 0);
+        idctRowCondDC_16_8(block + i*8, 0);
     }
 
     /* IDCT4 and store */
@@ -188,7 +197,7 @@ void ff_simple_idct84_add(uint8_t *dest, ptrdiff_t line_size, int16_t *block)
 
     /* IDCT8 on each line */
     for(i=0; i<4; i++) {
-        idctRowCondDC_8(block + i*8, 0);
+        idctRowCondDC_16_8(block + i*8, 0);
     }
 
     /* IDCT4 and store */
@@ -208,7 +217,7 @@ void ff_simple_idct48_add(uint8_t *dest, ptrdiff_t line_size, int16_t *block)
 
     /* IDCT8 and store */
     for(i=0; i<4; i++){
-        idctSparseColAdd_8(dest + i, line_size, block + i);
+        idctSparseColAdd_16_8(dest + i, line_size, block + i);
     }
 }
 
diff --git a/libavcodec/simple_idct.h b/libavcodec/simple_idct.h
index 2a5e1d7..67761e1 100644
--- a/libavcodec/simple_idct.h
+++ b/libavcodec/simple_idct.h
@@ -31,20 +31,24 @@
 #include <stddef.h>
 #include <stdint.h>
 
-void ff_simple_idct_put_8(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
-void ff_simple_idct_add_8(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
-void ff_simple_idct_8(int16_t *block);
+void ff_simple_idct_put_16_8(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
+void ff_simple_idct_add_16_8(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
+void ff_simple_idct_16_8(int16_t *block);
 
-void ff_simple_idct_put_10(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
-void ff_simple_idct_add_10(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
-void ff_simple_idct_10(int16_t *block);
+void ff_simple_idct_put_16_10(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
+void ff_simple_idct_add_16_10(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
+void ff_simple_idct_16_10(int16_t *block);
 
-void ff_simple_idct_put_12(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
-void ff_simple_idct_add_12(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
-void ff_simple_idct_12(int16_t *block);
+void ff_simple_idct_put_32_10(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
+void ff_simple_idct_add_32_10(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
+void ff_simple_idct_32_10(int16_t *block);
+
+void ff_simple_idct_put_16_12(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
+void ff_simple_idct_add_16_12(uint8_t *dest, ptrdiff_t line_size, int16_t *block);
+void ff_simple_idct_16_12(int16_t *block);
 
 /**
- * Special version of ff_simple_idct_10() which does dequantization
+ * Special version of ff_simple_idct_16_10() which does dequantization
  * and scales by a factor of 2 more between the two IDCTs to account
  * for larger scale of input coefficients.
  */
diff --git a/libavcodec/simple_idct_template.c b/libavcodec/simple_idct_template.c
index f532313..8d60b50 100644
--- a/libavcodec/simple_idct_template.c
+++ b/libavcodec/simple_idct_template.c
@@ -77,6 +77,10 @@
 #define ROW_SHIFT 13
 #define COL_SHIFT 18
 #define DC_SHIFT  1
+#   elif IN_IDCT_DEPTH == 32
+#define ROW_SHIFT 13
+#define COL_SHIFT 21
+#define DC_SHIFT  2
 #   else
 #define ROW_SHIFT 12
 #define COL_SHIFT 19
@@ -109,11 +113,12 @@
 #ifdef EXTRA_SHIFT
 static inline void FUNC(idctRowCondDC_extrashift)(int16_t *row, int extra_shift)
 #else
-static inline void FUNC(idctRowCondDC)(int16_t *row, int extra_shift)
+static inline void FUNC6(idctRowCondDC)(idctin *row, int extra_shift)
 #endif
 {
     SUINT a0, a1, a2, a3, b0, b1, b2, b3;
 
+#if IN_IDCT_DEPTH == 16
 #if HAVE_FAST_64BIT
 #define ROW0_MASK (0xffffLL << 48 * HAVE_BIGENDIAN)
     if (((AV_RN64A(row) & ~ROW0_MASK) | AV_RN64A(row+4)) == 0) {
@@ -148,6 +153,7 @@ static inline void FUNC(idctRowCondDC)(int16_t *row, int extra_shift)
         return;
     }
 #endif
+#endif
 
     a0 = (W4 * row[0]) + (1 << (ROW_SHIFT + extra_shift - 1));
     a1 = a0;
@@ -168,7 +174,11 @@ static inline void FUNC(idctRowCondDC)(int16_t *row, int extra_shift)
     b3 = MUL(W7, row[1]);
     MAC(b3, -W5, row[3]);
 
+#if IN_IDCT_DEPTH == 32
+    if (1) {
+#else
     if (AV_RN64A(row + 4)) {
+#endif
         a0 +=   W4*row[4] + W6*row[6];
         a1 += - W4*row[4] - W2*row[6];
         a2 += - W4*row[4] + W2*row[6];
@@ -250,8 +260,8 @@ static inline void FUNC(idctRowCondDC)(int16_t *row, int extra_shift)
 #ifdef EXTRA_SHIFT
 static inline void FUNC(idctSparseCol_extrashift)(int16_t *col)
 #else
-static inline void FUNC(idctSparseColPut)(pixel *dest, ptrdiff_t line_size,
-                                          int16_t *col)
+static inline void FUNC6(idctSparseColPut)(pixel *dest, ptrdiff_t line_size,
+                                          idctin *col)
 {
     SUINT a0, a1, a2, a3, b0, b1, b2, b3;
 
@@ -274,8 +284,8 @@ static inline void FUNC(idctSparseColPut)(pixel *dest, ptrdiff_t line_size,
     dest[0] = av_clip_pixel((int)(a0 - b0) >> COL_SHIFT);
 }
 
-static inline void FUNC(idctSparseColAdd)(pixel *dest, ptrdiff_t line_size,
-                                          int16_t *col)
+static inline void FUNC6(idctSparseColAdd)(pixel *dest, ptrdiff_t line_size,
+                                          idctin *col)
 {
     int a0, a1, a2, a3, b0, b1, b2, b3;
 
@@ -298,7 +308,7 @@ static inline void FUNC(idctSparseColAdd)(pixel *dest, ptrdiff_t line_size,
     dest[0] = av_clip_pixel(dest[0] + ((a0 - b0) >> COL_SHIFT));
 }
 
-static inline void FUNC(idctSparseCol)(int16_t *col)
+static inline void FUNC6(idctSparseCol)(idctin *col)
 #endif
 {
     int a0, a1, a2, a3, b0, b1, b2, b3;
@@ -316,21 +326,23 @@ static inline void FUNC(idctSparseCol)(int16_t *col)
 }
 
 #ifndef EXTRA_SHIFT
-void FUNC(ff_simple_idct_put)(uint8_t *dest_, ptrdiff_t line_size, int16_t *block)
+void FUNC6(ff_simple_idct_put)(uint8_t *dest_, ptrdiff_t line_size, int16_t *block_)
 {
+    idctin *block = (idctin *)block_;
     pixel *dest = (pixel *)dest_;
     int i;
 
     line_size /= sizeof(pixel);
 
     for (i = 0; i < 8; i++)
-        FUNC(idctRowCondDC)(block + i*8, 0);
+        FUNC6(idctRowCondDC)(block + i*8, 0);
 
     for (i = 0; i < 8; i++)
-        FUNC(idctSparseColPut)(dest + i, line_size, block + i);
+        FUNC6(idctSparseColPut)(dest + i, line_size, block + i);
 }
 
-void FUNC(ff_simple_idct_add)(uint8_t *dest_, ptrdiff_t line_size, int16_t *block)
+#if IN_IDCT_DEPTH == 16
+void FUNC6(ff_simple_idct_add)(uint8_t *dest_, ptrdiff_t line_size, int16_t *block)
 {
     pixel *dest = (pixel *)dest_;
     int i;
@@ -338,20 +350,21 @@ void FUNC(ff_simple_idct_add)(uint8_t *dest_, ptrdiff_t line_size, int16_t *bloc
     line_size /= sizeof(pixel);
 
     for (i = 0; i < 8; i++)
-        FUNC(idctRowCondDC)(block + i*8, 0);
+        FUNC6(idctRowCondDC)(block + i*8, 0);
 
     for (i = 0; i < 8; i++)
-        FUNC(idctSparseColAdd)(dest + i, line_size, block + i);
+        FUNC6(idctSparseColAdd)(dest + i, line_size, block + i);
 }
 
-void FUNC(ff_simple_idct)(int16_t *block)
+void FUNC6(ff_simple_idct)(int16_t *block)
 {
     int i;
 
     for (i = 0; i < 8; i++)
-        FUNC(idctRowCondDC)(block + i*8, 0);
+        FUNC6(idctRowCondDC)(block + i*8, 0);
 
     for (i = 0; i < 8; i++)
-        FUNC(idctSparseCol)(block + i);
+        FUNC6(idctSparseCol)(block + i);
 }
 #endif
+#endif
diff --git a/libavcodec/vc1.c b/libavcodec/vc1.c
index 48a2cc1..62a2bf5 100644
--- a/libavcodec/vc1.c
+++ b/libavcodec/vc1.c
@@ -314,11 +314,11 @@ int ff_vc1_decode_sequence_header(AVCodecContext *avctx, VC1Context *v, GetBitCo
     v->multires        = get_bits1(gb);
     v->res_fasttx      = get_bits1(gb);
     if (!v->res_fasttx) {
-        v->vc1dsp.vc1_inv_trans_8x8    = ff_simple_idct_8;
+        v->vc1dsp.vc1_inv_trans_8x8    = ff_simple_idct_16_8;
         v->vc1dsp.vc1_inv_trans_8x4    = ff_simple_idct84_add;
         v->vc1dsp.vc1_inv_trans_4x8    = ff_simple_idct48_add;
         v->vc1dsp.vc1_inv_trans_4x4    = ff_simple_idct44_add;
-        v->vc1dsp.vc1_inv_trans_8x8_dc = ff_simple_idct_add_8;
+        v->vc1dsp.vc1_inv_trans_8x8_dc = ff_simple_idct_add_16_8;
         v->vc1dsp.vc1_inv_trans_8x4_dc = ff_simple_idct84_add;
         v->vc1dsp.vc1_inv_trans_4x8_dc = ff_simple_idct48_add;
         v->vc1dsp.vc1_inv_trans_4x4_dc = ff_simple_idct44_add;
-- 
1.9.1

