Repository: incubator-hivemall
Updated Branches:
  refs/heads/master c53b9ff9b -> 8aae974fc


Close #63: [HIVEMALL-90] Refine incomplete AUC UDAF implementation


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8aae974f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8aae974f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8aae974f

Branch: refs/heads/master
Commit: 8aae974fc39cd16080acdf7e493152d7167aa9e7
Parents: c53b9ff
Author: Takuya Kitazawa <k.tak...@gmail.com>
Authored: Thu Apr 13 14:56:40 2017 +0900
Committer: myui <yuin...@gmail.com>
Committed: Thu Apr 13 15:10:16 2017 +0900

----------------------------------------------------------------------
 .../main/java/hivemall/evaluation/AUCUDAF.java  | 249 ++++++++++++++-----
 .../java/hivemall/utils/hadoop/HiveUtils.java   |  12 +
 .../java/hivemall/evaluation/AUCUDAFTest.java   | 156 +++++++++++-
 3 files changed, 341 insertions(+), 76 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8aae974f/core/src/main/java/hivemall/evaluation/AUCUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/AUCUDAF.java 
b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
index ff067b9..7cbdb52 100644
--- a/core/src/main/java/hivemall/evaluation/AUCUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
@@ -18,11 +18,19 @@
  */
 package hivemall.evaluation;
 
+import static 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
+import static 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaLongObjectInspector;
+import static 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+import static 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableLongObjectInspector;
 import hivemall.utils.hadoop.HiveUtils;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.SortedMap;
+import java.util.TreeMap;
 
 import javax.annotation.Nonnull;
 
@@ -36,12 +44,13 @@ import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.StructField;
 import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
-import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
 import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
@@ -86,12 +95,17 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
         private PrimitiveObjectInspector labelOI;
 
         private StructObjectInspector internalMergeOI;
-        private StructField aField;
-        private StructField scorePrevField;
+        private StructField indexScoreField;
+        private StructField areaField;
         private StructField fpField;
         private StructField tpField;
         private StructField fpPrevField;
         private StructField tpPrevField;
+        private StructField areaPartialMapField;
+        private StructField fpPartialMapField;
+        private StructField tpPartialMapField;
+        private StructField fpPrevPartialMapField;
+        private StructField tpPrevPartialMapField;
 
         public ClassificationEvaluator() {}
 
@@ -107,12 +121,17 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
             } else {// from partial aggregation
                 StructObjectInspector soi = (StructObjectInspector) 
parameters[0];
                 this.internalMergeOI = soi;
-                this.aField = soi.getStructFieldRef("a");
-                this.scorePrevField = soi.getStructFieldRef("scorePrev");
+                this.indexScoreField = soi.getStructFieldRef("indexScore");
+                this.areaField = soi.getStructFieldRef("area");
                 this.fpField = soi.getStructFieldRef("fp");
                 this.tpField = soi.getStructFieldRef("tp");
                 this.fpPrevField = soi.getStructFieldRef("fpPrev");
                 this.tpPrevField = soi.getStructFieldRef("tpPrev");
+                this.areaPartialMapField = 
soi.getStructFieldRef("areaPartialMap");
+                this.fpPartialMapField = soi.getStructFieldRef("fpPartialMap");
+                this.tpPartialMapField = soi.getStructFieldRef("tpPartialMap");
+                this.fpPrevPartialMapField = 
soi.getStructFieldRef("fpPrevPartialMap");
+                this.tpPrevPartialMapField = 
soi.getStructFieldRef("tpPrevPartialMap");
             }
 
             // initialize output
@@ -120,7 +139,7 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
             if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// 
terminatePartial
                 outputOI = internalMergeOI();
             } else {// terminate
-                outputOI = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+                outputOI = writableDoubleObjectInspector;
             }
             return outputOI;
         }
@@ -129,18 +148,43 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
             ArrayList<String> fieldNames = new ArrayList<String>();
             ArrayList<ObjectInspector> fieldOIs = new 
ArrayList<ObjectInspector>();
 
-            fieldNames.add("a");
-            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
-            fieldNames.add("scorePrev");
-            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+            fieldNames.add("indexScore");
+            fieldOIs.add(writableDoubleObjectInspector);
+            fieldNames.add("area");
+            fieldOIs.add(writableDoubleObjectInspector);
             fieldNames.add("fp");
-            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+            fieldOIs.add(writableLongObjectInspector);
             fieldNames.add("tp");
-            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+            fieldOIs.add(writableLongObjectInspector);
             fieldNames.add("fpPrev");
-            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+            fieldOIs.add(writableLongObjectInspector);
             fieldNames.add("tpPrev");
-            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+            fieldOIs.add(writableLongObjectInspector);
+
+            MapObjectInspector areaPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+                javaDoubleObjectInspector, javaDoubleObjectInspector);
+            fieldNames.add("areaPartialMap");
+            fieldOIs.add(areaPartialMapOI);
+
+            MapObjectInspector fpPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+                javaDoubleObjectInspector, javaLongObjectInspector);
+            fieldNames.add("fpPartialMap");
+            fieldOIs.add(fpPartialMapOI);
+
+            MapObjectInspector tpPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+                javaDoubleObjectInspector, javaLongObjectInspector);
+            fieldNames.add("tpPartialMap");
+            fieldOIs.add(tpPartialMapOI);
+
+            MapObjectInspector fpPrevPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+                javaDoubleObjectInspector, javaLongObjectInspector);
+            fieldNames.add("fpPrevPartialMap");
+            fieldOIs.add(fpPrevPartialMapOI);
+
+            MapObjectInspector tpPrevPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+                javaDoubleObjectInspector, javaLongObjectInspector);
+            fieldNames.add("tpPrevPartialMap");
+            fieldOIs.add(tpPrevPartialMapOI);
 
             return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
         }
@@ -188,37 +232,65 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
         public Object terminatePartial(AggregationBuffer agg) throws 
HiveException {
             ClassificationAUCAggregationBuffer myAggr = 
(ClassificationAUCAggregationBuffer) agg;
 
-            Object[] partialResult = new Object[6];
-            partialResult[0] = new DoubleWritable(myAggr.a);
-            partialResult[1] = new DoubleWritable(myAggr.scorePrev);
+            Object[] partialResult = new Object[11];
+            partialResult[0] = new DoubleWritable(myAggr.indexScore);
+            partialResult[1] = new DoubleWritable(myAggr.area);
             partialResult[2] = new LongWritable(myAggr.fp);
             partialResult[3] = new LongWritable(myAggr.tp);
             partialResult[4] = new LongWritable(myAggr.fpPrev);
             partialResult[5] = new LongWritable(myAggr.tpPrev);
+            partialResult[6] = myAggr.areaPartialMap;
+            partialResult[7] = myAggr.fpPartialMap;
+            partialResult[8] = myAggr.tpPartialMap;
+            partialResult[9] = myAggr.fpPrevPartialMap;
+            partialResult[10] = myAggr.tpPrevPartialMap;
+
             return partialResult;
         }
 
+        @SuppressWarnings("unchecked")
         @Override
         public void merge(AggregationBuffer agg, Object partial) throws 
HiveException {
             if (partial == null) {
                 return;
             }
 
-            Object aObj = internalMergeOI.getStructFieldData(partial, aField);
-            Object scorePrevObj = internalMergeOI.getStructFieldData(partial, 
scorePrevField);
+            Object indexScoreObj = internalMergeOI.getStructFieldData(partial, 
indexScoreField);
+            Object areaObj = internalMergeOI.getStructFieldData(partial, 
areaField);
             Object fpObj = internalMergeOI.getStructFieldData(partial, 
fpField);
             Object tpObj = internalMergeOI.getStructFieldData(partial, 
tpField);
             Object fpPrevObj = internalMergeOI.getStructFieldData(partial, 
fpPrevField);
             Object tpPrevObj = internalMergeOI.getStructFieldData(partial, 
tpPrevField);
-            double a = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(aObj);
-            double scorePrev = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(scorePrevObj);
-            long fp = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpObj);
-            long tp = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj);
-            long fpPrev = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpPrevObj);
-            long tpPrev = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpPrevObj);
+            Object areaPartialMapObj = 
internalMergeOI.getStructFieldData(partial,
+                areaPartialMapField);
+            Object fpPartialMapObj = 
internalMergeOI.getStructFieldData(partial, fpPartialMapField);
+            Object tpPartialMapObj = 
internalMergeOI.getStructFieldData(partial, tpPartialMapField);
+            Object fpPrevPartialMapObj = 
internalMergeOI.getStructFieldData(partial,
+                fpPrevPartialMapField);
+            Object tpPrevPartialMapObj = 
internalMergeOI.getStructFieldData(partial,
+                tpPrevPartialMapField);
+
+            double indexScore = 
writableDoubleObjectInspector.get(indexScoreObj);
+            double area = writableDoubleObjectInspector.get(areaObj);
+            long fp = writableLongObjectInspector.get(fpObj);
+            long tp = writableLongObjectInspector.get(tpObj);
+            long fpPrev = writableLongObjectInspector.get(fpPrevObj);
+            long tpPrev = writableLongObjectInspector.get(tpPrevObj);
+
+            StandardMapObjectInspector ddMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+                javaDoubleObjectInspector, javaDoubleObjectInspector);
+            StandardMapObjectInspector dlMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+                javaDoubleObjectInspector, javaLongObjectInspector);
+
+            Map<Double, Double> areaPartialMap = (Map<Double, Double>) 
ddMapOI.getMap(HiveUtils.castLazyBinaryObject(areaPartialMapObj));
+            Map<Double, Long> fpPartialMap = (Map<Double, Long>) 
dlMapOI.getMap(HiveUtils.castLazyBinaryObject(fpPartialMapObj));
+            Map<Double, Long> tpPartialMap = (Map<Double, Long>) 
dlMapOI.getMap(HiveUtils.castLazyBinaryObject(tpPartialMapObj));
+            Map<Double, Long> fpPrevPartialMap = (Map<Double, Long>) 
dlMapOI.getMap(HiveUtils.castLazyBinaryObject(fpPrevPartialMapObj));
+            Map<Double, Long> tpPrevPartialMap = (Map<Double, Long>) 
dlMapOI.getMap(HiveUtils.castLazyBinaryObject(tpPrevPartialMapObj));
 
             ClassificationAUCAggregationBuffer myAggr = 
(ClassificationAUCAggregationBuffer) agg;
-            myAggr.merge(a, scorePrev, fp, tp, fpPrev, tpPrev);
+            myAggr.merge(indexScore, area, fp, tp, fpPrev, tpPrev, 
areaPartialMap, fpPartialMap,
+                tpPartialMap, fpPrevPartialMap, tpPrevPartialMap);
         }
 
         @Override
@@ -232,67 +304,110 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
 
     public static class ClassificationAUCAggregationBuffer extends 
AbstractAggregationBuffer {
 
-        double a, scorePrev;
+        double area, scorePrev, indexScore;
         long fp, tp, fpPrev, tpPrev;
+        Map<Double, Double> areaPartialMap;
+        Map<Double, Long> fpPartialMap, tpPartialMap, fpPrevPartialMap, 
tpPrevPartialMap;
 
         public ClassificationAUCAggregationBuffer() {
             super();
         }
 
         void reset() {
-            this.a = 0.d;
+            this.area = 0.d;
             this.scorePrev = Double.POSITIVE_INFINITY;
+            this.indexScore = 0.d;
             this.fp = 0;
             this.tp = 0;
             this.fpPrev = 0;
             this.tpPrev = 0;
+            this.areaPartialMap = new HashMap<Double, Double>();
+            this.fpPartialMap = new HashMap<Double, Long>();
+            this.tpPartialMap = new HashMap<Double, Long>();
+            this.fpPrevPartialMap = new HashMap<Double, Long>();
+            this.tpPrevPartialMap = new HashMap<Double, Long>();
         }
 
-        void merge(double o_a, double o_scorePrev, long o_fp, long o_tp, long 
o_fpPrev,
-                long o_tpPrev) {
-            // compute the latest, not scaled AUC
-            a += trapezoidArea(fp, fpPrev, tp, tpPrev);
-            o_a += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev);
+        void merge(double o_indexScore, double o_area, long o_fp, long o_tp, 
long o_fpPrev,
+                long o_tpPrev, Map<Double, Double> o_areaPartialMap,
+                Map<Double, Long> o_fpPartialMap, Map<Double, Long> 
o_tpPartialMap,
+                Map<Double, Long> o_fpPrevPartialMap, Map<Double, Long> 
o_tpPrevPartialMap) {
+
+            // merge past partial results
+            areaPartialMap.putAll(o_areaPartialMap);
+            fpPartialMap.putAll(o_fpPartialMap);
+            tpPartialMap.putAll(o_tpPartialMap);
+            fpPrevPartialMap.putAll(o_fpPrevPartialMap);
+            tpPrevPartialMap.putAll(o_tpPrevPartialMap);
+
+            // finalize source AUC computation
+            o_area += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev);
+
+            // store source results
+            areaPartialMap.put(o_indexScore, o_area);
+            fpPartialMap.put(o_indexScore, o_fp);
+            tpPartialMap.put(o_indexScore, o_tp);
+            fpPrevPartialMap.put(o_indexScore, o_fpPrev);
+            tpPrevPartialMap.put(o_indexScore, o_tpPrev);
+        }
 
-            // sum up the partial areas
-            a += o_a;
-            if (scorePrev >= o_scorePrev) { // self is left-side
-                // adjust combined area by adding missing rectangle
-                a += trapezoidArea(fp + o_fp, fp, tp, tp);
-
-                // combine TP/FP counts; left-side curve should be base
-                fp += o_fp;
-                tp += o_tp;
-                fpPrev = fp + o_fpPrev;
-                tpPrev = tp + o_tpPrev;
-            } else { // self is right-side
-                a = a + trapezoidArea(fp + o_fp, o_fp, o_tp, o_tp);
-
-                fp += o_fp;
-                tp += o_tp;
-                fpPrev += o_fp;
-                tpPrev += o_tp;
-            }
+        double get() throws HiveException {
+            // store self results
+            areaPartialMap.put(indexScore, area);
+            fpPartialMap.put(indexScore, fp);
+            tpPartialMap.put(indexScore, tp);
+            fpPrevPartialMap.put(indexScore, fpPrev);
+            tpPrevPartialMap.put(indexScore, tpPrev);
+
+            SortedMap<Double, Double> areaPartialSortedMap = new 
TreeMap<Double, Double>(
+                Collections.reverseOrder());
+            areaPartialSortedMap.putAll(areaPartialMap);
+
+            // initialize with leftmost partial result
+            double firstKey = areaPartialSortedMap.firstKey();
+            double res = areaPartialSortedMap.get(firstKey);
+            long fpAccum = fpPartialMap.get(firstKey);
+            long tpAccum = tpPartialMap.get(firstKey);
+            long fpPrevAccum = fpPrevPartialMap.get(firstKey);
+            long tpPrevAccum = tpPrevPartialMap.get(firstKey);
+
+            // Merge from left (larger score) to right (smaller score)
+            for (double k : areaPartialSortedMap.keySet()) {
+                if (k == firstKey) { // variables are already initialized with 
the leftmost partial result
+                    continue;
+                }
 
-            // set current appropriate `scorePrev`
-            scorePrev = Math.min(scorePrev, o_scorePrev);
+                // sum up partial area
+                res += areaPartialSortedMap.get(k);
 
-            // subtract so that get() works correctly
-            a -= trapezoidArea(fp, fpPrev, tp, tpPrev);
-        }
+                // adjust combined area by adding missing rectangle
+                res += trapezoidArea(0, fpPartialMap.get(k), tpAccum, tpAccum);
 
-        double get() throws HiveException {
-            if (tp == 0 || fp == 0) {
+                // sum up (prev) TP/FP count
+                fpPrevAccum = fpAccum + fpPrevPartialMap.get(k);
+                tpPrevAccum = tpAccum + tpPrevPartialMap.get(k);
+                fpAccum = fpAccum + fpPartialMap.get(k);
+                tpAccum = tpAccum + tpPartialMap.get(k);
+            }
+
+            if (tpAccum == 0 || fpAccum == 0) {
                 throw new HiveException(
                     "AUC score is not defined because there is only one class 
in `label`.");
             }
-            double res = a + trapezoidArea(fp, fpPrev, tp, tpPrev);
-            return res / (tp * fp); // scale
+
+            // finalize by adding a trapezoid based on the last tp/fp counts
+            res += trapezoidArea(fpAccum, fpPrevAccum, tpAccum, tpPrevAccum);
+
+            return res / (tpAccum * fpAccum); // scale
         }
 
         void iterate(double score, int label) {
             if (score != scorePrev) {
-                a += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, 
tp)-(fpPrev, tpPrev)
+                if (scorePrev == Double.POSITIVE_INFINITY) {
+                    // store maximum score as an index
+                    indexScore = score;
+                }
+                area += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, 
tp)-(fpPrev, tpPrev)
                 scorePrev = score;
                 fpPrev = fp;
                 tpPrev = tp;
@@ -347,7 +462,7 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
             if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// 
terminatePartial
                 outputOI = internalMergeOI();
             } else {// terminate
-                outputOI = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+                outputOI = writableDoubleObjectInspector;
             }
             return outputOI;
         }
@@ -357,9 +472,9 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
             ArrayList<ObjectInspector> fieldOIs = new 
ArrayList<ObjectInspector>();
 
             fieldNames.add("sum");
-            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+            fieldOIs.add(writableDoubleObjectInspector);
             fieldNames.add("count");
-            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+            fieldOIs.add(writableLongObjectInspector);
 
             return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
         }
@@ -421,8 +536,8 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
 
             Object sumObj = internalMergeOI.getStructFieldData(partial, 
sumField);
             Object countObj = internalMergeOI.getStructFieldData(partial, 
countField);
-            double sum = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
-            long count = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);
+            double sum = writableDoubleObjectInspector.get(sumObj);
+            long count = writableLongObjectInspector.get(countObj);
 
             RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) 
agg;
             myAggr.merge(sum, count);

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8aae974f/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java 
b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
index b3a2de1..99b300d 100644
--- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
@@ -49,6 +49,8 @@ import org.apache.hadoop.hive.serde2.io.ShortWritable;
 import org.apache.hadoop.hive.serde2.lazy.LazyInteger;
 import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
 import org.apache.hadoop.hive.serde2.lazy.LazyString;
+import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
+import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryMap;
 import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
@@ -987,4 +989,14 @@ public final class HiveUtils {
         serde.initialize(conf, tbl);
         return serde;
     }
+
+    @Nonnull
+    public static Object castLazyBinaryObject(@Nonnull final Object obj) {
+        if (obj instanceof LazyBinaryMap) {
+            return ((LazyBinaryMap) obj).getMap();
+        } else if (obj instanceof LazyBinaryArray) {
+            return ((LazyBinaryArray) obj).getList();
+        }
+        return obj;
+    }
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8aae974f/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java 
b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
index 8725756..df26175 100644
--- a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
+++ b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
@@ -21,6 +21,7 @@ package hivemall.evaluation;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
 import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
@@ -54,9 +55,9 @@ public class AUCUDAFTest {
 
         ArrayList<String> fieldNames = new ArrayList<String>();
         ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
-        fieldNames.add("a");
+        fieldNames.add("indexScore");
         
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
-        fieldNames.add("scorePrev");
+        fieldNames.add("area");
         
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
         fieldNames.add("fp");
         
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
@@ -67,6 +68,36 @@ public class AUCUDAFTest {
         fieldNames.add("tpPrev");
         
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
 
+        MapObjectInspector areaPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+        fieldNames.add("areaPartialMap");
+        fieldOIs.add(areaPartialMapOI);
+
+        MapObjectInspector fpPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+            PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+        fieldNames.add("fpPartialMap");
+        fieldOIs.add(fpPartialMapOI);
+
+        MapObjectInspector tpPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+            PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+        fieldNames.add("tpPartialMap");
+        fieldOIs.add(tpPartialMapOI);
+
+        MapObjectInspector fpPrevPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+            PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+        fieldNames.add("fpPrevPartialMap");
+        fieldOIs.add(fpPrevPartialMapOI);
+
+        MapObjectInspector tpPrevPartialMapOI = 
ObjectInspectorFactory.getStandardMapObjectInspector(
+            PrimitiveObjectInspectorFactory.writableDoubleObjectInspector,
+            PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+        fieldNames.add("tpPrevPartialMap");
+        fieldOIs.add(tpPrevPartialMapOI);
+
         partialOI = new ObjectInspector[2];
         partialOI[0] = 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
 
@@ -76,8 +107,8 @@ public class AUCUDAFTest {
     @Test
     public void test() throws Exception {
         // should be sorted by scores in a descending order
-        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
-        final int[] labels = new int[] {1, 1, 0, 1, 0};
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.5, 0.3, 0.2};
+        final int[] labels = new int[] {1, 1, 0, 1, 1, 0};
 
         evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
         evaluator.reset(agg);
@@ -86,7 +117,7 @@ public class AUCUDAFTest {
             evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
         }
 
-        Assert.assertEquals(0.83333, agg.get(), 1e-5);
+        Assert.assertEquals(0.8125, agg.get(), 1e-5);
     }
 
     @Test(expected=HiveException.class)
@@ -177,8 +208,8 @@ public class AUCUDAFTest {
 
     @Test
     public void testMerge() throws Exception {
-        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
-        final int[] labels = new int[] {1, 1, 0, 1, 0};
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.5, 0.3, 0.2};
+        final int[] labels = new int[] {1, 1, 0, 1, 1, 0};
 
         Object[] partials = new Object[3];
 
@@ -200,11 +231,12 @@ public class AUCUDAFTest {
         evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
         evaluator.reset(agg);
         evaluator.iterate(agg, new Object[] {scores[4], labels[4]});
+        evaluator.iterate(agg, new Object[] {scores[5], labels[5]});
         partials[2] = evaluator.terminatePartial(agg);
 
         // merge bins
         // merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should 
return same value
-        final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, 
{2, 1, 0}};
+        final int[][] orders = new int[][] {{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, 
{1, 2, 0}, {2, 1, 0}, {2, 0, 1}};
         for (int i = 0; i < orders.length; i++) {
             evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
             evaluator.reset(agg);
@@ -213,7 +245,113 @@ public class AUCUDAFTest {
             evaluator.merge(agg, partials[orders[i][1]]);
             evaluator.merge(agg, partials[orders[i][2]]);
 
-            Assert.assertEquals(0.83333, agg.get(), 1e-5);
+            Assert.assertEquals(0.8125, agg.get(), 1e-5);
+        }
+    }
+
+    @Test
+    public void test100() throws Exception {
+        final double[] scores = new double[] {
+                0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8,
+                0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8,
+                0.8, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
+                0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6,
+                0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
+                0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4,
+                0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3,
+                0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
+                0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1,
+                0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+        final int[] labels = new int[] {
+                1, 1, 1, 1, 0, 0, 0, 0, 1, 1,
+                1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+                0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+                1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+                1, 0, 0, 1, 1, 1, 1, 1, 1, 0,
+                0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
+                0, 0, 0, 0, 0, 1, 1, 1, 1, 0,
+                0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
+                1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+                1, 1, 1, 0, 0, 0, 0, 0, 0, 0};
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < scores.length; i++) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+        }
+
+        // should equal to scikit-learn's result
+        Assert.assertEquals(0.567226890756, agg.get(), 1e-5);
+    }
+
+    @Test
+    public void testMerge100() throws Exception {
+        final double[] scores = new double[] {
+            0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8,
+            0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8,
+            0.8, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7,
+            0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6,
+            0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
+            0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4,
+            0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3,
+            0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2,
+            0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1,
+            0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+        final int[] labels = new int[] {
+            1, 1, 1, 1, 0, 0, 0, 0, 1, 1,
+            1, 1, 1, 1, 1, 1, 1, 0, 0, 0,
+            0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+            1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+            1, 0, 0, 1, 1, 1, 1, 1, 1, 0,
+            0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
+            0, 0, 0, 0, 0, 1, 1, 1, 1, 0,
+            0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
+            1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+            1, 1, 1, 0, 0, 0, 0, 0, 0, 0};
+
+        Object[] partials = new Object[3];
+
+        // bin #1 (score is in [0.9, 0.7])
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        int i = 0;
+        while (scores[i] > 0.6) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+            i++;
+        }
+        partials[0] = evaluator.terminatePartial(agg);
+
+        // bin #2 (score is in [0.6, 0.4])
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        while (scores[i] > 0.3) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+            i++;
+        }
+        partials[1] = evaluator.terminatePartial(agg);
+
+        // bin #3 (score is in [0.3, 0.1])
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        while (i < 100) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+            i++;
+        }
+        partials[2] = evaluator.terminatePartial(agg);
+
+        // merge bins
+        // merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should 
return same value
+        final int[][] orders = new int[][] {{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, 
{1, 2, 0}, {2, 1, 0}, {2, 0, 1}};
+        for (int j = 0; j < orders.length; j++) {
+            evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
+            evaluator.reset(agg);
+
+            evaluator.merge(agg, partials[orders[j][0]]);
+            evaluator.merge(agg, partials[orders[j][1]]);
+            evaluator.merge(agg, partials[orders[j][2]]);
+
+            Assert.assertEquals(0.567226890756, agg.get(), 1e-5);
         }
     }
 }

Reply via email to