Repository: incubator-hivemall Updated Branches: refs/heads/master c53b9ff9b -> 8aae974fc
Close #63: [HIVEMALL-90] Refine incomplete AUC UDAF implementation Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/8aae974f Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/8aae974f Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/8aae974f Branch: refs/heads/master Commit: 8aae974fc39cd16080acdf7e493152d7167aa9e7 Parents: c53b9ff Author: Takuya Kitazawa <k.tak...@gmail.com> Authored: Thu Apr 13 14:56:40 2017 +0900 Committer: myui <yuin...@gmail.com> Committed: Thu Apr 13 15:10:16 2017 +0900 ---------------------------------------------------------------------- .../main/java/hivemall/evaluation/AUCUDAF.java | 249 ++++++++++++++----- .../java/hivemall/utils/hadoop/HiveUtils.java | 12 + .../java/hivemall/evaluation/AUCUDAFTest.java | 156 +++++++++++- 3 files changed, 341 insertions(+), 76 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8aae974f/core/src/main/java/hivemall/evaluation/AUCUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/AUCUDAF.java b/core/src/main/java/hivemall/evaluation/AUCUDAF.java index ff067b9..7cbdb52 100644 --- a/core/src/main/java/hivemall/evaluation/AUCUDAF.java +++ b/core/src/main/java/hivemall/evaluation/AUCUDAF.java @@ -18,11 +18,19 @@ */ package hivemall.evaluation; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaDoubleObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.javaLongObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableLongObjectInspector; import hivemall.utils.hadoop.HiveUtils; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; import javax.annotation.Nonnull; @@ -36,12 +44,13 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; @@ -86,12 +95,17 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { private PrimitiveObjectInspector labelOI; private StructObjectInspector internalMergeOI; - private StructField aField; - private StructField scorePrevField; + private StructField indexScoreField; + private StructField areaField; private StructField fpField; private StructField tpField; private StructField fpPrevField; private StructField tpPrevField; + private StructField areaPartialMapField; + private StructField fpPartialMapField; + private StructField tpPartialMapField; + private StructField fpPrevPartialMapField; + private StructField tpPrevPartialMapField; public ClassificationEvaluator() {} @@ -107,12 +121,17 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { } else {// from partial aggregation StructObjectInspector soi = (StructObjectInspector) parameters[0]; this.internalMergeOI = soi; - this.aField = soi.getStructFieldRef("a"); - this.scorePrevField = soi.getStructFieldRef("scorePrev"); + this.indexScoreField = soi.getStructFieldRef("indexScore"); + this.areaField = soi.getStructFieldRef("area"); this.fpField = soi.getStructFieldRef("fp"); this.tpField = soi.getStructFieldRef("tp"); this.fpPrevField = soi.getStructFieldRef("fpPrev"); this.tpPrevField = soi.getStructFieldRef("tpPrev"); + this.areaPartialMapField = soi.getStructFieldRef("areaPartialMap"); + this.fpPartialMapField = soi.getStructFieldRef("fpPartialMap"); + this.tpPartialMapField = soi.getStructFieldRef("tpPartialMap"); + this.fpPrevPartialMapField = soi.getStructFieldRef("fpPrevPartialMap"); + this.tpPrevPartialMapField = soi.getStructFieldRef("tpPrevPartialMap"); } // initialize output @@ -120,7 +139,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial outputOI = internalMergeOI(); } else {// terminate - outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + outputOI = writableDoubleObjectInspector; } return outputOI; } @@ -129,18 +148,43 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); - fieldNames.add("a"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); - fieldNames.add("scorePrev"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("indexScore"); + fieldOIs.add(writableDoubleObjectInspector); + fieldNames.add("area"); + fieldOIs.add(writableDoubleObjectInspector); fieldNames.add("fp"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldOIs.add(writableLongObjectInspector); fieldNames.add("tp"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldOIs.add(writableLongObjectInspector); fieldNames.add("fpPrev"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldOIs.add(writableLongObjectInspector); fieldNames.add("tpPrev"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldOIs.add(writableLongObjectInspector); + + MapObjectInspector areaPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + javaDoubleObjectInspector, javaDoubleObjectInspector); + fieldNames.add("areaPartialMap"); + fieldOIs.add(areaPartialMapOI); + + MapObjectInspector fpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + javaDoubleObjectInspector, javaLongObjectInspector); + fieldNames.add("fpPartialMap"); + fieldOIs.add(fpPartialMapOI); + + MapObjectInspector tpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + javaDoubleObjectInspector, javaLongObjectInspector); + fieldNames.add("tpPartialMap"); + fieldOIs.add(tpPartialMapOI); + + MapObjectInspector fpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + javaDoubleObjectInspector, javaLongObjectInspector); + fieldNames.add("fpPrevPartialMap"); + fieldOIs.add(fpPrevPartialMapOI); + + MapObjectInspector tpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + javaDoubleObjectInspector, javaLongObjectInspector); + fieldNames.add("tpPrevPartialMap"); + fieldOIs.add(tpPrevPartialMapOI); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @@ -188,37 +232,65 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { public Object terminatePartial(AggregationBuffer agg) throws HiveException { ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg; - Object[] partialResult = new Object[6]; - partialResult[0] = new DoubleWritable(myAggr.a); - partialResult[1] = new DoubleWritable(myAggr.scorePrev); + Object[] partialResult = new Object[11]; + partialResult[0] = new DoubleWritable(myAggr.indexScore); + partialResult[1] = new DoubleWritable(myAggr.area); partialResult[2] = new LongWritable(myAggr.fp); partialResult[3] = new LongWritable(myAggr.tp); partialResult[4] = new LongWritable(myAggr.fpPrev); partialResult[5] = new LongWritable(myAggr.tpPrev); + partialResult[6] = myAggr.areaPartialMap; + partialResult[7] = myAggr.fpPartialMap; + partialResult[8] = myAggr.tpPartialMap; + partialResult[9] = myAggr.fpPrevPartialMap; + partialResult[10] = myAggr.tpPrevPartialMap; + return partialResult; } + @SuppressWarnings("unchecked") @Override public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial == null) { return; } - Object aObj = internalMergeOI.getStructFieldData(partial, aField); - Object scorePrevObj = internalMergeOI.getStructFieldData(partial, scorePrevField); + Object indexScoreObj = internalMergeOI.getStructFieldData(partial, indexScoreField); + Object areaObj = internalMergeOI.getStructFieldData(partial, areaField); Object fpObj = internalMergeOI.getStructFieldData(partial, fpField); Object tpObj = internalMergeOI.getStructFieldData(partial, tpField); Object fpPrevObj = internalMergeOI.getStructFieldData(partial, fpPrevField); Object tpPrevObj = internalMergeOI.getStructFieldData(partial, tpPrevField); - double a = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(aObj); - double scorePrev = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(scorePrevObj); - long fp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpObj); - long tp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj); - long fpPrev = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpPrevObj); - long tpPrev = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpPrevObj); + Object areaPartialMapObj = internalMergeOI.getStructFieldData(partial, + areaPartialMapField); + Object fpPartialMapObj = internalMergeOI.getStructFieldData(partial, fpPartialMapField); + Object tpPartialMapObj = internalMergeOI.getStructFieldData(partial, tpPartialMapField); + Object fpPrevPartialMapObj = internalMergeOI.getStructFieldData(partial, + fpPrevPartialMapField); + Object tpPrevPartialMapObj = internalMergeOI.getStructFieldData(partial, + tpPrevPartialMapField); + + double indexScore = writableDoubleObjectInspector.get(indexScoreObj); + double area = writableDoubleObjectInspector.get(areaObj); + long fp = writableLongObjectInspector.get(fpObj); + long tp = writableLongObjectInspector.get(tpObj); + long fpPrev = writableLongObjectInspector.get(fpPrevObj); + long tpPrev = writableLongObjectInspector.get(tpPrevObj); + + StandardMapObjectInspector ddMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + javaDoubleObjectInspector, javaDoubleObjectInspector); + StandardMapObjectInspector dlMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + javaDoubleObjectInspector, javaLongObjectInspector); + + Map<Double, Double> areaPartialMap = (Map<Double, Double>) ddMapOI.getMap(HiveUtils.castLazyBinaryObject(areaPartialMapObj)); + Map<Double, Long> fpPartialMap = (Map<Double, Long>) dlMapOI.getMap(HiveUtils.castLazyBinaryObject(fpPartialMapObj)); + Map<Double, Long> tpPartialMap = (Map<Double, Long>) dlMapOI.getMap(HiveUtils.castLazyBinaryObject(tpPartialMapObj)); + Map<Double, Long> fpPrevPartialMap = (Map<Double, Long>) dlMapOI.getMap(HiveUtils.castLazyBinaryObject(fpPrevPartialMapObj)); + Map<Double, Long> tpPrevPartialMap = (Map<Double, Long>) dlMapOI.getMap(HiveUtils.castLazyBinaryObject(tpPrevPartialMapObj)); ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg; - myAggr.merge(a, scorePrev, fp, tp, fpPrev, tpPrev); + myAggr.merge(indexScore, area, fp, tp, fpPrev, tpPrev, areaPartialMap, fpPartialMap, + tpPartialMap, fpPrevPartialMap, tpPrevPartialMap); } @Override @@ -232,67 +304,110 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { public static class ClassificationAUCAggregationBuffer extends AbstractAggregationBuffer { - double a, scorePrev; + double area, scorePrev, indexScore; long fp, tp, fpPrev, tpPrev; + Map<Double, Double> areaPartialMap; + Map<Double, Long> fpPartialMap, tpPartialMap, fpPrevPartialMap, tpPrevPartialMap; public ClassificationAUCAggregationBuffer() { super(); } void reset() { - this.a = 0.d; + this.area = 0.d; this.scorePrev = Double.POSITIVE_INFINITY; + this.indexScore = 0.d; this.fp = 0; this.tp = 0; this.fpPrev = 0; this.tpPrev = 0; + this.areaPartialMap = new HashMap<Double, Double>(); + this.fpPartialMap = new HashMap<Double, Long>(); + this.tpPartialMap = new HashMap<Double, Long>(); + this.fpPrevPartialMap = new HashMap<Double, Long>(); + this.tpPrevPartialMap = new HashMap<Double, Long>(); } - void merge(double o_a, double o_scorePrev, long o_fp, long o_tp, long o_fpPrev, - long o_tpPrev) { - // compute the latest, not scaled AUC - a += trapezoidArea(fp, fpPrev, tp, tpPrev); - o_a += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev); + void merge(double o_indexScore, double o_area, long o_fp, long o_tp, long o_fpPrev, + long o_tpPrev, Map<Double, Double> o_areaPartialMap, + Map<Double, Long> o_fpPartialMap, Map<Double, Long> o_tpPartialMap, + Map<Double, Long> o_fpPrevPartialMap, Map<Double, Long> o_tpPrevPartialMap) { + + // merge past partial results + areaPartialMap.putAll(o_areaPartialMap); + fpPartialMap.putAll(o_fpPartialMap); + tpPartialMap.putAll(o_tpPartialMap); + fpPrevPartialMap.putAll(o_fpPrevPartialMap); + tpPrevPartialMap.putAll(o_tpPrevPartialMap); + + // finalize source AUC computation + o_area += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev); + + // store source results + areaPartialMap.put(o_indexScore, o_area); + fpPartialMap.put(o_indexScore, o_fp); + tpPartialMap.put(o_indexScore, o_tp); + fpPrevPartialMap.put(o_indexScore, o_fpPrev); + tpPrevPartialMap.put(o_indexScore, o_tpPrev); + } - // sum up the partial areas - a += o_a; - if (scorePrev >= o_scorePrev) { // self is left-side - // adjust combined area by adding missing rectangle - a += trapezoidArea(fp + o_fp, fp, tp, tp); - - // combine TP/FP counts; left-side curve should be base - fp += o_fp; - tp += o_tp; - fpPrev = fp + o_fpPrev; - tpPrev = tp + o_tpPrev; - } else { // self is right-side - a = a + trapezoidArea(fp + o_fp, o_fp, o_tp, o_tp); - - fp += o_fp; - tp += o_tp; - fpPrev += o_fp; - tpPrev += o_tp; - } + double get() throws HiveException { + // store self results + areaPartialMap.put(indexScore, area); + fpPartialMap.put(indexScore, fp); + tpPartialMap.put(indexScore, tp); + fpPrevPartialMap.put(indexScore, fpPrev); + tpPrevPartialMap.put(indexScore, tpPrev); + + SortedMap<Double, Double> areaPartialSortedMap = new TreeMap<Double, Double>( + Collections.reverseOrder()); + areaPartialSortedMap.putAll(areaPartialMap); + + // initialize with leftmost partial result + double firstKey = areaPartialSortedMap.firstKey(); + double res = areaPartialSortedMap.get(firstKey); + long fpAccum = fpPartialMap.get(firstKey); + long tpAccum = tpPartialMap.get(firstKey); + long fpPrevAccum = fpPrevPartialMap.get(firstKey); + long tpPrevAccum = tpPrevPartialMap.get(firstKey); + + // Merge from left (larger score) to right (smaller score) + for (double k : areaPartialSortedMap.keySet()) { + if (k == firstKey) { // variables are already initialized with the leftmost partial result + continue; + } - // set current appropriate `scorePrev` - scorePrev = Math.min(scorePrev, o_scorePrev); + // sum up partial area + res += areaPartialSortedMap.get(k); - // subtract so that get() works correctly - a -= trapezoidArea(fp, fpPrev, tp, tpPrev); - } + // adjust combined area by adding missing rectangle + res += trapezoidArea(0, fpPartialMap.get(k), tpAccum, tpAccum); - double get() throws HiveException { - if (tp == 0 || fp == 0) { + // sum up (prev) TP/FP count + fpPrevAccum = fpAccum + fpPrevPartialMap.get(k); + tpPrevAccum = tpAccum + tpPrevPartialMap.get(k); + fpAccum = fpAccum + fpPartialMap.get(k); + tpAccum = tpAccum + tpPartialMap.get(k); + } + + if (tpAccum == 0 || fpAccum == 0) { throw new HiveException( "AUC score is not defined because there is only one class in `label`."); } - double res = a + trapezoidArea(fp, fpPrev, tp, tpPrev); - return res / (tp * fp); // scale + + // finalize by adding a trapezoid based on the last tp/fp counts + res += trapezoidArea(fpAccum, fpPrevAccum, tpAccum, tpPrevAccum); + + return res / (tpAccum * fpAccum); // scale } void iterate(double score, int label) { if (score != scorePrev) { - a += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, tp)-(fpPrev, tpPrev) + if (scorePrev == Double.POSITIVE_INFINITY) { + // store maximum score as an index + indexScore = score; + } + area += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, tp)-(fpPrev, tpPrev) scorePrev = score; fpPrev = fp; tpPrev = tp; @@ -347,7 +462,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial outputOI = internalMergeOI(); } else {// terminate - outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + outputOI = writableDoubleObjectInspector; } return outputOI; } @@ -357,9 +472,9 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); fieldNames.add("sum"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldOIs.add(writableDoubleObjectInspector); fieldNames.add("count"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldOIs.add(writableLongObjectInspector); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @@ -421,8 +536,8 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { Object sumObj = internalMergeOI.getStructFieldData(partial, sumField); Object countObj = internalMergeOI.getStructFieldData(partial, countField); - double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj); - long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj); + double sum = writableDoubleObjectInspector.get(sumObj); + long count = writableLongObjectInspector.get(countObj); RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg; myAggr.merge(sum, count); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8aae974f/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index b3a2de1..99b300d 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -49,6 +49,8 @@ import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.lazy.LazyInteger; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.lazy.LazyString; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryMap; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; @@ -987,4 +989,14 @@ public final class HiveUtils { serde.initialize(conf, tbl); return serde; } + + @Nonnull + public static Object castLazyBinaryObject(@Nonnull final Object obj) { + if (obj instanceof LazyBinaryMap) { + return ((LazyBinaryMap) obj).getMap(); + } else if (obj instanceof LazyBinaryArray) { + return ((LazyBinaryArray) obj).getList(); + } + return obj; + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8aae974f/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java index 8725756..df26175 100644 --- a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java +++ b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java @@ -21,6 +21,7 @@ package hivemall.evaluation; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; @@ -54,9 +55,9 @@ public class AUCUDAFTest { ArrayList<String> fieldNames = new ArrayList<String>(); ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); - fieldNames.add("a"); + fieldNames.add("indexScore"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); - fieldNames.add("scorePrev"); + fieldNames.add("area"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); fieldNames.add("fp"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); @@ -67,6 +68,36 @@ public class AUCUDAFTest { fieldNames.add("tpPrev"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + MapObjectInspector areaPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector, + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("areaPartialMap"); + fieldOIs.add(areaPartialMapOI); + + MapObjectInspector fpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector, + PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("fpPartialMap"); + fieldOIs.add(fpPartialMapOI); + + MapObjectInspector tpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector, + PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("tpPartialMap"); + fieldOIs.add(tpPartialMapOI); + + MapObjectInspector fpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector, + PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("fpPrevPartialMap"); + fieldOIs.add(fpPrevPartialMapOI); + + MapObjectInspector tpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector, + PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("tpPrevPartialMap"); + fieldOIs.add(tpPrevPartialMapOI); + partialOI = new ObjectInspector[2]; partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); @@ -76,8 +107,8 @@ public class AUCUDAFTest { @Test public void test() throws Exception { // should be sorted by scores in a descending order - final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; - final int[] labels = new int[] {1, 1, 0, 1, 0}; + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.5, 0.3, 0.2}; + final int[] labels = new int[] {1, 1, 0, 1, 1, 0}; evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); @@ -86,7 +117,7 @@ public class AUCUDAFTest { evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); } - Assert.assertEquals(0.83333, agg.get(), 1e-5); + Assert.assertEquals(0.8125, agg.get(), 1e-5); } @Test(expected=HiveException.class) @@ -177,8 +208,8 @@ public class AUCUDAFTest { @Test public void testMerge() throws Exception { - final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; - final int[] labels = new int[] {1, 1, 0, 1, 0}; + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.5, 0.3, 0.2}; + final int[] labels = new int[] {1, 1, 0, 1, 1, 0}; Object[] partials = new Object[3]; @@ -200,11 +231,12 @@ public class AUCUDAFTest { evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); evaluator.reset(agg); evaluator.iterate(agg, new Object[] {scores[4], labels[4]}); + evaluator.iterate(agg, new Object[] {scores[5], labels[5]}); partials[2] = evaluator.terminatePartial(agg); // merge bins // merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should return same value - final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}}; + final int[][] orders = new int[][] {{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}, {2, 0, 1}}; for (int i = 0; i < orders.length; i++) { evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI); evaluator.reset(agg); @@ -213,7 +245,113 @@ public class AUCUDAFTest { evaluator.merge(agg, partials[orders[i][1]]); evaluator.merge(agg, partials[orders[i][2]]); - Assert.assertEquals(0.83333, agg.get(), 1e-5); + Assert.assertEquals(0.8125, agg.get(), 1e-5); + } + } + + @Test + public void test100() throws Exception { + final double[] scores = new double[] { + 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8, + 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, + 0.8, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, + 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, + 0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, + 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, + 0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, + 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, + 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, + 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}; + final int[] labels = new int[] { + 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 0}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < scores.length; i++) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + } + + // should equal to scikit-learn's result + Assert.assertEquals(0.567226890756, agg.get(), 1e-5); + } + + @Test + public void testMerge100() throws Exception { + final double[] scores = new double[] { + 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.8, 0.8, + 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, + 0.8, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, + 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, + 0.6, 0.6, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, + 0.5, 0.5, 0.5, 0.5, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, + 0.4, 0.4, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, + 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, + 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, + 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}; + final int[] labels = new int[] { + 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 0}; + + Object[] partials = new Object[3]; + + // bin #1 (score is in [0.9, 0.7]) + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + int i = 0; + while (scores[i] > 0.6) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + i++; + } + partials[0] = evaluator.terminatePartial(agg); + + // bin #2 (score is in [0.6, 0.4]) + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + while (scores[i] > 0.3) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + i++; + } + partials[1] = evaluator.terminatePartial(agg); + + // bin #3 (score is in [0.3, 0.1]) + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + while (i < 100) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + i++; + } + partials[2] = evaluator.terminatePartial(agg); + + // merge bins + // merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should return same value + final int[][] orders = new int[][] {{0, 1, 2}, {0, 2, 1}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}, {2, 0, 1}}; + for (int j = 0; j < orders.length; j++) { + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI); + evaluator.reset(agg); + + evaluator.merge(agg, partials[orders[j][0]]); + evaluator.merge(agg, partials[orders[j][1]]); + evaluator.merge(agg, partials[orders[j][2]]); + + Assert.assertEquals(0.567226890756, agg.get(), 1e-5); } } }