Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 19d472b54 -> 97bc91247


Close #52: [HIVEMALL-78] Implement AUC UDAF for binary classification


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/97bc9124
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/97bc9124
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/97bc9124

Branch: refs/heads/master
Commit: 97bc91247f1a453b6ffb49212800fa601d25c297
Parents: 19d472b
Author: Takuya Kitazawa <[email protected]>
Authored: Tue Feb 28 18:45:11 2017 +0900
Committer: myui <[email protected]>
Committed: Tue Feb 28 18:51:45 2017 +0900

----------------------------------------------------------------------
 .../main/java/hivemall/evaluation/AUCUDAF.java  | 286 +++++++++++++++++--
 .../java/hivemall/evaluation/AUCUDAFTest.java   | 219 ++++++++++++++
 docs/gitbook/SUMMARY.md                         |   1 +
 docs/gitbook/eval/auc.md                        | 104 +++++++
 4 files changed, 584 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/97bc9124/core/src/main/java/hivemall/evaluation/AUCUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/AUCUDAF.java 
b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
index bc39b4c..ff067b9 100644
--- a/core/src/main/java/hivemall/evaluation/AUCUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/AUCUDAF.java
@@ -38,24 +38,22 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.StructField;
 import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
 import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.apache.hadoop.io.LongWritable;
 
 @SuppressWarnings("deprecation")
-@Description(
-        name = "auc",
-        value = "_FUNC_(array rankItems, array correctItems [, const int 
recommendSize = rankItems.size])"
-                + " - Returns AUC")
+@Description(name = "auc",
+        value = "_FUNC_(array rankItems | double score, array correctItems | 
int label "
+                + "[, const int recommendSize = rankItems.size ])" + " - 
Returns AUC")
 public final class AUCUDAF extends AbstractGenericUDAFResolver {
 
-    // prevent instantiation
-    private AUCUDAF() {}
-
     @Override
     public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) 
throws SemanticException {
         if (typeInfo.length != 2 && typeInfo.length != 3) {
@@ -63,21 +61,257 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
                 "_FUNC_ takes two or three arguments");
         }
 
-        ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
-        if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) 
{
-            throw new UDFArgumentTypeException(0,
-                "The first argument `array rankItems` is invalid form: " + 
typeInfo[0]);
+        if (HiveUtils.isNumberTypeInfo(typeInfo[0]) && 
HiveUtils.isIntegerTypeInfo(typeInfo[1])) {
+            return new ClassificationEvaluator();
+        } else {
+            ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
+            if 
(!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) {
+                throw new UDFArgumentTypeException(0,
+                    "The first argument `array rankItems` is invalid form: " + 
typeInfo[0]);
+            }
+
+            ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
+            if 
(!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) {
+                throw new UDFArgumentTypeException(1,
+                    "The second argument `array correctItems` is invalid form: 
" + typeInfo[1]);
+            }
+
+            return new RankingEvaluator();
         }
-        ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
-        if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) 
{
-            throw new UDFArgumentTypeException(1,
-                "The second argument `array correctItems` is invalid form: " + 
typeInfo[1]);
+    }
+
+    public static class ClassificationEvaluator extends GenericUDAFEvaluator {
+
+        private PrimitiveObjectInspector scoreOI;
+        private PrimitiveObjectInspector labelOI;
+
+        private StructObjectInspector internalMergeOI;
+        private StructField aField;
+        private StructField scorePrevField;
+        private StructField fpField;
+        private StructField tpField;
+        private StructField fpPrevField;
+        private StructField tpPrevField;
+
+        public ClassificationEvaluator() {}
+
+        @Override
+        public ObjectInspector init(Mode mode, ObjectInspector[] parameters) 
throws HiveException {
+            assert (parameters.length == 2 || parameters.length == 3) : 
parameters.length;
+            super.init(mode, parameters);
+
+            // initialize input
+            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from 
original data
+                this.scoreOI = HiveUtils.asDoubleCompatibleOI(parameters[0]);
+                this.labelOI = HiveUtils.asIntegerOI(parameters[1]);
+            } else {// from partial aggregation
+                StructObjectInspector soi = (StructObjectInspector) 
parameters[0];
+                this.internalMergeOI = soi;
+                this.aField = soi.getStructFieldRef("a");
+                this.scorePrevField = soi.getStructFieldRef("scorePrev");
+                this.fpField = soi.getStructFieldRef("fp");
+                this.tpField = soi.getStructFieldRef("tp");
+                this.fpPrevField = soi.getStructFieldRef("fpPrev");
+                this.tpPrevField = soi.getStructFieldRef("tpPrev");
+            }
+
+            // initialize output
+            final ObjectInspector outputOI;
+            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// 
terminatePartial
+                outputOI = internalMergeOI();
+            } else {// terminate
+                outputOI = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+            }
+            return outputOI;
+        }
+
+        private static StructObjectInspector internalMergeOI() {
+            ArrayList<String> fieldNames = new ArrayList<String>();
+            ArrayList<ObjectInspector> fieldOIs = new 
ArrayList<ObjectInspector>();
+
+            fieldNames.add("a");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+            fieldNames.add("scorePrev");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+            fieldNames.add("fp");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+            fieldNames.add("tp");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+            fieldNames.add("fpPrev");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+            fieldNames.add("tpPrev");
+            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+
+            return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+        }
+
+        @Override
+        public AggregationBuffer getNewAggregationBuffer() throws 
HiveException {
+            AggregationBuffer myAggr = new 
ClassificationAUCAggregationBuffer();
+            reset(myAggr);
+            return myAggr;
+        }
+
+        @Override
+        public void reset(AggregationBuffer agg) throws HiveException {
+            ClassificationAUCAggregationBuffer myAggr = 
(ClassificationAUCAggregationBuffer) agg;
+            myAggr.reset();
         }
 
-        return new Evaluator();
+        @Override
+        public void iterate(AggregationBuffer agg, Object[] parameters) throws 
HiveException {
+            ClassificationAUCAggregationBuffer myAggr = 
(ClassificationAUCAggregationBuffer) agg;
+
+            if (parameters[0] == null) {
+                return;
+            }
+            if (parameters[1] == null) {
+                return;
+            }
+
+            double score = HiveUtils.getDouble(parameters[0], scoreOI);
+            if (score < 0.0d || score > 1.0d) {
+                throw new UDFArgumentException("score value MUST be in range 
[0,1]: " + score);
+            }
+
+            int label = PrimitiveObjectInspectorUtils.getInt(parameters[1], 
labelOI);
+            if (label == -1) {
+                label = 0;
+            } else if (label != 0 && label != 1) {
+                throw new UDFArgumentException("label MUST be 0/1 or -1/1: " + 
label);
+            }
+
+            myAggr.iterate(score, label);
+        }
+
+        @Override
+        public Object terminatePartial(AggregationBuffer agg) throws 
HiveException {
+            ClassificationAUCAggregationBuffer myAggr = 
(ClassificationAUCAggregationBuffer) agg;
+
+            Object[] partialResult = new Object[6];
+            partialResult[0] = new DoubleWritable(myAggr.a);
+            partialResult[1] = new DoubleWritable(myAggr.scorePrev);
+            partialResult[2] = new LongWritable(myAggr.fp);
+            partialResult[3] = new LongWritable(myAggr.tp);
+            partialResult[4] = new LongWritable(myAggr.fpPrev);
+            partialResult[5] = new LongWritable(myAggr.tpPrev);
+            return partialResult;
+        }
+
+        @Override
+        public void merge(AggregationBuffer agg, Object partial) throws 
HiveException {
+            if (partial == null) {
+                return;
+            }
+
+            Object aObj = internalMergeOI.getStructFieldData(partial, aField);
+            Object scorePrevObj = internalMergeOI.getStructFieldData(partial, 
scorePrevField);
+            Object fpObj = internalMergeOI.getStructFieldData(partial, 
fpField);
+            Object tpObj = internalMergeOI.getStructFieldData(partial, 
tpField);
+            Object fpPrevObj = internalMergeOI.getStructFieldData(partial, 
fpPrevField);
+            Object tpPrevObj = internalMergeOI.getStructFieldData(partial, 
tpPrevField);
+            double a = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(aObj);
+            double scorePrev = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(scorePrevObj);
+            long fp = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpObj);
+            long tp = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj);
+            long fpPrev = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpPrevObj);
+            long tpPrev = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpPrevObj);
+
+            ClassificationAUCAggregationBuffer myAggr = 
(ClassificationAUCAggregationBuffer) agg;
+            myAggr.merge(a, scorePrev, fp, tp, fpPrev, tpPrev);
+        }
+
+        @Override
+        public DoubleWritable terminate(AggregationBuffer agg) throws 
HiveException {
+            ClassificationAUCAggregationBuffer myAggr = 
(ClassificationAUCAggregationBuffer) agg;
+            double result = myAggr.get();
+            return new DoubleWritable(result);
+        }
+
+    }
+
+    public static class ClassificationAUCAggregationBuffer extends 
AbstractAggregationBuffer {
+
+        double a, scorePrev;
+        long fp, tp, fpPrev, tpPrev;
+
+        public ClassificationAUCAggregationBuffer() {
+            super();
+        }
+
+        void reset() {
+            this.a = 0.d;
+            this.scorePrev = Double.POSITIVE_INFINITY;
+            this.fp = 0;
+            this.tp = 0;
+            this.fpPrev = 0;
+            this.tpPrev = 0;
+        }
+
+        void merge(double o_a, double o_scorePrev, long o_fp, long o_tp, long 
o_fpPrev,
+                long o_tpPrev) {
+            // compute the latest, not scaled AUC
+            a += trapezoidArea(fp, fpPrev, tp, tpPrev);
+            o_a += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev);
+
+            // sum up the partial areas
+            a += o_a;
+            if (scorePrev >= o_scorePrev) { // self is left-side
+                // adjust combined area by adding missing rectangle
+                a += trapezoidArea(fp + o_fp, fp, tp, tp);
+
+                // combine TP/FP counts; left-side curve should be base
+                fp += o_fp;
+                tp += o_tp;
+                fpPrev = fp + o_fpPrev;
+                tpPrev = tp + o_tpPrev;
+            } else { // self is right-side
+                a = a + trapezoidArea(fp + o_fp, o_fp, o_tp, o_tp);
+
+                fp += o_fp;
+                tp += o_tp;
+                fpPrev += o_fp;
+                tpPrev += o_tp;
+            }
+
+            // set current appropriate `scorePrev`
+            scorePrev = Math.min(scorePrev, o_scorePrev);
+
+            // subtract so that get() works correctly
+            a -= trapezoidArea(fp, fpPrev, tp, tpPrev);
+        }
+
+        double get() throws HiveException {
+            if (tp == 0 || fp == 0) {
+                throw new HiveException(
+                    "AUC score is not defined because there is only one class 
in `label`.");
+            }
+            double res = a + trapezoidArea(fp, fpPrev, tp, tpPrev);
+            return res / (tp * fp); // scale
+        }
+
+        void iterate(double score, int label) {
+            if (score != scorePrev) {
+                a += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, 
tp)-(fpPrev, tpPrev)
+                scorePrev = score;
+                fpPrev = fp;
+                tpPrev = tp;
+            }
+            if (label == 1) {
+                tp++; // this finally will be the number of positive samples
+            } else {
+                fp++; // this finally will be the number of negative samples
+            }
+        }
+
+        private double trapezoidArea(double x1, double x2, double y1, double 
y2) {
+            double base = Math.abs(x1 - x2);
+            double height = (y1 + y2) / 2.d;
+            return base * height;
+        }
     }
 
-    public static class Evaluator extends GenericUDAFEvaluator {
+    public static class RankingEvaluator extends GenericUDAFEvaluator {
 
         private ListObjectInspector recommendListOI;
         private ListObjectInspector truthListOI;
@@ -87,7 +321,7 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
         private StructField countField;
         private StructField sumField;
 
-        public Evaluator() {}
+        public RankingEvaluator() {}
 
         @Override
         public ObjectInspector init(Mode mode, ObjectInspector[] parameters) 
throws HiveException {
@@ -132,20 +366,20 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
 
         @Override
         public AggregationBuffer getNewAggregationBuffer() throws 
HiveException {
-            AggregationBuffer myAggr = new AUCAggregationBuffer();
+            AggregationBuffer myAggr = new RankingAUCAggregationBuffer();
             reset(myAggr);
             return myAggr;
         }
 
         @Override
         public void reset(AggregationBuffer agg) throws HiveException {
-            AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;
+            RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) 
agg;
             myAggr.reset();
         }
 
         @Override
         public void iterate(AggregationBuffer agg, Object[] parameters) throws 
HiveException {
-            AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;
+            RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) 
agg;
 
             List<?> recommendList = recommendListOI.getList(parameters[0]);
             if (recommendList == null) {
@@ -171,7 +405,7 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
 
         @Override
         public Object terminatePartial(AggregationBuffer agg) throws 
HiveException {
-            AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;
+            RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) 
agg;
 
             Object[] partialResult = new Object[2];
             partialResult[0] = new DoubleWritable(myAggr.sum);
@@ -190,25 +424,25 @@ public final class AUCUDAF extends 
AbstractGenericUDAFResolver {
             double sum = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
             long count = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);
 
-            AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;
+            RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) 
agg;
             myAggr.merge(sum, count);
         }
 
         @Override
         public DoubleWritable terminate(AggregationBuffer agg) throws 
HiveException {
-            AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;
+            RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) 
agg;
             double result = myAggr.get();
             return new DoubleWritable(result);
         }
 
     }
 
-    public static class AUCAggregationBuffer extends AbstractAggregationBuffer 
{
+    public static class RankingAUCAggregationBuffer extends 
AbstractAggregationBuffer {
 
         double sum;
         long count;
 
-        public AUCAggregationBuffer() {
+        public RankingAUCAggregationBuffer() {
             super();
         }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/97bc9124/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java 
b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
new file mode 100644
index 0000000..8725756
--- /dev/null
+++ b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.evaluation;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class AUCUDAFTest {
+    AUCUDAF auc;
+    GenericUDAFEvaluator evaluator;
+    ObjectInspector[] inputOIs;
+    ObjectInspector[] partialOI;
+    AUCUDAF.ClassificationAUCAggregationBuffer agg;
+
+    @Before
+    public void setUp() throws Exception {
+        auc = new AUCUDAF();
+
+        inputOIs = new ObjectInspector[] {
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.DOUBLE),
+                
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(
+                        PrimitiveObjectInspector.PrimitiveCategory.INT)};
+
+        evaluator = auc.getEvaluator(new 
SimpleGenericUDAFParameterInfo(inputOIs, false, false));
+
+        ArrayList<String> fieldNames = new ArrayList<String>();
+        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
+        fieldNames.add("a");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+        fieldNames.add("scorePrev");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+        fieldNames.add("fp");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+        fieldNames.add("tp");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+        fieldNames.add("fpPrev");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+        fieldNames.add("tpPrev");
+        
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+
+        partialOI = new ObjectInspector[2];
+        partialOI[0] = 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+
+        agg = (AUCUDAF.ClassificationAUCAggregationBuffer) 
evaluator.getNewAggregationBuffer();
+    }
+
+    @Test
+    public void test() throws Exception {
+        // should be sorted by scores in a descending order
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+        final int[] labels = new int[] {1, 1, 0, 1, 0};
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < scores.length; i++) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+        }
+
+        Assert.assertEquals(0.83333, agg.get(), 1e-5);
+    }
+
+    @Test(expected=HiveException.class)
+    public void testAllTruePositive() throws Exception {
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+        final int[] labels = new int[] {1, 1, 1, 1, 1};
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < scores.length; i++) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+        }
+
+        // AUC for all TP scores are not defined
+        agg.get();
+    }
+
+    @Test(expected=HiveException.class)
+    public void testAllFalsePositive() throws Exception {
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+        final int[] labels = new int[] {0, 0, 0, 0, 0};
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < scores.length; i++) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+        }
+
+        // AUC for all FP scores are not defined
+        agg.get();
+    }
+
+    @Test
+    public void testMaxAUC() throws Exception {
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+        final int[] labels = new int[] {1, 1, 1, 1, 0};
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < scores.length; i++) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+        }
+
+        // All TPs are ranked higher than FPs => AUC=1.0
+        Assert.assertEquals(1.d, agg.get(), 1e-5);
+    }
+
+    @Test
+    public void testMinAUC() throws Exception {
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+        final int[] labels = new int[] {0, 0, 0, 1, 1};
+
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+
+        for (int i = 0; i < scores.length; i++) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels[i]});
+        }
+
+        // All TPs are ranked lower than FPs => AUC=0.0
+        Assert.assertEquals(0.d, agg.get(), 1e-5);
+    }
+
+    @Test
+    public void testMidAUC() throws Exception {
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+
+        // if TP and FP appear alternately, AUC=0.5
+        final int[] labels1 = new int[] {1, 0, 1, 0, 1};
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < scores.length; i++) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels1[i]});
+        }
+        Assert.assertEquals(0.5, agg.get(), 1e-5);
+
+        final int[] labels2 = new int[] {0, 1, 0, 1, 0};
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        for (int i = 0; i < scores.length; i++) {
+            evaluator.iterate(agg, new Object[] {scores[i], labels2[i]});
+        }
+        Assert.assertEquals(0.5, agg.get(), 1e-5);
+    }
+
+    @Test
+    public void testMerge() throws Exception {
+        final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2};
+        final int[] labels = new int[] {1, 1, 0, 1, 0};
+
+        Object[] partials = new Object[3];
+
+        // bin #1
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        evaluator.iterate(agg, new Object[] {scores[0], labels[0]});
+        evaluator.iterate(agg, new Object[] {scores[1], labels[1]});
+        partials[0] = evaluator.terminatePartial(agg);
+
+        // bin #2
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        evaluator.iterate(agg, new Object[] {scores[2], labels[2]});
+        evaluator.iterate(agg, new Object[] {scores[3], labels[3]});
+        partials[1] = evaluator.terminatePartial(agg);
+
+        // bin #3
+        evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs);
+        evaluator.reset(agg);
+        evaluator.iterate(agg, new Object[] {scores[4], labels[4]});
+        partials[2] = evaluator.terminatePartial(agg);
+
+        // merge bins
+        // merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should 
return same value
+        final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, 
{2, 1, 0}};
+        for (int i = 0; i < orders.length; i++) {
+            evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI);
+            evaluator.reset(agg);
+
+            evaluator.merge(agg, partials[orders[i][0]]);
+            evaluator.merge(agg, partials[orders[i][1]]);
+            evaluator.merge(agg, partials[orders[i][2]]);
+
+            Assert.assertEquals(0.83333, agg.get(), 1e-5);
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/97bc9124/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 40f20a8..994f9d8 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -66,6 +66,7 @@
 ## Part IV - Evaluation
 
 * [Statistical evaluation of a prediction model](eval/stat_eval.md)
+    * [Area Under the ROC Curve](eval/auc.md)
 
 * [Data Generation](eval/datagen.md)
     * [Logistic Regression data generation](eval/lr_datagen.md)

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/97bc9124/docs/gitbook/eval/auc.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/eval/auc.md b/docs/gitbook/eval/auc.md
new file mode 100644
index 0000000..3c8de95
--- /dev/null
+++ b/docs/gitbook/eval/auc.md
@@ -0,0 +1,104 @@
+<!--
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing,
+  software distributed under the License is distributed on an
+  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+  KIND, either express or implied.  See the License for the
+  specific language governing permissions and limitations
+  under the License.
+-->
+
+<!-- toc -->
+
+# Area Under the ROC Curve
+
+[ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) 
and Area Under the ROC Curve (AUC) are widely-used metric for binary (i.e., 
positive or negative) classification problems such as [Logistic 
Regression](../binaryclass/a9a_lr.html).
+
+Binary classifiers generally predict how likely a sample is to be positive by 
computing probability. Ultimately, we can evaluate the classifiers by comparing 
the probabilities with truth positive/negative labels.
+
+Now we assume that there is a table which contains predicted scores (i.e., 
probabilities) and truth labels as follows:
+
+| probability<br/>(predicted score) | truth label |
+|:---:|:---:|
+| 0.5 | 0 |
+| 0.3 | 1 |
+| 0.2 | 0 |
+| 0.8 | 1 |
+| 0.7 | 1 |
+
+Once the rows are sorted by the probabilities in a descending order, AUC gives 
a metric based on how many positive (`label=1`) samples are ranked higher than 
negative (`label=0`) samples. If many positive rows get larger scores than 
negative rows, AUC would be large, and hence our classifier would perform well.
+
+# Compute AUC on Hivemall
+
+On Hivemall, a function `auc(double score, int label)` provides a way to 
compute AUC for pairs of probability and truth label.
+
+For instance, following query computes AUC of the table which was shown above:
+
+```sql
+with data as (
+  select 0.5 as prob, 0 as label
+  union all
+  select 0.3 as prob, 1 as label
+  union all
+  select 0.2 as prob, 0 as label
+  union all
+  select 0.8 as prob, 1 as label
+  union all
+  select 0.7 as prob, 1 as label
+), data_ordered as (
+  select prob, label
+  from data
+  order by prob desc
+)
+select auc(prob, label)
+from data_ordered;
+```
+
+This query returns `0.83333` as AUC.
+
+Since AUC is a metric based on ranked probability-label pairs as mentioned 
above, input data (rows) needs to be ordered by scores in a descending order.
+
+Meanwhile, Hive's `distribute by` clause allows you to compute AUC in 
parallel: 
+
+```sql
+with data as (
+  select 0.5 as prob, 0 as label
+  union all
+  select 0.3 as prob, 1 as label
+  union all
+  select 0.2 as prob, 0 as label
+  union all
+  select 0.8 as prob, 1 as label
+  union all
+  select 0.7 as prob, 1 as label
+), data_ordered as (
+  select prob, label
+  from data
+  order by prob desc
+)
+select auc(prob, label)
+from (
+  select prob, label
+  from data_ordered
+  distribute by floor(prob / 0.2)
+) t;
+```
+
+Note that `floor(prob / 0.2)` means that the rows are distributed to 5 bins 
for the AUC computation because the column `prob` is in a [0, 1] range.
+
+# Difference between AUC and Logarithmic Loss
+
+Hivemall has another metric called [Logarithmic 
Loss](stat_eval.html#logarithmic-loss) for binary classification. Both AUC and 
Logarithmic Loss compute scores for probability-label pairs. 
+
+Score produced by AUC is a relative metric based on sorted pairs. On the other 
hand, Logarithmic Loss simply gives a metric by comparing probability with its 
truth label one-by-one.
+
+To give an example, `auc(prob, label)` and `logloss(prob, label)` respectively 
returns `0.83333` and `0.54001` in the above case. Note that larger AUC and 
smaller Logarithmic Loss are better.
\ No newline at end of file

Reply via email to