Repository: incubator-hivemall Updated Branches: refs/heads/master 06f2f8220 -> c2b95783c
Close #115: [HIVEMALL-124][BUGFIX] Fixed bugs in BinaryResponseMeasure (nDCG, MRR, AP) Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/c2b95783 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/c2b95783 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/c2b95783 Branch: refs/heads/master Commit: c2b95783cf9d6fc1646a48ac928e96152eab98c6 Parents: 06f2f82 Author: Makoto Yui <[email protected]> Authored: Fri Sep 15 18:52:33 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Fri Sep 15 18:52:33 2017 +0900 ---------------------------------------------------------------------- .../main/java/hivemall/HivemallConstants.java | 2 + .../evaluation/BinaryResponsesMeasures.java | 122 ++++++++++++++----- .../main/java/hivemall/evaluation/MAPUDAF.java | 2 +- .../main/java/hivemall/evaluation/MRRUDAF.java | 2 +- .../main/java/hivemall/evaluation/NDCGUDAF.java | 32 +++-- .../hivemall/tools/list/UDAFToOrderedList.java | 2 +- .../java/hivemall/utils/hadoop/HiveUtils.java | 18 ++- .../java/hivemall/utils/math/MathUtils.java | 5 + .../evaluation/BinaryResponsesMeasuresTest.java | 101 +++++++++++++-- docs/gitbook/eval/rank.md | 33 ++--- 10 files changed, 250 insertions(+), 69 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/HivemallConstants.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/HivemallConstants.java b/core/src/main/java/hivemall/HivemallConstants.java index 0eb9feb..67bb228 100644 --- a/core/src/main/java/hivemall/HivemallConstants.java +++ b/core/src/main/java/hivemall/HivemallConstants.java @@ -18,6 +18,7 @@ */ package hivemall; + public final class HivemallConstants { public static final String VERSION = "0.4.2-rc.2"; @@ -35,6 +36,7 @@ public final class HivemallConstants { public static final String BIGINT_TYPE_NAME = "bigint"; public static final String FLOAT_TYPE_NAME = "float"; public static final String DOUBLE_TYPE_NAME = "double"; + public static final String DECIMAL_TYPE_NAME = "decimal"; public static final String STRING_TYPE_NAME = "string"; public static final String DATE_TYPE_NAME = "date"; public static final String DATETIME_TYPE_NAME = "datetime"; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java index 81cf075..7c21849 100644 --- a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java +++ b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java @@ -18,8 +18,12 @@ */ package hivemall.evaluation; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.math.MathUtils; + import java.util.List; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; /** @@ -40,19 +44,25 @@ public final class BinaryResponsesMeasures { * @return nDCG */ public static double nDCG(@Nonnull final List<?> rankedList, - @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { + @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + double dcg = 0.d; - double idcg = IDCG(Math.min(recommendSize, groundTruth.size())); - for (int i = 0, n = recommendSize; i < n; i++) { + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (!groundTruth.contains(item_id)) { continue; } int rank = i + 1; - dcg += Math.log(2) / Math.log(rank + 1); + dcg += 1.d / MathUtils.log2(rank + 1); } + final double idcg = IDCG(Math.min(groundTruth.size(), k)); + if (idcg == 0.d) { + return 0.d; + } return dcg / idcg; } @@ -62,10 +72,12 @@ public final class BinaryResponsesMeasures { * @param n the number of positive items * @return ideal DCG */ - public static double IDCG(final int n) { + public static double IDCG(@Nonnegative final int n) { + Preconditions.checkArgument(n >= 0); + double idcg = 0.d; for (int i = 0; i < n; i++) { - idcg += Math.log(2) / Math.log(i + 2); + idcg += 1.d / MathUtils.log2(i + 2); } return idcg; } @@ -79,8 +91,26 @@ public final class BinaryResponsesMeasures { * @return Precision */ public static double Precision(@Nonnull final List<?> rankedList, - @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { - return (double) countTruePositive(rankedList, groundTruth, recommendSize) / recommendSize; + @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { + if (rankedList.isEmpty()) { + if (groundTruth.isEmpty()) { + return 1.d; + } + return 0.d; + } + + Preconditions.checkArgument(recommendSize > 0); // can be zero when groundTruth is empty + + int nTruePositive = 0; + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { + Object item_id = rankedList.get(i); + if (groundTruth.contains(item_id)) { + nTruePositive++; + } + } + + return ((double) nTruePositive) / k; } /** @@ -92,8 +122,15 @@ public final class BinaryResponsesMeasures { * @return Recall */ public static double Recall(@Nonnull final List<?> rankedList, - @Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) { - return (double) countTruePositive(rankedList, groundTruth, recommendSize) + @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { + if (groundTruth.isEmpty()) { + if (rankedList.isEmpty()) { + return 1.d; + } + return 0.d; + } + + return ((double) TruePositives(rankedList, groundTruth, recommendSize)) / groundTruth.size(); } @@ -105,11 +142,14 @@ public final class BinaryResponsesMeasures { * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return number of true positives */ - public static int countTruePositive(final List<?> rankedList, final List<?> groundTruth, - final int recommendSize) { + public static int TruePositives(final List<?> rankedList, final List<?> groundTruth, + @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + int nTruePositive = 0; - for (int i = 0, n = recommendSize; i < n; i++) { + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { nTruePositive++; @@ -120,48 +160,65 @@ public final class BinaryResponsesMeasures { } /** - * Computes Mean Reciprocal Rank (MRR) + * Computes Reciprocal Rank * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended - * @return MRR + * @return Reciprocal Rank + * @link https://en.wikipedia.org/wiki/Mean_reciprocal_rank */ - public static double MRR(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, - @Nonnull final int recommendSize) { - for (int i = 0, n = recommendSize; i < n; i++) { + public static double ReciprocalRank(@Nonnull final List<?> rankedList, + @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { - return 1.0 / (i + 1.0); + return 1.d / (i + 1); } } - return 0.0; + return 0.d; } /** - * Computes Mean Average Precision (MAP) + * Computes Average Precision (AP) * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended - * @return MAP + * @return AveragePrecision */ - public static double MAP(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, - @Nonnull final int recommendSize) { + public static double AveragePrecision(@Nonnull final List<?> rankedList, + @Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + + if (groundTruth.isEmpty()) { + if (rankedList.isEmpty()) { + return 1.d; + } + return 0.d; + } + int nTruePositive = 0; - double sumPrecision = 0.0; + double sumPrecision = 0.d; // accumulate precision@1 to @recommendSize - for (int i = 0, n = recommendSize; i < n; i++) { + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { nTruePositive++; - sumPrecision += nTruePositive / (i + 1.0); + sumPrecision += nTruePositive / (i + 1.d); } } - return sumPrecision / groundTruth.size(); + if (nTruePositive == 0) { + return 0.d; + } + return sumPrecision / nTruePositive; } /** @@ -173,11 +230,14 @@ public final class BinaryResponsesMeasures { * @return AUC */ public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth, - @Nonnull final int recommendSize) { + @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + int nTruePositive = 0, nCorrectPairs = 0; // count # of pairs of items that are ranked in the correct order (i.e. TP > FP) - for (int i = 0, n = recommendSize; i < n; i++) { + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { // # of true positives which are ranked higher position than i-th recommended item @@ -197,7 +257,7 @@ public final class BinaryResponsesMeasures { } // AUC can equivalently be calculated by counting the portion of correctly ordered pairs - return (double) nCorrectPairs / nPairs; + return ((double) nCorrectPairs) / nPairs; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/evaluation/MAPUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/MAPUDAF.java b/core/src/main/java/hivemall/evaluation/MAPUDAF.java index cac6de5..3878684 100644 --- a/core/src/main/java/hivemall/evaluation/MAPUDAF.java +++ b/core/src/main/java/hivemall/evaluation/MAPUDAF.java @@ -235,7 +235,7 @@ public final class MAPUDAF extends AbstractGenericUDAFResolver { void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList, @Nonnull int recommendSize) { - sum += BinaryResponsesMeasures.MAP(recommendList, truthList, recommendSize); + sum += BinaryResponsesMeasures.AveragePrecision(recommendList, truthList, recommendSize); count++; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/evaluation/MRRUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/MRRUDAF.java b/core/src/main/java/hivemall/evaluation/MRRUDAF.java index 41a236d..f5aba3b 100644 --- a/core/src/main/java/hivemall/evaluation/MRRUDAF.java +++ b/core/src/main/java/hivemall/evaluation/MRRUDAF.java @@ -235,7 +235,7 @@ public final class MRRUDAF extends AbstractGenericUDAFResolver { void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList, @Nonnull int recommendSize) { - sum += BinaryResponsesMeasures.MRR(recommendList, truthList, recommendSize); + sum += BinaryResponsesMeasures.ReciprocalRank(recommendList, truthList, recommendSize); count++; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/evaluation/NDCGUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java index f50d27a..f1ba832 100644 --- a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java +++ b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java @@ -18,6 +18,8 @@ */ package hivemall.evaluation; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableLongObjectInspector; import hivemall.utils.hadoop.HiveUtils; import java.util.ArrayList; @@ -38,10 +40,11 @@ import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; @@ -120,8 +123,8 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver { } private static StructObjectInspector internalMergeOI() { - ArrayList<String> fieldNames = new ArrayList<String>(); - ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + List<String> fieldNames = new ArrayList<>(); + List<ObjectInspector> fieldOIs = new ArrayList<>(); fieldNames.add("sum"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); @@ -180,20 +183,31 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver { StructObjectInspector sOI = (StructObjectInspector) recommendListOI.getListElementObjectInspector(); List<?> fieldRefList = sOI.getAllStructFieldRefs(); StructField relScoreField = (StructField) fieldRefList.get(0); - WritableDoubleObjectInspector relScoreFieldOI = (WritableDoubleObjectInspector) relScoreField.getFieldObjectInspector(); + PrimitiveObjectInspector relScoreFieldOI = HiveUtils.asDoubleCompatibleOI(relScoreField.getFieldObjectInspector()); for (int i = 0, n = recommendList.size(); i < n; i++) { Object structObj = recommendList.get(i); List<Object> fieldList = sOI.getStructFieldsDataAsList(structObj); - double relScore = (double) relScoreFieldOI.get(fieldList.get(0)); + Object field0 = fieldList.get(0); + if (field0 == null) { + throw new UDFArgumentException("Field 0 of a struct field is null: " + + fieldList); + } + double relScore = PrimitiveObjectInspectorUtils.getDouble(field0, + relScoreFieldOI); recommendRelScoreList.add(relScore); } // Create a ordered list of relevance scores for truth items List<Double> truthRelScoreList = new ArrayList<Double>(); - WritableDoubleObjectInspector truthRelScoreOI = (WritableDoubleObjectInspector) truthListOI.getListElementObjectInspector(); + PrimitiveObjectInspector truthRelScoreOI = HiveUtils.asDoubleCompatibleOI(truthListOI.getListElementObjectInspector()); for (int i = 0, n = truthList.size(); i < n; i++) { Object relScoreObj = truthList.get(i); - double relScore = (double) truthRelScoreOI.get(relScoreObj); + if (relScoreObj == null) { + throw new UDFArgumentException("Found null in the ground truth: " + + truthList); + } + double relScore = PrimitiveObjectInspectorUtils.getDouble(relScoreObj, + truthRelScoreOI); truthRelScoreList.add(relScore); } @@ -224,8 +238,8 @@ public final class NDCGUDAF extends AbstractGenericUDAFResolver { Object sumObj = internalMergeOI.getStructFieldData(partial, sumField); Object countObj = internalMergeOI.getStructFieldData(partial, countField); - double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj); - long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj); + double sum = writableDoubleObjectInspector.get(sumObj); + long count = writableLongObjectInspector.get(countObj); NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer) agg; myAggr.merge(sum, count); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java index e88a16c..52c521c 100644 --- a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java +++ b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java @@ -207,7 +207,7 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { || (argOIs.length == 3 && HiveUtils.isConstString(argOIs[2])); if (sortByKey) { - this.valueOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.valueOI = argOIs[0]; this.keyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]); } else { // sort values by value itself http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 8fba349..b8b344c 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -21,6 +21,7 @@ package hivemall.utils.hadoop; import static hivemall.HivemallConstants.BIGINT_TYPE_NAME; import static hivemall.HivemallConstants.BINARY_TYPE_NAME; import static hivemall.HivemallConstants.BOOLEAN_TYPE_NAME; +import static hivemall.HivemallConstants.DECIMAL_TYPE_NAME; import static hivemall.HivemallConstants.DOUBLE_TYPE_NAME; import static hivemall.HivemallConstants.FLOAT_TYPE_NAME; import static hivemall.HivemallConstants.INT_TYPE_NAME; @@ -47,6 +48,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject; import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.lazy.ByteArrayRef; import org.apache.hadoop.hive.serde2.lazy.LazyDouble; @@ -265,6 +267,7 @@ public final class HiveUtils { case LONG: case FLOAT: case DOUBLE: + case DECIMAL: case BYTE: //case TIMESTAMP: return true; @@ -357,6 +360,7 @@ public final class HiveUtils { case LONG: case FLOAT: case DOUBLE: + case DECIMAL: return true; default: return false; @@ -404,6 +408,7 @@ public final class HiveUtils { switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { case DOUBLE: case FLOAT: + case DECIMAL: return true; default: return false; @@ -630,6 +635,9 @@ public final class HiveUtils { } else if (TINYINT_TYPE_NAME.equals(typeName)) { ByteWritable v = getConstValue(numberOI); return v.get(); + } else if (DECIMAL_TYPE_NAME.equals(typeName)) { + HiveDecimalWritable v = getConstValue(numberOI); + return v.getHiveDecimal().floatValue(); } throw new UDFArgumentException("Unexpected argument type to cast as double: " + TypeInfoUtils.getTypeInfoFromObjectInspector(numberOI)); @@ -656,6 +664,9 @@ public final class HiveUtils { } else if (TINYINT_TYPE_NAME.equals(typeName)) { ByteWritable v = getConstValue(numberOI); return v.get(); + } else if (DECIMAL_TYPE_NAME.equals(typeName)) { + HiveDecimalWritable v = getConstValue(numberOI); + return v.getHiveDecimal().doubleValue(); } throw new UDFArgumentException("Unexpected argument type to cast as double: " + TypeInfoUtils.getTypeInfoFromObjectInspector(numberOI)); @@ -923,10 +934,10 @@ public final class HiveUtils { case LONG: case FLOAT: case DOUBLE: + case DECIMAL: case BOOLEAN: case BYTE: case STRING: - case DECIMAL: break; default: throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName() @@ -951,9 +962,9 @@ public final class HiveUtils { case BOOLEAN: case FLOAT: case DOUBLE: + case DECIMAL: case STRING: case TIMESTAMP: - case DECIMAL: break; default: throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName() @@ -998,6 +1009,7 @@ public final class HiveUtils { case LONG: case FLOAT: case DOUBLE: + case DECIMAL: case STRING: case TIMESTAMP: break; @@ -1020,6 +1032,7 @@ public final class HiveUtils { switch (oi.getPrimitiveCategory()) { case FLOAT: case DOUBLE: + case DECIMAL: break; default: throw new UDFArgumentTypeException(0, @@ -1044,6 +1057,7 @@ public final class HiveUtils { case LONG: case FLOAT: case DOUBLE: + case DECIMAL: break; default: throw new UDFArgumentTypeException(0, http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 6162adb..ee533dc 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -43,6 +43,7 @@ import javax.annotation.Nullable; import org.apache.commons.math3.special.Gamma; public final class MathUtils { + private static final double LOG2 = Math.log(2); private MathUtils() {} @@ -246,6 +247,10 @@ public final class MathUtils { return Math.log(n) / Math.log(base); } + public static double log2(final double n) { + return Math.log(n) / LOG2; + } + public static int floorDiv(final int x, final int y) { int r = x / y; // if the signs are different and modulo not zero, round down http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java index 9f8a04e..5e8f253 100644 --- a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java +++ b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java @@ -18,8 +18,8 @@ */ package hivemall.evaluation; -import java.util.Collections; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.junit.Assert; @@ -40,6 +40,18 @@ public class BinaryResponsesMeasuresTest { } @Test + public void testNDCG2() { + List<Integer> rankedList = Arrays.asList(3, 2, 1, 6); + List<Integer> groundTruth = Arrays.asList(1); + + double actual = BinaryResponsesMeasures.nDCG(rankedList, groundTruth, 2); + Assert.assertEquals(0.d, actual, 0.0001d); + + actual = BinaryResponsesMeasures.nDCG(rankedList, groundTruth, 3); + Assert.assertEquals(0.5d, actual, 0.0001d); + } + + @Test public void testRecall() { List<Integer> rankedList = Arrays.asList(1, 3, 2, 6); List<Integer> groundTruth = Arrays.asList(1, 2, 4); @@ -52,6 +64,16 @@ public class BinaryResponsesMeasuresTest { } @Test + public void testRecallEmpty() { + Assert.assertEquals(1.d, + BinaryResponsesMeasures.Recall(Collections.emptyList(), Collections.emptyList(), 2), + 0.d); + + Assert.assertEquals(0.d, + BinaryResponsesMeasures.Recall(Arrays.asList(1, 3, 2), Collections.emptyList(), 2), 0.d); + } + + @Test public void testPrecision() { List<Integer> rankedList = Arrays.asList(1, 3, 2, 6); List<Integer> groundTruth = Arrays.asList(1, 2, 4); @@ -65,32 +87,91 @@ public class BinaryResponsesMeasuresTest { } @Test - public void testMRR() { + public void testPrecisionEmpty() { + Assert.assertEquals(1.d, + BinaryResponsesMeasures.Precision(Collections.emptyList(), Collections.emptyList(), 2), + 0.d); + + Assert.assertEquals(0.d, + BinaryResponsesMeasures.Precision(Arrays.asList(1, 3, 2), Collections.emptyList(), 2), + 0.d); + } + + @Test + public void testRR() { List<Integer> rankedList = Arrays.asList(1, 3, 2, 6); List<Integer> groundTruth = Arrays.asList(1, 2, 4); - double actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, rankedList.size()); + double actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, + rankedList.size()); Assert.assertEquals(1.0d, actual, 0.0001d); Collections.reverse(rankedList); - actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, rankedList.size()); + actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, rankedList.size()); Assert.assertEquals(0.5d, actual, 0.0001d); - actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, 1); + actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, 1); Assert.assertEquals(0.0d, actual, 0.0001d); } @Test - public void testMAP() { + public void testAP() { List<Integer> rankedList = Arrays.asList(1, 3, 2, 6); List<Integer> groundTruth = Arrays.asList(1, 2, 4); - double actual = BinaryResponsesMeasures.MAP(rankedList, groundTruth, rankedList.size()); - Assert.assertEquals(0.5555555555555555d, actual, 0.0001d); + double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, + rankedList.size()); + Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d); + + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 4); + Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d); - actual = BinaryResponsesMeasures.MAP(rankedList, groundTruth, 2); - Assert.assertEquals(0.3333333333333333d, actual, 0.0001d); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 3); + Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d); + + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2); + Assert.assertEquals(1.0 / 1.0 * (1.0 / 1.0), actual, 0.0001d); + + rankedList = Arrays.asList(3, 1, 2, 6); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2); + Assert.assertEquals(1.0 / 1.0 * (1.0 / 2.0), actual, 0.0001d); + + groundTruth = Arrays.asList(1, 2, 3); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2); + Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 2.0), actual, 0.0001d); + + rankedList = Arrays.asList(3, 1); + groundTruth = Arrays.asList(1, 2); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2); + Assert.assertEquals(1.0 / 1.0 * (1.0 / 2.0), actual, 0.0001d); + } + + @Test + public void testAPString() { + List<String> rankedList = Arrays.asList("a", "b", "c", "d", "e", "f", "g"); + List<String> groundTruth = Arrays.asList("a", "x", "x", "d", "x", "x"); + + double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 6); + Assert.assertEquals(0.75d, actual, 0.0001d); + } + + @Test + public void testAPString10() { + List<String> rankedList = Arrays.asList("a", "b", "c", "d", "e", "f", "g", "h", "i", "j"); + List<String> groundTruth = Arrays.asList("a", "x", "c", "x", "e", "f"); + + double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 10); + Assert.assertEquals(1.0 / 4.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0 + 4.0 / 6.0), actual, + 0.0001d); + + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 5); + Assert.assertEquals(1.0 / 3.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0), actual, 0.0001d); + + groundTruth = Arrays.asList("a", "x", "c", "x", "e", "f", "x", "x", "x", "x"); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 10); + Assert.assertEquals(1.0 / 4.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0 + 4.0 / 6.0), actual, + 0.0001d); } @Test http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c2b95783/docs/gitbook/eval/rank.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/eval/rank.md b/docs/gitbook/eval/rank.md index 207418e..30d82e5 100644 --- a/docs/gitbook/eval/rank.md +++ b/docs/gitbook/eval/rank.md @@ -83,7 +83,8 @@ with truth as ( rec as ( select userid, - map_values(to_ordered_map(score, itemid, true)) as rec, + -- map_values(to_ordered_map(score, itemid, true)) as rec, + to_ordered_list(itemid, score, '-reverse') as rec, cast(count(itemid) as int) as max_k from dummy_rec group by userid @@ -222,7 +223,7 @@ While the binary response setting simply considers positive-only ranked list of Unlike separated `dummy_truth` and `dummy_rec` table in the binary setting, we assume the following single table named `dummy_recrel` which contains item-$$\mathrm{rel}_n$$ pairs: -| userid | itemid | score<br/>(predicted) | rel<br/>(expected) | +| userid | itemid | score<br/>(predicted) | relscore<br/>(expected) | | :-: | :-: | :-: | :-: | | 1 | 1 | 10.0 | 5.0 | | 1 | 3 | 8.0 | 2.0 | @@ -244,27 +245,31 @@ The function `ndcg()` can take non-binary `truth` values as the second argument: ```sql with truth as ( - select userid, map_keys(to_ordered_map(relscore, itemid, true)) as truth - from dummy_recrel - group by userid + select + userid, + to_ordered_list(relscore, '-reverse') as truth + from + dummy_recrel + group by + userid ), rec as ( select userid, - map_values ( - to_ordered_map(score, struct(relscore, itemid), true) - ) as rec, - cast(count(itemid) as int) as max_k - from dummy_recrel - group by userid + to_ordered_list(struct(relscore, itemid), score, "-reverse") as rec, + count(itemid) as max_k + from + dummy_recrel + group by + userid ) select -- top-2 recommendation ndcg(t1.rec, t2.truth, 2), -- => 0.8128912838590544 - -- top-3 recommendation ndcg(t1.rec, t2.truth, 3) -- => 0.9187707805346093 -from rec t1 -join truth t2 on (t1.userid = t2.userid) +from + rec t1 + join truth t2 on (t1.userid = t2.userid) ; ```
