Author: xuefu
Date: Mon Dec  9 21:09:01 2013
New Revision: 1549679

URL: http://svn.apache.org/r1549679
Log:
HIVE-5872: Make UDAFs such as GenericUDAFSum report accurate precision/scale 
for decimal types (reviewed by Sergey Shelukhin)

Modified:
    
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
    
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java
    hive/trunk/ql/src/test/queries/clientpositive/decimal_precision.q
    hive/trunk/ql/src/test/queries/clientpositive/decimal_udf.q
    hive/trunk/ql/src/test/results/clientpositive/decimal_precision.q.out
    hive/trunk/ql/src/test/results/clientpositive/decimal_udf.q.out
    
hive/trunk/serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/LazyBinarySerDe.java

Modified: 
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
URL: 
http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java?rev=1549679&r1=1549678&r2=1549679&view=diff
==============================================================================
--- 
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
 (original)
+++ 
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFAverage.java
 Mon Dec  9 21:09:01 2013
@@ -40,8 +40,10 @@ import org.apache.hadoop.hive.serde2.obj
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.util.StringUtils;
 
@@ -65,7 +67,7 @@ public class GenericUDAFAverage extends 
     if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
       throw new UDFArgumentTypeException(0,
           "Only primitive type arguments are accepted but "
-          + parameters[0].getTypeName() + " is passed.");
+              + parameters[0].getTypeName() + " is passed.");
     }
     switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
     case BYTE:
@@ -84,7 +86,7 @@ public class GenericUDAFAverage extends 
     default:
       throw new UDFArgumentTypeException(0,
           "Only numeric or string type arguments are accepted but "
-          + parameters[0].getTypeName() + " is passed.");
+              + parameters[0].getTypeName() + " is passed.");
     }
   }
 
@@ -160,11 +162,29 @@ public class GenericUDAFAverage extends 
 
     @Override
     protected ObjectInspector getSumFieldJavaObjectInspector() {
-      return PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector;
+      DecimalTypeInfo typeInfo = deriveResultDecimalTypeInfo();
+      return 
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(typeInfo);
     }
+
     @Override
     protected ObjectInspector getSumFieldWritableObjectInspector() {
-      return 
PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector;
+      DecimalTypeInfo typeInfo = deriveResultDecimalTypeInfo();
+      return 
PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(typeInfo);
+    }
+
+    /**
+     * The result type has the same number of integer digits and 4 more 
decimal digits.
+     */
+    private DecimalTypeInfo deriveResultDecimalTypeInfo() {
+      if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+        int scale = inputOI.scale();
+        int intPart = inputOI.precision() - scale;
+        scale = Math.min(scale + 4, HiveDecimal.MAX_SCALE - intPart);
+        return TypeInfoFactory.getDecimalTypeInfo(intPart + scale, scale);
+      } else {
+        PrimitiveObjectInspector sfOI = (PrimitiveObjectInspector) sumFieldOI;
+        return (DecimalTypeInfo) sfOI.getTypeInfo();
+      }
     }
 
     @Override
@@ -231,13 +251,13 @@ public class GenericUDAFAverage extends 
   public static abstract class AbstractGenericUDAFAverageEvaluator<TYPE> 
extends GenericUDAFEvaluator {
 
     // For PARTIAL1 and COMPLETE
-    private transient PrimitiveObjectInspector inputOI;
+    protected transient PrimitiveObjectInspector inputOI;
     // For PARTIAL2 and FINAL
     private transient StructObjectInspector soi;
     private transient StructField countField;
     private transient StructField sumField;
     private LongObjectInspector countFieldOI;
-    private ObjectInspector sumFieldOI;
+    protected ObjectInspector sumFieldOI;
     // For PARTIAL1 and PARTIAL2
     protected transient Object[] partialResult;
 

Modified: 
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java
URL: 
http://svn.apache.org/viewvc/hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java?rev=1549679&r1=1549678&r2=1549679&view=diff
==============================================================================
--- 
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java
 (original)
+++ 
hive/trunk/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java
 Mon Dec  9 21:09:01 2013
@@ -24,6 +24,7 @@ import org.apache.hadoop.hive.ql.exec.De
 import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode;
 import org.apache.hadoop.hive.ql.util.JavaDataModel;
 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
@@ -31,8 +32,10 @@ import org.apache.hadoop.hive.serde2.obj
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.util.StringUtils;
 
@@ -94,7 +97,16 @@ public class GenericUDAFSum extends Abst
       super.init(m, parameters);
       result = new HiveDecimalWritable(HiveDecimal.ZERO);
       inputOI = (PrimitiveObjectInspector) parameters[0];
-      return 
PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector;
+      // The output precision is 10 greater than the input which should cover 
at least
+      // 10b rows. The scale is the same as the input.
+      DecimalTypeInfo outputTypeInfo = null;
+      if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
+        int precision = Math.min(HiveDecimal.MAX_PRECISION, 
inputOI.precision() + 10);
+        outputTypeInfo = TypeInfoFactory.getDecimalTypeInfo(precision, 
inputOI.scale());
+      } else {
+        outputTypeInfo = (DecimalTypeInfo) inputOI.getTypeInfo();
+      }
+      return 
PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(outputTypeInfo);
     }
 
     /** class for storing decimal sum value. */

Modified: hive/trunk/ql/src/test/queries/clientpositive/decimal_precision.q
URL: 
http://svn.apache.org/viewvc/hive/trunk/ql/src/test/queries/clientpositive/decimal_precision.q?rev=1549679&r1=1549678&r2=1549679&view=diff
==============================================================================
--- hive/trunk/ql/src/test/queries/clientpositive/decimal_precision.q (original)
+++ hive/trunk/ql/src/test/queries/clientpositive/decimal_precision.q Mon Dec  
9 21:09:01 2013
@@ -15,6 +15,7 @@ SELECT dec, dec / 9 FROM DECIMAL_PRECISI
 SELECT dec, dec / 27 FROM DECIMAL_PRECISION ORDER BY dec;
 SELECT dec, dec * dec FROM DECIMAL_PRECISION ORDER BY dec;
 
+EXPLAIN SELECT avg(dec), sum(dec) FROM DECIMAL_PRECISION;
 SELECT avg(dec), sum(dec) FROM DECIMAL_PRECISION;
 
 SELECT dec * cast('12345678901234567890.12345678' as decimal(38,18)) FROM 
DECIMAL_PRECISION LIMIT 1;

Modified: hive/trunk/ql/src/test/queries/clientpositive/decimal_udf.q
URL: 
http://svn.apache.org/viewvc/hive/trunk/ql/src/test/queries/clientpositive/decimal_udf.q?rev=1549679&r1=1549678&r2=1549679&view=diff
==============================================================================
--- hive/trunk/ql/src/test/queries/clientpositive/decimal_udf.q (original)
+++ hive/trunk/ql/src/test/queries/clientpositive/decimal_udf.q Mon Dec  9 
21:09:01 2013
@@ -72,8 +72,8 @@ EXPLAIN SELECT abs(key) FROM DECIMAL_UDF
 SELECT abs(key) FROM DECIMAL_UDF;
 
 -- avg
-EXPLAIN SELECT value, sum(key) / count(key), avg(key) FROM DECIMAL_UDF GROUP 
BY value ORDER BY value;
-SELECT value, sum(key) / count(key), avg(key) FROM DECIMAL_UDF GROUP BY value 
ORDER BY value;
+EXPLAIN SELECT value, sum(key) / count(key), avg(key), sum(key) FROM 
DECIMAL_UDF GROUP BY value ORDER BY value;
+SELECT value, sum(key) / count(key), avg(key), sum(key) FROM DECIMAL_UDF GROUP 
BY value ORDER BY value;
 
 -- negative
 EXPLAIN SELECT -key FROM DECIMAL_UDF;

Modified: hive/trunk/ql/src/test/results/clientpositive/decimal_precision.q.out
URL: 
http://svn.apache.org/viewvc/hive/trunk/ql/src/test/results/clientpositive/decimal_precision.q.out?rev=1549679&r1=1549678&r2=1549679&view=diff
==============================================================================
--- hive/trunk/ql/src/test/results/clientpositive/decimal_precision.q.out 
(original)
+++ hive/trunk/ql/src/test/results/clientpositive/decimal_precision.q.out Mon 
Dec  9 21:09:01 2013
@@ -517,6 +517,72 @@ NULL       NULL
 123456789.0123456789   15241578753238836.75019051998750190521
 1234567890.123456      NULL
 1234567890.123456789   NULL
+PREHOOK: query: EXPLAIN SELECT avg(dec), sum(dec) FROM DECIMAL_PRECISION
+PREHOOK: type: QUERY
+POSTHOOK: query: EXPLAIN SELECT avg(dec), sum(dec) FROM DECIMAL_PRECISION
+POSTHOOK: type: QUERY
+ABSTRACT SYNTAX TREE:
+  (TOK_QUERY (TOK_FROM (TOK_TABREF (TOK_TABNAME DECIMAL_PRECISION))) 
(TOK_INSERT (TOK_DESTINATION (TOK_DIR TOK_TMP_FILE)) (TOK_SELECT (TOK_SELEXPR 
(TOK_FUNCTION avg (TOK_TABLE_OR_COL dec))) (TOK_SELEXPR (TOK_FUNCTION sum 
(TOK_TABLE_OR_COL dec))))))
+
+STAGE DEPENDENCIES:
+  Stage-1 is a root stage
+  Stage-0 is a root stage
+
+STAGE PLANS:
+  Stage: Stage-1
+    Map Reduce
+      Alias -> Map Operator Tree:
+        decimal_precision 
+          TableScan
+            alias: decimal_precision
+            Select Operator
+              expressions:
+                    expr: dec
+                    type: decimal(20,10)
+              outputColumnNames: dec
+              Group By Operator
+                aggregations:
+                      expr: avg(dec)
+                      expr: sum(dec)
+                bucketGroup: false
+                mode: hash
+                outputColumnNames: _col0, _col1
+                Reduce Output Operator
+                  sort order: 
+                  tag: -1
+                  value expressions:
+                        expr: _col0
+                        type: struct<count:bigint,sum:decimal(24,14)>
+                        expr: _col1
+                        type: decimal(30,10)
+      Reduce Operator Tree:
+        Group By Operator
+          aggregations:
+                expr: avg(VALUE._col0)
+                expr: sum(VALUE._col1)
+          bucketGroup: false
+          mode: mergepartial
+          outputColumnNames: _col0, _col1
+          Select Operator
+            expressions:
+                  expr: _col0
+                  type: decimal(24,14)
+                  expr: _col1
+                  type: decimal(30,10)
+            outputColumnNames: _col0, _col1
+            File Output Operator
+              compressed: false
+              GlobalTableId: 0
+              table:
+                  input format: org.apache.hadoop.mapred.TextInputFormat
+                  output format: 
org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat
+                  serde: org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe
+
+  Stage: Stage-0
+    Fetch Operator
+      limit: -1
+
+
 PREHOOK: query: SELECT avg(dec), sum(dec) FROM DECIMAL_PRECISION
 PREHOOK: type: QUERY
 PREHOOK: Input: default@decimal_precision
@@ -525,7 +591,7 @@ POSTHOOK: query: SELECT avg(dec), sum(de
 POSTHOOK: type: QUERY
 POSTHOOK: Input: default@decimal_precision
 #### A masked pattern was here ####
-88499534.575865762206451613    2743485571.8518386284
+88499534.57586576220645        2743485571.8518386284
 PREHOOK: query: SELECT dec * cast('12345678901234567890.12345678' as 
decimal(38,18)) FROM DECIMAL_PRECISION LIMIT 1
 PREHOOK: type: QUERY
 PREHOOK: Input: default@decimal_precision

Modified: hive/trunk/ql/src/test/results/clientpositive/decimal_udf.q.out
URL: 
http://svn.apache.org/viewvc/hive/trunk/ql/src/test/results/clientpositive/decimal_udf.q.out?rev=1549679&r1=1549678&r2=1549679&view=diff
==============================================================================
--- hive/trunk/ql/src/test/results/clientpositive/decimal_udf.q.out (original)
+++ hive/trunk/ql/src/test/results/clientpositive/decimal_udf.q.out Mon Dec  9 
21:09:01 2013
@@ -1268,13 +1268,13 @@ NULL
 1234567890.123456789
 1234567890.12345678
 PREHOOK: query: -- avg
-EXPLAIN SELECT value, sum(key) / count(key), avg(key) FROM DECIMAL_UDF GROUP 
BY value ORDER BY value
+EXPLAIN SELECT value, sum(key) / count(key), avg(key), sum(key) FROM 
DECIMAL_UDF GROUP BY value ORDER BY value
 PREHOOK: type: QUERY
 POSTHOOK: query: -- avg
-EXPLAIN SELECT value, sum(key) / count(key), avg(key) FROM DECIMAL_UDF GROUP 
BY value ORDER BY value
+EXPLAIN SELECT value, sum(key) / count(key), avg(key), sum(key) FROM 
DECIMAL_UDF GROUP BY value ORDER BY value
 POSTHOOK: type: QUERY
 ABSTRACT SYNTAX TREE:
-  (TOK_QUERY (TOK_FROM (TOK_TABREF (TOK_TABNAME DECIMAL_UDF))) (TOK_INSERT 
(TOK_DESTINATION (TOK_DIR TOK_TMP_FILE)) (TOK_SELECT (TOK_SELEXPR 
(TOK_TABLE_OR_COL value)) (TOK_SELEXPR (/ (TOK_FUNCTION sum (TOK_TABLE_OR_COL 
key)) (TOK_FUNCTION count (TOK_TABLE_OR_COL key)))) (TOK_SELEXPR (TOK_FUNCTION 
avg (TOK_TABLE_OR_COL key)))) (TOK_GROUPBY (TOK_TABLE_OR_COL value)) 
(TOK_ORDERBY (TOK_TABSORTCOLNAMEASC (TOK_TABLE_OR_COL value)))))
+  (TOK_QUERY (TOK_FROM (TOK_TABREF (TOK_TABNAME DECIMAL_UDF))) (TOK_INSERT 
(TOK_DESTINATION (TOK_DIR TOK_TMP_FILE)) (TOK_SELECT (TOK_SELEXPR 
(TOK_TABLE_OR_COL value)) (TOK_SELEXPR (/ (TOK_FUNCTION sum (TOK_TABLE_OR_COL 
key)) (TOK_FUNCTION count (TOK_TABLE_OR_COL key)))) (TOK_SELEXPR (TOK_FUNCTION 
avg (TOK_TABLE_OR_COL key))) (TOK_SELEXPR (TOK_FUNCTION sum (TOK_TABLE_OR_COL 
key)))) (TOK_GROUPBY (TOK_TABLE_OR_COL value)) (TOK_ORDERBY 
(TOK_TABSORTCOLNAMEASC (TOK_TABLE_OR_COL value)))))
 
 STAGE DEPENDENCIES:
   Stage-1 is a root stage
@@ -1317,11 +1317,11 @@ STAGE PLANS:
                   tag: -1
                   value expressions:
                         expr: _col1
-                        type: decimal(38,18)
+                        type: decimal(30,10)
                         expr: _col2
                         type: bigint
                         expr: _col3
-                        type: struct<count:bigint,sum:decimal(38,18)>
+                        type: struct<count:bigint,sum:decimal(24,14)>
       Reduce Operator Tree:
         Group By Operator
           aggregations:
@@ -1339,10 +1339,12 @@ STAGE PLANS:
                   expr: _col0
                   type: int
                   expr: (_col1 / _col2)
-                  type: decimal(38,27)
+                  type: decimal(38,23)
                   expr: _col3
-                  type: decimal(38,18)
-            outputColumnNames: _col0, _col1, _col2
+                  type: decimal(24,14)
+                  expr: _col1
+                  type: decimal(30,10)
+            outputColumnNames: _col0, _col1, _col2, _col3
             File Output Operator
               compressed: false
               GlobalTableId: 0
@@ -1366,9 +1368,11 @@ STAGE PLANS:
                     expr: _col0
                     type: int
                     expr: _col1
-                    type: decimal(38,27)
+                    type: decimal(38,23)
                     expr: _col2
-                    type: decimal(38,18)
+                    type: decimal(24,14)
+                    expr: _col3
+                    type: decimal(30,10)
       Reduce Operator Tree:
         Extract
           File Output Operator
@@ -1383,31 +1387,31 @@ STAGE PLANS:
     Fetch Operator
       limit: -1
 
-PREHOOK: query: SELECT value, sum(key) / count(key), avg(key) FROM DECIMAL_UDF 
GROUP BY value ORDER BY value
+PREHOOK: query: SELECT value, sum(key) / count(key), avg(key), sum(key) FROM 
DECIMAL_UDF GROUP BY value ORDER BY value
 PREHOOK: type: QUERY
 PREHOOK: Input: default@decimal_udf
 #### A masked pattern was here ####
-POSTHOOK: query: SELECT value, sum(key) / count(key), avg(key) FROM 
DECIMAL_UDF GROUP BY value ORDER BY value
+POSTHOOK: query: SELECT value, sum(key) / count(key), avg(key), sum(key) FROM 
DECIMAL_UDF GROUP BY value ORDER BY value
 POSTHOOK: type: QUERY
 POSTHOOK: Input: default@decimal_udf
 #### A masked pattern was here ####
--1234567890    -1234567890.123456789   -1234567890.123456789
--1255  -1255.49        -1255.49
--11    -1.122  -1.122
--1     -1.12   -1.12
-0      0.025384615384615384615384615   0.025384615384615385
-1      1.0484  1.0484
-2      2       2
-3      3.14    3.14
-4      3.14    3.14
-10     10      10
-20     20      20
-100    100     100
-124    124     124
-125    125.2   125.2
-200    200     200
-4400   -4400   -4400
-1234567890     1234567890.12345678     1234567890.12345678
+-1234567890    -1234567890.123456789   -1234567890.123456789   
-1234567890.123456789
+-1255  -1255.49        -1255.49        -1255.49
+-11    -1.122  -1.122  -1.122
+-1     -1.12   -1.12   -2.24
+0      0.02538461538461538461538       0.02538461538462        0.33
+1      1.0484  1.0484  5.242
+2      2       2       4
+3      3.14    3.14    9.42
+4      3.14    3.14    3.14
+10     10      10      10
+20     20      20      20
+100    100     100     100
+124    124     124     124
+125    125.2   125.2   125.2
+200    200     200     200
+4400   -4400   -4400   -4400
+1234567890     1234567890.12345678     1234567890.12345678     
1234567890.12345678
 PREHOOK: query: -- negative
 EXPLAIN SELECT -key FROM DECIMAL_UDF
 PREHOOK: type: QUERY

Modified: 
hive/trunk/serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/LazyBinarySerDe.java
URL: 
http://svn.apache.org/viewvc/hive/trunk/serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/LazyBinarySerDe.java?rev=1549679&r1=1549678&r2=1549679&view=diff
==============================================================================
--- 
hive/trunk/serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/LazyBinarySerDe.java
 (original)
+++ 
hive/trunk/serde/src/java/org/apache/hadoop/hive/serde2/lazybinary/LazyBinarySerDe.java
 Mon Dec  9 21:09:01 2013
@@ -413,6 +413,9 @@ public class LazyBinarySerDe extends Abs
       case DECIMAL: {
         HiveDecimalObjectInspector bdoi = (HiveDecimalObjectInspector) poi;
         HiveDecimalWritable t = bdoi.getPrimitiveWritableObject(obj);
+        if (t == null) {
+          return warnedOnceNullMapKey;
+        }
         t.writeToByteStream(byteStream);
         return warnedOnceNullMapKey;
       }


Reply via email to