Repository: incubator-hivemall Updated Branches: refs/heads/master 19d472b54 -> 97bc91247
Close #52: [HIVEMALL-78] Implement AUC UDAF for binary classification Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/97bc9124 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/97bc9124 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/97bc9124 Branch: refs/heads/master Commit: 97bc91247f1a453b6ffb49212800fa601d25c297 Parents: 19d472b Author: Takuya Kitazawa <[email protected]> Authored: Tue Feb 28 18:45:11 2017 +0900 Committer: myui <[email protected]> Committed: Tue Feb 28 18:51:45 2017 +0900 ---------------------------------------------------------------------- .../main/java/hivemall/evaluation/AUCUDAF.java | 286 +++++++++++++++++-- .../java/hivemall/evaluation/AUCUDAFTest.java | 219 ++++++++++++++ docs/gitbook/SUMMARY.md | 1 + docs/gitbook/eval/auc.md | 104 +++++++ 4 files changed, 584 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/97bc9124/core/src/main/java/hivemall/evaluation/AUCUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/AUCUDAF.java b/core/src/main/java/hivemall/evaluation/AUCUDAF.java index bc39b4c..ff067b9 100644 --- a/core/src/main/java/hivemall/evaluation/AUCUDAF.java +++ b/core/src/main/java/hivemall/evaluation/AUCUDAF.java @@ -38,24 +38,22 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; @SuppressWarnings("deprecation") -@Description( - name = "auc", - value = "_FUNC_(array rankItems, array correctItems [, const int recommendSize = rankItems.size])" - + " - Returns AUC") +@Description(name = "auc", + value = "_FUNC_(array rankItems | double score, array correctItems | int label " + + "[, const int recommendSize = rankItems.size ])" + " - Returns AUC") public final class AUCUDAF extends AbstractGenericUDAFResolver { - // prevent instantiation - private AUCUDAF() {} - @Override public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { if (typeInfo.length != 2 && typeInfo.length != 3) { @@ -63,21 +61,257 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { "_FUNC_ takes two or three arguments"); } - ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]); - if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) { - throw new UDFArgumentTypeException(0, - "The first argument `array rankItems` is invalid form: " + typeInfo[0]); + if (HiveUtils.isNumberTypeInfo(typeInfo[0]) && HiveUtils.isIntegerTypeInfo(typeInfo[1])) { + return new ClassificationEvaluator(); + } else { + ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]); + if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(0, + "The first argument `array rankItems` is invalid form: " + typeInfo[0]); + } + + ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]); + if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) { + throw new UDFArgumentTypeException(1, + "The second argument `array correctItems` is invalid form: " + typeInfo[1]); + } + + return new RankingEvaluator(); } - ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]); - if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) { - throw new UDFArgumentTypeException(1, - "The second argument `array correctItems` is invalid form: " + typeInfo[1]); + } + + public static class ClassificationEvaluator extends GenericUDAFEvaluator { + + private PrimitiveObjectInspector scoreOI; + private PrimitiveObjectInspector labelOI; + + private StructObjectInspector internalMergeOI; + private StructField aField; + private StructField scorePrevField; + private StructField fpField; + private StructField tpField; + private StructField fpPrevField; + private StructField tpPrevField; + + public ClassificationEvaluator() {} + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 2 || parameters.length == 3) : parameters.length; + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.scoreOI = HiveUtils.asDoubleCompatibleOI(parameters[0]); + this.labelOI = HiveUtils.asIntegerOI(parameters[1]); + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) parameters[0]; + this.internalMergeOI = soi; + this.aField = soi.getStructFieldRef("a"); + this.scorePrevField = soi.getStructFieldRef("scorePrev"); + this.fpField = soi.getStructFieldRef("fp"); + this.tpField = soi.getStructFieldRef("tp"); + this.fpPrevField = soi.getStructFieldRef("fpPrev"); + this.tpPrevField = soi.getStructFieldRef("tpPrev"); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(); + } else {// terminate + outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; + } + return outputOI; + } + + private static StructObjectInspector internalMergeOI() { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("a"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("scorePrev"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("fp"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("tp"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("fpPrev"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("tpPrev"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + AggregationBuffer myAggr = new ClassificationAUCAggregationBuffer(); + reset(myAggr); + return myAggr; + } + + @Override + public void reset(AggregationBuffer agg) throws HiveException { + ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg; + myAggr.reset(); } - return new Evaluator(); + @Override + public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { + ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg; + + if (parameters[0] == null) { + return; + } + if (parameters[1] == null) { + return; + } + + double score = HiveUtils.getDouble(parameters[0], scoreOI); + if (score < 0.0d || score > 1.0d) { + throw new UDFArgumentException("score value MUST be in range [0,1]: " + score); + } + + int label = PrimitiveObjectInspectorUtils.getInt(parameters[1], labelOI); + if (label == -1) { + label = 0; + } else if (label != 0 && label != 1) { + throw new UDFArgumentException("label MUST be 0/1 or -1/1: " + label); + } + + myAggr.iterate(score, label); + } + + @Override + public Object terminatePartial(AggregationBuffer agg) throws HiveException { + ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg; + + Object[] partialResult = new Object[6]; + partialResult[0] = new DoubleWritable(myAggr.a); + partialResult[1] = new DoubleWritable(myAggr.scorePrev); + partialResult[2] = new LongWritable(myAggr.fp); + partialResult[3] = new LongWritable(myAggr.tp); + partialResult[4] = new LongWritable(myAggr.fpPrev); + partialResult[5] = new LongWritable(myAggr.tpPrev); + return partialResult; + } + + @Override + public void merge(AggregationBuffer agg, Object partial) throws HiveException { + if (partial == null) { + return; + } + + Object aObj = internalMergeOI.getStructFieldData(partial, aField); + Object scorePrevObj = internalMergeOI.getStructFieldData(partial, scorePrevField); + Object fpObj = internalMergeOI.getStructFieldData(partial, fpField); + Object tpObj = internalMergeOI.getStructFieldData(partial, tpField); + Object fpPrevObj = internalMergeOI.getStructFieldData(partial, fpPrevField); + Object tpPrevObj = internalMergeOI.getStructFieldData(partial, tpPrevField); + double a = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(aObj); + double scorePrev = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(scorePrevObj); + long fp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpObj); + long tp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj); + long fpPrev = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpPrevObj); + long tpPrev = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpPrevObj); + + ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg; + myAggr.merge(a, scorePrev, fp, tp, fpPrev, tpPrev); + } + + @Override + public DoubleWritable terminate(AggregationBuffer agg) throws HiveException { + ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg; + double result = myAggr.get(); + return new DoubleWritable(result); + } + + } + + public static class ClassificationAUCAggregationBuffer extends AbstractAggregationBuffer { + + double a, scorePrev; + long fp, tp, fpPrev, tpPrev; + + public ClassificationAUCAggregationBuffer() { + super(); + } + + void reset() { + this.a = 0.d; + this.scorePrev = Double.POSITIVE_INFINITY; + this.fp = 0; + this.tp = 0; + this.fpPrev = 0; + this.tpPrev = 0; + } + + void merge(double o_a, double o_scorePrev, long o_fp, long o_tp, long o_fpPrev, + long o_tpPrev) { + // compute the latest, not scaled AUC + a += trapezoidArea(fp, fpPrev, tp, tpPrev); + o_a += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev); + + // sum up the partial areas + a += o_a; + if (scorePrev >= o_scorePrev) { // self is left-side + // adjust combined area by adding missing rectangle + a += trapezoidArea(fp + o_fp, fp, tp, tp); + + // combine TP/FP counts; left-side curve should be base + fp += o_fp; + tp += o_tp; + fpPrev = fp + o_fpPrev; + tpPrev = tp + o_tpPrev; + } else { // self is right-side + a = a + trapezoidArea(fp + o_fp, o_fp, o_tp, o_tp); + + fp += o_fp; + tp += o_tp; + fpPrev += o_fp; + tpPrev += o_tp; + } + + // set current appropriate `scorePrev` + scorePrev = Math.min(scorePrev, o_scorePrev); + + // subtract so that get() works correctly + a -= trapezoidArea(fp, fpPrev, tp, tpPrev); + } + + double get() throws HiveException { + if (tp == 0 || fp == 0) { + throw new HiveException( + "AUC score is not defined because there is only one class in `label`."); + } + double res = a + trapezoidArea(fp, fpPrev, tp, tpPrev); + return res / (tp * fp); // scale + } + + void iterate(double score, int label) { + if (score != scorePrev) { + a += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, tp)-(fpPrev, tpPrev) + scorePrev = score; + fpPrev = fp; + tpPrev = tp; + } + if (label == 1) { + tp++; // this finally will be the number of positive samples + } else { + fp++; // this finally will be the number of negative samples + } + } + + private double trapezoidArea(double x1, double x2, double y1, double y2) { + double base = Math.abs(x1 - x2); + double height = (y1 + y2) / 2.d; + return base * height; + } } - public static class Evaluator extends GenericUDAFEvaluator { + public static class RankingEvaluator extends GenericUDAFEvaluator { private ListObjectInspector recommendListOI; private ListObjectInspector truthListOI; @@ -87,7 +321,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { private StructField countField; private StructField sumField; - public Evaluator() {} + public RankingEvaluator() {} @Override public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { @@ -132,20 +366,20 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { @Override public AggregationBuffer getNewAggregationBuffer() throws HiveException { - AggregationBuffer myAggr = new AUCAggregationBuffer(); + AggregationBuffer myAggr = new RankingAUCAggregationBuffer(); reset(myAggr); return myAggr; } @Override public void reset(AggregationBuffer agg) throws HiveException { - AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg; + RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg; myAggr.reset(); } @Override public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException { - AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg; + RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg; List<?> recommendList = recommendListOI.getList(parameters[0]); if (recommendList == null) { @@ -171,7 +405,7 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { @Override public Object terminatePartial(AggregationBuffer agg) throws HiveException { - AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg; + RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg; Object[] partialResult = new Object[2]; partialResult[0] = new DoubleWritable(myAggr.sum); @@ -190,25 +424,25 @@ public final class AUCUDAF extends AbstractGenericUDAFResolver { double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj); long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj); - AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg; + RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg; myAggr.merge(sum, count); } @Override public DoubleWritable terminate(AggregationBuffer agg) throws HiveException { - AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg; + RankingAUCAggregationBuffer myAggr = (RankingAUCAggregationBuffer) agg; double result = myAggr.get(); return new DoubleWritable(result); } } - public static class AUCAggregationBuffer extends AbstractAggregationBuffer { + public static class RankingAUCAggregationBuffer extends AbstractAggregationBuffer { double sum; long count; - public AUCAggregationBuffer() { + public RankingAUCAggregationBuffer() { super(); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/97bc9124/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java new file mode 100644 index 0000000..8725756 --- /dev/null +++ b/core/src/test/java/hivemall/evaluation/AUCUDAFTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.evaluation; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import java.util.ArrayList; +import java.util.Arrays; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class AUCUDAFTest { + AUCUDAF auc; + GenericUDAFEvaluator evaluator; + ObjectInspector[] inputOIs; + ObjectInspector[] partialOI; + AUCUDAF.ClassificationAUCAggregationBuffer agg; + + @Before + public void setUp() throws Exception { + auc = new AUCUDAF(); + + inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.DOUBLE), + PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( + PrimitiveObjectInspector.PrimitiveCategory.INT)}; + + evaluator = auc.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("a"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("scorePrev"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("fp"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("tp"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("fpPrev"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("tpPrev"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + + partialOI = new ObjectInspector[2]; + partialOI[0] = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + + agg = (AUCUDAF.ClassificationAUCAggregationBuffer) evaluator.getNewAggregationBuffer(); + } + + @Test + public void test() throws Exception { + // should be sorted by scores in a descending order + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; + final int[] labels = new int[] {1, 1, 0, 1, 0}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < scores.length; i++) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + } + + Assert.assertEquals(0.83333, agg.get(), 1e-5); + } + + @Test(expected=HiveException.class) + public void testAllTruePositive() throws Exception { + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; + final int[] labels = new int[] {1, 1, 1, 1, 1}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < scores.length; i++) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + } + + // AUC for all TP scores are not defined + agg.get(); + } + + @Test(expected=HiveException.class) + public void testAllFalsePositive() throws Exception { + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; + final int[] labels = new int[] {0, 0, 0, 0, 0}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < scores.length; i++) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + } + + // AUC for all FP scores are not defined + agg.get(); + } + + @Test + public void testMaxAUC() throws Exception { + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; + final int[] labels = new int[] {1, 1, 1, 1, 0}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < scores.length; i++) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + } + + // All TPs are ranked higher than FPs => AUC=1.0 + Assert.assertEquals(1.d, agg.get(), 1e-5); + } + + @Test + public void testMinAUC() throws Exception { + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; + final int[] labels = new int[] {0, 0, 0, 1, 1}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < scores.length; i++) { + evaluator.iterate(agg, new Object[] {scores[i], labels[i]}); + } + + // All TPs are ranked lower than FPs => AUC=0.0 + Assert.assertEquals(0.d, agg.get(), 1e-5); + } + + @Test + public void testMidAUC() throws Exception { + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; + + // if TP and FP appear alternately, AUC=0.5 + final int[] labels1 = new int[] {1, 0, 1, 0, 1}; + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 0; i < scores.length; i++) { + evaluator.iterate(agg, new Object[] {scores[i], labels1[i]}); + } + Assert.assertEquals(0.5, agg.get(), 1e-5); + + final int[] labels2 = new int[] {0, 1, 0, 1, 0}; + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + for (int i = 0; i < scores.length; i++) { + evaluator.iterate(agg, new Object[] {scores[i], labels2[i]}); + } + Assert.assertEquals(0.5, agg.get(), 1e-5); + } + + @Test + public void testMerge() throws Exception { + final double[] scores = new double[] {0.8, 0.7, 0.5, 0.3, 0.2}; + final int[] labels = new int[] {1, 1, 0, 1, 0}; + + Object[] partials = new Object[3]; + + // bin #1 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + evaluator.iterate(agg, new Object[] {scores[0], labels[0]}); + evaluator.iterate(agg, new Object[] {scores[1], labels[1]}); + partials[0] = evaluator.terminatePartial(agg); + + // bin #2 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + evaluator.iterate(agg, new Object[] {scores[2], labels[2]}); + evaluator.iterate(agg, new Object[] {scores[3], labels[3]}); + partials[1] = evaluator.terminatePartial(agg); + + // bin #3 + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + evaluator.iterate(agg, new Object[] {scores[4], labels[4]}); + partials[2] = evaluator.terminatePartial(agg); + + // merge bins + // merge in a different order; e.g., <bin0, bin1>, <bin1, bin0> should return same value + final int[][] orders = new int[][] {{0, 1, 2}, {1, 0, 2}, {1, 2, 0}, {2, 1, 0}}; + for (int i = 0; i < orders.length; i++) { + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, partialOI); + evaluator.reset(agg); + + evaluator.merge(agg, partials[orders[i][0]]); + evaluator.merge(agg, partials[orders[i][1]]); + evaluator.merge(agg, partials[orders[i][2]]); + + Assert.assertEquals(0.83333, agg.get(), 1e-5); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/97bc9124/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index 40f20a8..994f9d8 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -66,6 +66,7 @@ ## Part IV - Evaluation * [Statistical evaluation of a prediction model](eval/stat_eval.md) + * [Area Under the ROC Curve](eval/auc.md) * [Data Generation](eval/datagen.md) * [Logistic Regression data generation](eval/lr_datagen.md) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/97bc9124/docs/gitbook/eval/auc.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/eval/auc.md b/docs/gitbook/eval/auc.md new file mode 100644 index 0000000..3c8de95 --- /dev/null +++ b/docs/gitbook/eval/auc.md @@ -0,0 +1,104 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +<!-- toc --> + +# Area Under the ROC Curve + +[ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) and Area Under the ROC Curve (AUC) are widely-used metric for binary (i.e., positive or negative) classification problems such as [Logistic Regression](../binaryclass/a9a_lr.html). + +Binary classifiers generally predict how likely a sample is to be positive by computing probability. Ultimately, we can evaluate the classifiers by comparing the probabilities with truth positive/negative labels. + +Now we assume that there is a table which contains predicted scores (i.e., probabilities) and truth labels as follows: + +| probability<br/>(predicted score) | truth label | +|:---:|:---:| +| 0.5 | 0 | +| 0.3 | 1 | +| 0.2 | 0 | +| 0.8 | 1 | +| 0.7 | 1 | + +Once the rows are sorted by the probabilities in a descending order, AUC gives a metric based on how many positive (`label=1`) samples are ranked higher than negative (`label=0`) samples. If many positive rows get larger scores than negative rows, AUC would be large, and hence our classifier would perform well. + +# Compute AUC on Hivemall + +On Hivemall, a function `auc(double score, int label)` provides a way to compute AUC for pairs of probability and truth label. + +For instance, following query computes AUC of the table which was shown above: + +```sql +with data as ( + select 0.5 as prob, 0 as label + union all + select 0.3 as prob, 1 as label + union all + select 0.2 as prob, 0 as label + union all + select 0.8 as prob, 1 as label + union all + select 0.7 as prob, 1 as label +), data_ordered as ( + select prob, label + from data + order by prob desc +) +select auc(prob, label) +from data_ordered; +``` + +This query returns `0.83333` as AUC. + +Since AUC is a metric based on ranked probability-label pairs as mentioned above, input data (rows) needs to be ordered by scores in a descending order. + +Meanwhile, Hive's `distribute by` clause allows you to compute AUC in parallel: + +```sql +with data as ( + select 0.5 as prob, 0 as label + union all + select 0.3 as prob, 1 as label + union all + select 0.2 as prob, 0 as label + union all + select 0.8 as prob, 1 as label + union all + select 0.7 as prob, 1 as label +), data_ordered as ( + select prob, label + from data + order by prob desc +) +select auc(prob, label) +from ( + select prob, label + from data_ordered + distribute by floor(prob / 0.2) +) t; +``` + +Note that `floor(prob / 0.2)` means that the rows are distributed to 5 bins for the AUC computation because the column `prob` is in a [0, 1] range. + +# Difference between AUC and Logarithmic Loss + +Hivemall has another metric called [Logarithmic Loss](stat_eval.html#logarithmic-loss) for binary classification. Both AUC and Logarithmic Loss compute scores for probability-label pairs. + +Score produced by AUC is a relative metric based on sorted pairs. On the other hand, Logarithmic Loss simply gives a metric by comparing probability with its truth label one-by-one. + +To give an example, `auc(prob, label)` and `logloss(prob, label)` respectively returns `0.83333` and `0.54001` in the above case. Note that larger AUC and smaller Logarithmic Loss are better. \ No newline at end of file
