Repository: incubator-hivemall Updated Branches: refs/heads/master bedbd39ca -> c1cd4b2e0
Close #107: [HIVEMALL-132] Generalize f1score UDAF to support any Beta value Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/098a7f3d Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/098a7f3d Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/098a7f3d Branch: refs/heads/master Commit: 098a7f3d9999c8c910d0516ebe8079babae67146 Parents: bedbd39 Author: Kento NOZAWA <[email protected]> Authored: Wed Sep 13 22:17:45 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Wed Sep 13 22:17:45 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/UDAFEvaluatorWithOptions.java | 119 ++++++ .../java/hivemall/evaluation/F1ScoreUDAF.java | 134 ++++++ .../java/hivemall/evaluation/FMeasureUDAF.java | 428 +++++++++++++++---- .../hivemall/evaluation/FMeasureUDAFTest.java | 394 +++++++++++++++++ docs/gitbook/SUMMARY.md | 5 +- docs/gitbook/eval/auc.md | 8 +- .../eval/binary_classification_measures.md | 223 ++++++++++ .../eval/multilabel_classification_measures.md | 147 +++++++ docs/gitbook/eval/regression.md | 77 ++++ docs/gitbook/eval/stat_eval.md | 77 ---- resources/ddl/define-all-as-permanent.hive | 5 +- resources/ddl/define-all.hive | 5 +- resources/ddl/define-all.spark | 5 +- resources/ddl/define-udfs.td.hql | 3 +- 14 files changed, 1468 insertions(+), 162 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/core/src/main/java/hivemall/UDAFEvaluatorWithOptions.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/UDAFEvaluatorWithOptions.java b/core/src/main/java/hivemall/UDAFEvaluatorWithOptions.java new file mode 100644 index 0000000..de1564c --- /dev/null +++ b/core/src/main/java/hivemall/UDAFEvaluatorWithOptions.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall; + +import hivemall.utils.lang.CommandLineUtils; + +import java.io.PrintWriter; +import java.io.StringWriter; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.MapredContext; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.mapred.Counters.Counter; +import org.apache.hadoop.mapred.Reporter; + +public abstract class UDAFEvaluatorWithOptions extends GenericUDAFEvaluator { + + @Nullable + protected MapredContext mapredContext; + + @Override + public final void configure(MapredContext mapredContext) { + this.mapredContext = mapredContext; + } + + @Nullable + protected final Reporter getReporter() { + if (mapredContext == null) { + return null; + } + return mapredContext.getReporter(); + } + + protected static void reportProgress(@Nullable Reporter reporter) { + if (reporter != null) { + synchronized (reporter) { + reporter.progress(); + } + } + } + + protected static void setCounterValue(@Nullable Counter counter, long value) { + if (counter != null) { + synchronized (counter) { + counter.setValue(value); + } + } + } + + protected static void incrCounter(@Nullable Counter counter, long incr) { + if (counter != null) { + synchronized (counter) { + counter.increment(incr); + } + } + } + + @Nonnull + protected abstract Options getOptions(); + + @Nonnull + protected final CommandLine parseOptions(@Nonnull String optionValue) + throws UDFArgumentException { + String[] args = optionValue.split("\\s+"); + Options opts = getOptions(); + opts.addOption("help", false, "Show function help"); + CommandLine cl = CommandLineUtils.parseOptions(args, opts); + + if (cl.hasOption("help")) { + Description funcDesc = getClass().getAnnotation(Description.class); + final String cmdLineSyntax; + if (funcDesc == null) { + cmdLineSyntax = getClass().getSimpleName(); + } else { + String funcName = funcDesc.name(); + cmdLineSyntax = funcName == null ? getClass().getSimpleName() + : funcDesc.value().replace("_FUNC_", funcDesc.name()); + } + StringWriter sw = new StringWriter(); + sw.write('\n'); + PrintWriter pw = new PrintWriter(sw); + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, + HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); + pw.flush(); + String helpMsg = sw.toString(); + throw new UDFArgumentException(helpMsg); + } + + return cl; + } + + protected abstract CommandLine processOptions(ObjectInspector[] argOIs) + throws UDFArgumentException; +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/core/src/main/java/hivemall/evaluation/F1ScoreUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/F1ScoreUDAF.java b/core/src/main/java/hivemall/evaluation/F1ScoreUDAF.java new file mode 100644 index 0000000..ba1c44e --- /dev/null +++ b/core/src/main/java/hivemall/evaluation/F1ScoreUDAF.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.evaluation; + +import hivemall.utils.hadoop.WritableUtils; + +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDAF; +import org.apache.hadoop.hive.ql.exec.UDAFEvaluator; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; + +@SuppressWarnings("deprecation") +@Description(name = "f1score", value = "_FUNC_(array[int], array[int]) - Return a F1 score") +public final class F1ScoreUDAF extends UDAF { + + public static class Evaluator implements UDAFEvaluator { + + public static class PartialResult { + long tp; + /** tp + fn */ + long totalActual; + /** tp + fp */ + long totalPredicted; + + PartialResult() { + this.tp = 0L; + this.totalPredicted = 0L; + this.totalActual = 0L; + } + + void updateScore(final List<IntWritable> actual, final List<IntWritable> predicted) { + final int numActual = actual.size(); + final int numPredicted = predicted.size(); + int countTp = 0; + for (int i = 0; i < numPredicted; i++) { + IntWritable p = predicted.get(i); + if (actual.contains(p)) { + countTp++; + } + } + this.tp += countTp; + this.totalActual += numActual; + this.totalPredicted += numPredicted; + } + + void merge(PartialResult other) { + this.tp += other.tp; + this.totalActual += other.totalActual; + this.totalPredicted += other.totalPredicted; + } + } + + private PartialResult partial; + + @Override + public void init() { + this.partial = null; + } + + public boolean iterate(List<IntWritable> actual, List<IntWritable> predicted) { + if (partial == null) { + this.partial = new PartialResult(); + } + partial.updateScore(actual, predicted); + return true; + } + + public PartialResult terminatePartial() { + return partial; + } + + public boolean merge(PartialResult other) { + if (other == null) { + return true; + } + if (partial == null) { + this.partial = new PartialResult(); + } + partial.merge(other); + return true; + } + + public DoubleWritable terminate() { + if (partial == null) { + return null; + } + double score = f1Score(partial); + return WritableUtils.val(score); + } + + /** + * @return 2 * precision * recall / (precision + recall) + */ + private static double f1Score(final PartialResult partial) { + double precision = precision(partial); + double recall = recall(partial); + double divisor = precision + recall; + if (divisor > 0) { + return (2.d * precision * recall) / divisor; + } else { + return -1d; + } + } + + private static double precision(final PartialResult partial) { + return (partial.totalPredicted == 0L) ? 0d : partial.tp + / (double) partial.totalPredicted; + } + + private static double recall(final PartialResult partial) { + return (partial.totalActual == 0L) ? 0d : partial.tp / (double) partial.totalActual; + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java b/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java index 5d41cb8..feb50b7 100644 --- a/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java +++ b/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java @@ -18,118 +18,396 @@ */ package hivemall.evaluation; -import hivemall.utils.hadoop.WritableUtils; +import hivemall.UDAFEvaluatorWithOptions; +import hivemall.utils.hadoop.HiveUtils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import hivemall.utils.lang.Primitives; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; + import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDAF; -import org.apache.hadoop.hive.ql.exec.UDAFEvaluator; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.io.IntWritable; - -@SuppressWarnings("deprecation") -@Description(name = "f1score", - value = "_FUNC_(array[int], array[int]) - Return a F-measure/F1 score") -public final class FMeasureUDAF extends UDAF { - - public static class Evaluator implements UDAFEvaluator { - - public static class PartialResult { - long tp; - /** tp + fn */ - long totalAcutal; - /** tp + fp */ - long totalPredicted; - - PartialResult() { - this.tp = 0L; - this.totalPredicted = 0L; - this.totalAcutal = 0L; - } +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; - void updateScore(final List<IntWritable> actual, final List<IntWritable> predicted) { - final int numActual = actual.size(); - final int numPredicted = predicted.size(); - int countTp = 0; - for (int i = 0; i < numPredicted; i++) { - IntWritable p = predicted.get(i); - if (actual.contains(p)) { - countTp++; - } +import javax.annotation.Nonnull; + +@Description( + name = "fmeasure", + value = "_FUNC_(array | int | boolean actual , array | int | boolean predicted, String) - Return a F-measure (f1score is the special with beta=1.)") +public final class FMeasureUDAF extends AbstractGenericUDAFResolver { + @Override + public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 2 && typeInfo.length != 3) { + throw new UDFArgumentTypeException(typeInfo.length - 1, + "_FUNC_ takes two or three arguments"); + } + + boolean isArg1ListOrIntOrBoolean = HiveUtils.isListTypeInfo(typeInfo[0]) + || HiveUtils.isIntegerTypeInfo(typeInfo[0]) + || HiveUtils.isBooleanTypeInfo(typeInfo[0]); + if (!isArg1ListOrIntOrBoolean) { + throw new UDFArgumentTypeException(0, + "The first argument `array/int/boolean actual` is invalid form: " + typeInfo[0]); + } + + boolean isArg2ListOrIntOrBoolean = HiveUtils.isListTypeInfo(typeInfo[1]) + || HiveUtils.isIntegerTypeInfo(typeInfo[1]) + || HiveUtils.isBooleanTypeInfo(typeInfo[1]); + if (!isArg2ListOrIntOrBoolean) { + throw new UDFArgumentTypeException(1, + "The second argument `array/int/boolean predicted` is invalid form: " + typeInfo[1]); + } + + if (typeInfo[0] != typeInfo[1]) { + throw new UDFArgumentTypeException(1, "The first argument `actual`'s type is " + + typeInfo[0] + ", but the second argument `predicted`'s type is not match: " + + typeInfo[1]); + } + + return new Evaluator(); + } + + public static class Evaluator extends UDAFEvaluatorWithOptions { + + private ObjectInspector actualOI; + private ObjectInspector predictedOI; + private StructObjectInspector internalMergeOI; + + private StructField tpField; + private StructField totalActualField; + private StructField totalPredictedField; + private StructField betaOptionField; + private StructField averageOptionFiled; + + private double beta; + private String average; + + public Evaluator() {} + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("beta", true, "The weight of precision [default: 1.]"); + opts.addOption("average", true, "The way of average calculation [default: micro]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + double beta = 1.0d; + String average = "micro"; + + if (argOIs.length >= 3) { + String rawArgs = HiveUtils.getConstString(argOIs[2]); + cl = parseOptions(rawArgs); + + beta = Primitives.parseDouble(cl.getOptionValue("beta"), beta); + if (beta <= 0.d) { + throw new UDFArgumentException( + "The third argument `double beta` must be greater than 0.0: " + beta); + } + + average = cl.getOptionValue("average", average); + + if (average.equals("macro")) { + throw new UDFArgumentException("\"-average macro\" is not supported"); } - this.tp += countTp; - this.totalAcutal += numActual; - this.totalPredicted += numPredicted; + + if (!(average.equals("binary") || average.equals("micro"))) { + throw new UDFArgumentException( + "The third argument `String average` must be one of the {binary, micro, macro}: " + + average); + } + } + + this.beta = beta; + this.average = average; + return cl; + } + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 2 || parameters.length == 3) : parameters.length; + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.processOptions(parameters); + this.actualOI = parameters[0]; + this.predictedOI = parameters[1]; + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) parameters[0]; + this.internalMergeOI = soi; + this.tpField = soi.getStructFieldRef("tp"); + this.totalActualField = soi.getStructFieldRef("totalActual"); + this.totalPredictedField = soi.getStructFieldRef("totalPredicted"); + this.betaOptionField = soi.getStructFieldRef("beta"); + this.averageOptionFiled = soi.getStructFieldRef("average"); } - void merge(PartialResult other) { - this.tp = other.tp; - this.totalAcutal = other.totalAcutal; - this.totalPredicted = other.totalPredicted; + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(); + } else {// terminate + outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; } + return outputOI; } - private PartialResult partial; + private static StructObjectInspector internalMergeOI() { + ArrayList<String> fieldNames = new ArrayList<>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<>(); + + fieldNames.add("tp"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("totalActual"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("totalPredicted"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); + fieldNames.add("beta"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("average"); + fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } @Override - public void init() { - this.partial = null; + public FMeasureAggregationBuffer getNewAggregationBuffer() throws HiveException { + FMeasureAggregationBuffer myAggr = new FMeasureAggregationBuffer(); + reset(myAggr); + return myAggr; } - public boolean iterate(List<IntWritable> actual, List<IntWritable> predicted) { - if (partial == null) { - this.partial = new PartialResult(); + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; + myAggr.reset(); + myAggr.setOptions(this.beta, this.average); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; + boolean isList = HiveUtils.isListOI(actualOI) && HiveUtils.isListOI(predictedOI); + + final List<?> actual; + final List<?> predicted; + + if (isList) {// array case + if (this.average.equals("binary")) { + throw new UDFArgumentException( + "\"-average binary\" is not supported when `predict` is array"); + } + actual = ((ListObjectInspector) actualOI).getList(parameters[0]); + predicted = ((ListObjectInspector) predictedOI).getList(parameters[1]); + } else {//binary case + if (HiveUtils.isBooleanOI(actualOI)) { // boolean case + actual = Arrays.asList(asIntLabel(parameters[0], + (BooleanObjectInspector) actualOI)); + predicted = Arrays.asList(asIntLabel(parameters[1], + (BooleanObjectInspector) predictedOI)); + } else { // int case + int actualLabel = asIntLabel(parameters[0], (IntObjectInspector) actualOI); + + if (actualLabel == 0 && this.average.equals("binary")) { + actual = Collections.emptyList(); + } else { + actual = Arrays.asList(actualLabel); + } + + int predictedLabel = asIntLabel(parameters[1], (IntObjectInspector) predictedOI); + if (predictedLabel == 0 && this.average.equals("binary")) { + predicted = Collections.emptyList(); + } else { + predicted = Arrays.asList(predictedLabel); + } + } } - partial.updateScore(actual, predicted); - return true; + myAggr.iterate(actual, predicted); } - public PartialResult terminatePartial() { - return partial; + private int asIntLabel(@Nonnull Object o, @Nonnull BooleanObjectInspector booleanOI) { + if (booleanOI.get(o)) { + return 1; + } else { + return 0; + } } - public boolean merge(PartialResult other) { - if (other == null) { - return true; + private int asIntLabel(@Nonnull Object o, @Nonnull IntObjectInspector intOI) + throws HiveException { + int value = intOI.get(o); + if (!(value == 1 || value == 0 || value == -1)) { + throw new UDFArgumentException("Int label must be 1, 0 or -1: " + value); } - if (partial == null) { - this.partial = new PartialResult(); + if (value == -1) { + value = 0; } - partial.merge(other); - return true; + return value; } - public DoubleWritable terminate() { + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; + + Object[] partialResult = new Object[5]; + partialResult[0] = new LongWritable(myAggr.tp); + partialResult[1] = new LongWritable(myAggr.totalActual); + partialResult[2] = new LongWritable(myAggr.totalPredicted); + partialResult[3] = new DoubleWritable(myAggr.beta); + partialResult[4] = myAggr.average; + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { if (partial == null) { - return null; + return; } - double score = f1Score(partial); - return WritableUtils.val(score); + + Object tpObj = internalMergeOI.getStructFieldData(partial, tpField); + Object totalActualObj = internalMergeOI.getStructFieldData(partial, totalActualField); + Object totalPredictedObj = internalMergeOI.getStructFieldData(partial, + totalPredictedField); + Object betaObj = internalMergeOI.getStructFieldData(partial, betaOptionField); + Object averageObj = internalMergeOI.getStructFieldData(partial, averageOptionFiled); + long tp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj); + long totalActual = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(totalActualObj); + long totalPredicted = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(totalPredictedObj); + double beta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(betaObj); + String average = PrimitiveObjectInspectorFactory.writableStringObjectInspector.getPrimitiveJavaObject(averageObj); + + FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; + myAggr.merge(tp, totalActual, totalPredicted, beta, average); } - /** - * @return 2 * precision * recall / (precision + recall) - */ - private static double f1Score(final PartialResult partial) { - double precision = precision(partial); - double recall = recall(partial); - double divisor = precision + recall; + @Override + public DoubleWritable terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; + double result = myAggr.get(); + return new DoubleWritable(result); + } + } + + @AggregationType(estimable = true) + public static class FMeasureAggregationBuffer extends AbstractAggregationBuffer { + long tp; + /** tp + fn */ + long totalActual; + /** tp + fp */ + long totalPredicted; + double beta; + String average; + + public FMeasureAggregationBuffer() { + super(); + } + + @Override + public int estimate() { + JavaDataModel model = JavaDataModel.get(); + return model.primitive2() * 4 + model.lengthFor(average); + } + + void setOptions(double beta, String average) { + this.beta = beta; + this.average = average; + } + + void reset() { + this.tp = 0L; + this.totalActual = 0L; + this.totalPredicted = 0L; + } + + void merge(long o_tp, long o_actual, long o_predicted, double beta, String average) { + tp += o_tp; + totalActual += o_actual; + totalPredicted += o_predicted; + this.beta = beta; + this.average = average; + } + + double get() { + double squareBeta = beta * beta; + double divisor; + double numerator; + + if (average.equals("micro")) { + divisor = denom(tp, totalActual, totalPredicted, squareBeta); + numerator = (1.d + squareBeta) * tp; + } else { // binary + double precision = precision(tp, totalPredicted); + double recall = recall(tp, totalActual); + divisor = squareBeta * precision + recall; + numerator = (1.d + squareBeta) * precision * recall; + } + if (divisor > 0) { - return (2.d * precision * recall) / divisor; + return (numerator / divisor); } else { - return -1d; + return 0.d; } } - private static double precision(final PartialResult partial) { - return (partial.totalPredicted == 0L) ? 0d : partial.tp - / (double) partial.totalPredicted; + private static double denom(long tp, long totalActual, long totalPredicted, + double squareBeta) { + long lp = totalActual - tp; + long pl = totalPredicted - tp; + + return squareBeta * (tp + lp) + tp + pl; } - private static double recall(final PartialResult partial) { - return (partial.totalAcutal == 0L) ? 0d : partial.tp / (double) partial.totalAcutal; + private static double precision(long tp, long totalPredicted) { + return (totalPredicted == 0L) ? 0d : tp / (double) totalPredicted; } + private static double recall(long tp, long totalActual) { + return (totalActual == 0L) ? 0d : tp / (double) totalActual; + } + + void iterate(@Nonnull List<?> actual, @Nonnull List<?> predicted) { + final int numActual = actual.size(); + final int numPredicted = predicted.size(); + int countTp = 0; + + for (Object p : predicted) { + if (actual.contains(p)) { + countTp++; + } + } + this.tp += countTp; + this.totalActual += numActual; + this.totalPredicted += numPredicted; + } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/core/src/test/java/hivemall/evaluation/FMeasureUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/evaluation/FMeasureUDAFTest.java b/core/src/test/java/hivemall/evaluation/FMeasureUDAFTest.java new file mode 100644 index 0000000..3974c3d --- /dev/null +++ b/core/src/test/java/hivemall/evaluation/FMeasureUDAFTest.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.evaluation; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.SimpleGenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + + +public class FMeasureUDAFTest { + FMeasureUDAF fmeasure; + GenericUDAFEvaluator evaluator; + ObjectInspector[] inputOIs; + FMeasureUDAF.FMeasureAggregationBuffer agg; + + @Before + public void setUp() throws Exception { + fmeasure = new FMeasureUDAF(); + inputOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-beta 1.")}; + + evaluator = fmeasure.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + + agg = (FMeasureUDAF.FMeasureAggregationBuffer) evaluator.getNewAggregationBuffer(); + } + + private void setUpWithArguments(double beta, String average) throws Exception { + fmeasure = new FMeasureUDAF(); + inputOIs = new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-beta " + beta + + " -average " + average)}; + + evaluator = fmeasure.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + agg = (FMeasureUDAF.FMeasureAggregationBuffer) evaluator.getNewAggregationBuffer(); + } + + private void binarySetUp(Object actual, Object predicted, double beta, String average) + throws Exception { + fmeasure = new FMeasureUDAF(); + inputOIs = new ObjectInspector[3]; + + String actualClassName = actual.getClass().getName(); + if (actualClassName.equals("java.lang.Integer")) { + inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT); + } else if (actualClassName.equals("java.lang.Boolean")) { + inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN); + } else if ((actualClassName.equals("java.lang.String"))) { + inputOIs[0] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING); + } + + String predicatedClassName = predicted.getClass().getName(); + if (predicatedClassName.equals("java.lang.Integer")) { + inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.INT); + } else if (predicatedClassName.equals("java.lang.Boolean")) { + inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN); + } else if ((predicatedClassName.equals("java.lang.String"))) { + inputOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(PrimitiveObjectInspector.PrimitiveCategory.STRING); + } + + inputOIs[2] = ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-beta " + beta + + " -average " + average); + + evaluator = fmeasure.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); + agg = (FMeasureUDAF.FMeasureAggregationBuffer) evaluator.getNewAggregationBuffer(); + } + + @Test + public void testBinaryMultiSamplesAverageBinary() throws Exception { + final int[] actual = {0, 1, 0, 0, 0, 1, 0, 0}; + final int[] predicted = {1, 0, 0, 1, 0, 1, 0, 1}; + double beta = 1.; + String average = "binary"; + binarySetUp(actual[0], predicted[0], beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < actual.length; i++) { + evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); + } + + // should equal to turi's result + // https://turi.com/learn/userguide/evaluation/classification.html#fscores-f1-fbeta- + Assert.assertEquals(0.3333d, agg.get(), 1e-4); + } + + @Test(expected = HiveException.class) + public void testBinaryMultiSamplesAverageMacro() throws Exception { + final int[] actual = {0, 1, 0, 0, 0, 1, 0, 0}; + final int[] predicted = {1, 0, 0, 1, 0, 1, 0, 1}; + double beta = 1.; + String average = "macro"; + binarySetUp(actual[0], predicted[0], beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < actual.length; i++) { + evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); + } + + agg.get(); + } + + @Test + public void testBinaryMultiSamples() throws Exception { + final int[] actual = {0, 1, 0, 0, 0, 1, 0, 0}; + final int[] predicted = {1, 0, 0, 1, 0, 1, 0, 1}; + double beta = 1.; + String average = "micro"; + binarySetUp(actual[0], predicted[0], beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < actual.length; i++) { + evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); + } + + Assert.assertEquals(0.5d, agg.get(), 1e-4); + } + + @Test + public void testBinaryMultiSamplesBeta2() throws Exception { + final int[] actual = {0, 1, 0, 0, 0, 1, 0, 0}; + final int[] predicted = {1, 0, 0, 1, 0, 1, 0, 1}; + double beta = 2.0; + String average = "binary"; + binarySetUp(actual[0], predicted[0], beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < actual.length; i++) { + evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); + } + + Assert.assertEquals(0.4166d, agg.get(), 1e-4); + } + + @Test + public void testBinary() throws Exception { + int actual = 1; + int predicted = 1; + double beta = 1.0; + String average = "micro"; + binarySetUp(actual, predicted, beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + Assert.assertEquals(1.d, agg.get(), 1e-4); + } + + @Test + public void testBinaryNegativeInput() throws Exception { + int actual = 1; + int predicted = -1; + binarySetUp(actual, predicted, 1.0, "binary"); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + Assert.assertEquals(0.d, agg.get(), 1e-4); + } + + @Test + public void testBinaryBooleanInput() throws Exception { + boolean actual = true; + boolean predicted = false; + double beta = 1.0d; + binarySetUp(actual, predicted, beta, "binary"); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + Assert.assertEquals(0.d, agg.get(), 1e-4); + } + + @Test(expected = HiveException.class) + public void testBinaryInvalidStringInput() throws Exception { + String actual = "cat"; + int predicted = 1; + binarySetUp(actual, predicted, 1.0, "micro"); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + agg.get(); + } + + @Test(expected = HiveException.class) + public void testBinaryInvalidLargeIntInput() throws Exception { + int actual = 1; + int predicted = 3; + binarySetUp(actual, predicted, 1.0, "micro"); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + agg.get(); + } + + @Test(expected = HiveException.class) + public void testMultiLabelZeroBeta() throws Exception { + List<Integer> actual = Arrays.asList(1, 3, 2, 6); + List<Integer> predicted = Arrays.asList(1, 2, 4); + double beta = 0.; + setUpWithArguments(beta, "micro"); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + // FMeasure for beta has zero value is not defined + agg.get(); + } + + @Test(expected = HiveException.class) + public void testMultiLabelNegativeBeta() throws Exception { + List<Integer> actual = Arrays.asList(1, 3, 2, 6); + List<Integer> predicted = Arrays.asList(1, 2, 4); + double beta = -1.0d; + String average = "micro"; + setUpWithArguments(beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + // FMeasure for beta has negative value is not defined + agg.get(); + } + + @Test + public void testMultiLabelF1score() throws Exception { + List<Integer> actual = Arrays.asList(1, 3, 2, 6); + List<Integer> predicted = Arrays.asList(1, 2, 4); + double beta = 1.0; + String average = " micro"; + setUpWithArguments(beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + // should equal to spark's micro f1 measure result + // https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html#multilabel-classification + Assert.assertEquals(0.5714285714285714, agg.get(), 1e-5); + } + + @Test + public void testMultiLabelMaxFMeasure() throws Exception { + List<Integer> actual = Arrays.asList(1, 2, 3); + List<Integer> predicted = Arrays.asList(1, 2, 3); + double beta = 1.0; + String average = "micro"; + setUpWithArguments(beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + Assert.assertEquals(1.d, agg.get(), 1e-5); + } + + @Test + public void testMultiLabelMinFMeasure() throws Exception { + List<Integer> actual = Arrays.asList(0, 0, 0); + List<Integer> predicted = Arrays.asList(1, 2, 3); + double beta = 1.0; + String average = "micro"; + setUpWithArguments(beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + evaluator.iterate(agg, new Object[] {actual, predicted}); + + Assert.assertEquals(0.d, agg.get(), 1e-5); + } + + @Test + public void testMultiLabelF1MultiSamples() throws Exception { + String[][] actual = { {"0", "2"}, {"0", "1"}, {"0"}, {"2"}, {"2", "0"}, {"0", "1"}, + {"1", "2"}}; + String[][] predicted = { {"0", "1"}, {"0", "2"}, {}, {"2"}, {"2", "0"}, {"0", "1", "2"}, + {"1"}}; + + double beta = 1.0; + String average = "micro"; + setUpWithArguments(beta, average); + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < actual.length; i++) { + evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); + } + + // should equal to spark's micro f1 measure result + // https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html#multilabel-classification + Assert.assertEquals(0.6956d, agg.get(), 1e-4); + } + + @Test + public void testMultiLabelFmeasureMultiSamples() throws Exception { + String[][] actual = { {"0", "2"}, {"0", "1"}, {"0"}, {"2"}, {"2", "0"}, {"0", "1"}, + {"1", "2"}}; + String[][] predicted = { {"0", "1"}, {"0", "2"}, {}, {"2"}, {"2", "0"}, {"0", "1", "2"}, + {"1"}}; + + double beta = 2.0; + String average = "micro"; + setUpWithArguments(beta, average); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < actual.length; i++) { + evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); + } + + Assert.assertEquals(0.6779d, agg.get(), 1e-4); + } + + @Test(expected = HiveException.class) + public void testMultiLabelFmeasureBinary() throws Exception { + String[][] actual = { {"0", "2"}, {"0", "1"}, {"0"}, {"2"}, {"2", "0"}, {"0", "1"}, + {"1", "2"}}; + String[][] predicted = { {"0", "1"}, {"0", "2"}, {}, {"2"}, {"2", "0"}, {"0", "1", "2"}, + {"1"}}; + + double beta = 1.0; + String average = "binary"; + + setUpWithArguments(beta, average); + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < actual.length; i++) { + evaluator.iterate(agg, new Object[] {actual[i], predicted[i]}); + } + + agg.get(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index f5ab81e..3d640f8 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -66,9 +66,10 @@ ## Part IV - Evaluation -* [Statistical evaluation of a prediction model](eval/stat_eval.md) +* [Binary Classification Metrics](eval/binary_classification_measures.md) * [Area Under the ROC Curve](eval/auc.md) - +* [Multi-label Classification Metrics](eval/multilabel_classification_measures.md) +* [Regression metrics](eval/regression.md) * [Ranking Measures](eval/rank.md) * [Data Generation](eval/datagen.md) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/docs/gitbook/eval/auc.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/eval/auc.md b/docs/gitbook/eval/auc.md index 8cad8f6..b8f7f0b 100644 --- a/docs/gitbook/eval/auc.md +++ b/docs/gitbook/eval/auc.md @@ -57,7 +57,7 @@ with data as ( union all select 0.7 as prob, 1 as label ) -select +select auc(prob, label) as auc from ( select prob, label @@ -72,7 +72,7 @@ Since AUC is a metric based on ranked probability-label pairs as mentioned above ## Parallel approximate AUC computation -Meanwhile, Hive's `distribute by` clause allows you to compute AUC in parallel: +Meanwhile, Hive's `distribute by` clause allows you to compute AUC in parallel: ```sql with data as ( @@ -86,7 +86,7 @@ with data as ( union all select 0.7 as prob, 1 as label ) -select +select auc(prob, label) as auc from ( select prob, label @@ -100,7 +100,7 @@ Note that `floor(prob / 0.2)` means that the rows are distributed to 5 bins for # Difference between AUC and Logarithmic Loss -Hivemall has another metric called [Logarithmic Loss](stat_eval.html#logarithmic-loss) for binary classification. Both AUC and Logarithmic Loss compute scores for probability-label pairs. +Hivemall has another metric called [Logarithmic Loss](stat_eval.html#logarithmic-loss) for binary classification. Both AUC and Logarithmic Loss compute scores for probability-label pairs. Score produced by AUC is a relative metric based on sorted pairs. On the other hand, Logarithmic Loss simply gives a metric by comparing probability with its truth label one-by-one. http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/docs/gitbook/eval/binary_classification_measures.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/eval/binary_classification_measures.md b/docs/gitbook/eval/binary_classification_measures.md new file mode 100644 index 0000000..5121ffe --- /dev/null +++ b/docs/gitbook/eval/binary_classification_measures.md @@ -0,0 +1,223 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +<!-- toc --> + +# Binary problems + +Binary classification problem is the task to predict the label of each data given two categories. + +Hivemall provides some tutorials to deal with binary classification problems as follows: + +- [Online advertisement click prediction](../binaryclass/general.html) +- [News classification](../binaryclass/news20_dataset.html) + +This page focuses on the evaluation for such binary classification problems. +If your classifier outputs probability rather than 0/1 label, evaluation based on [Area Under the ROC Curve](./auc.md) would be more appropriate. + + +# Example + +For the metrics explanation, this page introduces toy example data and two metrics. + +## Data + +The following table shows the sample of binary classification's prediction. +In this case, `1` means positive label and `0` means negative label. +Left column includes supervised label data, +and center column includes predicted label by a binary classifier. + +| truth label| predicted label | | +|:---:|:---:|:---:| +| 1 | 0 |False Negative| +| 0 | 1 |False Positive| +| 0 | 0 |True Negative| +| 1 | 1 |True Positive| +| 0 | 1 |False Positive| +| 0 | 0 |True Negative| + +## Preliminary metrics + +Some evaluation metrics are calculated based on 4 values: + +- True Positive (TP): truth label is positive and predicted label is also positive +- True Negative (TN): truth label is negative and predicted label is also negative +- False Positive (FP): truth label is negative but predicted label is positive +- False Negative (FN): truth label is positive but predicted label is negative + +`TR` and `TN` represent correct classification, and `FP` and `FN` illustrate incorrect ones. + +In this example, we can obtain those values: + +- TP: 1 +- TN: 2 +- FP: 2 +- FN: 1 + +if you want to know about those metrics, Wikipedia provides [more detail information](https://en.wikipedia.org/wiki/Sensitivity_and_specificity). + +### Recall + +Recall indicates the true positive rate in truth positive labels. +The value is computed by the following equation: + +$$ +\mathrm{recall} = \frac{\mathrm{\#TP}}{\mathrm{\#TP} + \mathrm{\#FN}} +$$ + +In the previous example, $$\mathrm{precision} = \frac{1}{2}$$. + +### Precision + +Precision indicates the true positive rate in positive predictive labels. +The value is computed by the following equation: + +$$ +\mathrm{precision} = \frac{\mathrm{\#TP}}{\mathrm{\#TP} + \mathrm{\#FP}} +$$ + +In the previous example, $$\mathrm{precision} = \frac{1}{3}$$. + +# Metrics + +To use metrics examples, please create the following table. + +```sql +create table data as + select 1 as truth, 0 as predicted +union all + select 0 as truth, 1 as predicted +union all + select 0 as truth, 0 as predicted +union all + select 1 as truth, 1 as predicted +union all + select 0 as truth, 1 as predicted +union all + select 0 as truth, 0 as predicted +; +``` + +## F1-score + +F1-score is the harmonic mean of recall and precision. +F1-score is computed by the following equation: + +$$ +\mathrm{F}_1 = 2 \frac{\mathrm{precision} * \mathrm{recall}}{\mathrm{precision} + \mathrm{recall}} +$$ + +Hivemall's `fmeasure` function provides the option which can switch `micro`(default) or `binary` by passing `average` argument. + + +> #### Caution +> Hivemall also provides `f1score` function, but it is old function to obtain F1-score. The value of `f1score` is based on set operation. So, we recommend to use `fmeasure` function to get F1-score based on this article. + +You can learn more about this from the following external resource: + +- [scikit-learn's F1-score](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html) + + +### Micro average + +If `micro` is passed to `average`, +recall and precision are modified to consider True Negative. +So, micro f1score are calculated by those modified recall and precision. + +$$ +\mathrm{recall} = \frac{\mathrm{\#TP} + \mathrm{\#TN}}{\mathrm{\#TP} + \mathrm{\#FN} + \mathrm{\#TN}} +$$ + +$$ +\mathrm{precision} = \frac{\mathrm{\#TP} + \mathrm{\#TN}}{\mathrm{\#TP} + \mathrm{\#FP} + \mathrm{\#TN}} +$$ + +If `average` argument is omitted, `fmeasure` use default value: `'-average micro'`. + +The following query shows the example to obtain F1-score. +Each row value has the same type (`int` or `boolean`). +If row value's type is `int`, `1` is considered as the positive label, and `-1` or `0` is considered as the negative label. + + +```sql +select fmeasure(truth, predicted, '-average micro') from data; +``` + +> 0.5 + + +It should be noted that, since the old `f1score(truth, predicted)` function simply counts the number of "matched" elements between `truth` and `predicted`, the above query is equivalent to: + + +```sql +select f1score(array(truth), array(predicted)) from data; +``` + +### Binary average + +If `binary` is passed to `average`, `True Negative` samples are ignored to get F1-score. + +The following query shows the example to obtain F1-score with binary average. +```sql +select fmeasure(truth, predicted, '-average binary') from data; +``` + +> 0.4 + + +## F-measure + +F-measure is generalized F1-score and the weighted harmonic mean of recall and precision. +F-measure is computed by the following equation: + +$$ +\mathrm{F}_{\beta} = (1+\beta^2) \frac{\mathrm{precision} * \mathrm{recall}}{\beta^2 \mathrm{precision} + \mathrm{recall}} +$$ + +$$\beta$$ is the parameter to determine the weight of precision. +So, F1-score is the special case of F-measure given $$\beta=1$$. + +If $$\beta$$ is larger positive value than `1.0`, F-measure reaches recall. +On the other hand, +if $$\beta$$ is smaller positive value than `1.0`, F-measure reaches precision. + +If $$\beta$$ is omitted, hivemall calculates F-measure with $$\beta=1$$ (: equivalent to F1-score). + +Hivemall's `fmeasure` function also provides the option which can switch `micro`(default) or `binary` by passing `average` argument. + + +The following query shows the example to obtain F-measure with $$\beta=2$$ and micro average. + +```sql +select fmeasure(truth, predicted, '-beta 2. -average micro') from data; +``` + +> 0.5 + +The following query shows the example to obtain F-measure with $$\beta=2$$ and binary average. + +```sql +select fmeasure(truth, predicted, '-beta 2. -average binary') from data; +``` + +> 0.45454545454545453 + +You can learn more about this from the following external resource: + +- [scikit-learn's FMeasure](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/docs/gitbook/eval/multilabel_classification_measures.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/eval/multilabel_classification_measures.md b/docs/gitbook/eval/multilabel_classification_measures.md new file mode 100644 index 0000000..fb2d6c0 --- /dev/null +++ b/docs/gitbook/eval/multilabel_classification_measures.md @@ -0,0 +1,147 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +<!-- toc --> + +# Multi-label classification + + +Multi-label classification problem is the task to predict the labels given categorized dataset. +Each sample $$i$$ has $$l_i$$ labels, where $$L$$ is a set of unique labels in the dataset, and $$0 \leq l_i \leq |L|$$. + +This page focuses on evaluation of the results from such multi-label classification problems. + +# Example + +For the metrics explanation, this page introduces toy example dataset. + +## Data + +The following table shows the sample of multi-label classification's prediction. +Animal names represent the tags of blog post. +Left column includes supervised labels, +and right column includes predicted labels by a multi-label classifier. + +| truth labels| predicted labels | +|:---:|:---:| +| cat, bird | cat, dog| +| cat, dog | cat, bird| +| cat | (*no truth label*)| +| bird | bird | +| bird, cat | bird, cat| +| cat, dog | cat, dog, bird | +| dog, bird | dog | + + +# Evaluation metrics for multi-label classification + +Hivemall provides micro F1-score and micro F-measure. + +Define $$L$$ is the set of the tag of blog posts, and +$$l_i$$ is a tag set of $$i$$th document. +In the same manner, +$$p_i$$ is a predicted tag set of $$i$$th document. + +## Micro F1-score + +F1-score is the harmonic mean of recall and precision. + +The value is computed by the following equation: + +$$ +\mathrm{F}_1 = 2 \frac +{\sum_i |l_i \cap p_i |} +{ 2* \sum_i |l_i \cap p_i | + \sum_i |l_i - p_i| + \sum_i |p_i - l_i| } +$$ + +> #### Caution +> Hivemall also provides `f1score` function, but it is old function to obtain F1-score. The value of `f1score` is based on set operation. So, we recommend to use `fmeasure` function to get F1-score based on this article. + +The following query shows the example to obtain F1-score. + +```sql +WITH data as ( + select array("cat", "bird") as actual, array("cat", "dog") as predicted +union all + select array("cat", "dog") as actual, array("cat", "bird") as predicted +union all + select array("cat") as actual, array() as predicted +union all + select array("bird") as actual, array("bird") as predicted +union all + select array("bird", "cat") as actual, array("bird", "cat") as predicted +union all + select array("cat", "dog") as actual, array("cat", "dog", "bird") as predicted +union all + select array("dog", "bird") as actual, array("dog") as predicted +) +select + fmeasure(actual, predicted) +from data +; +``` + +> 0.6956521739130435 + +## Micro F-measure + + +F-measure is generalized F1-score and the weighted harmonic mean of recall and precision. + +The value is computed by the following equation: +$$ +\mathrm{F}_{\beta} = (1+\beta^2) \frac +{\sum_i |l_i \cap p_i |} +{ \beta^2 (\sum_i |l_i \cap p_i | + \sum_i |l_i - p_i|) + \sum_i |l_i \cap p_i | + \sum_i |p_i - l_i|} +$$ + +$$\beta$$ is the parameter to determine the weight of precision. +So, F1-score is the special case of F-measure given $$\beta=1$$. + +If $$\beta$$ is larger positive value than `1.0`, F-measure reaches micro recall. +On the other hand, +if $$\beta$$ is smaller positive value than `1.0`, F-measure reaches micro precision. + +If $$\beta$$ is omitted, hivemall calculates F-measure with $$\beta=1$$ (: equivalent to F1-score). + +The following query shows the example to obtain F-measure with $$\beta=2$$. + +```sql +WITH data as ( + select array("cat", "bird") as actual, array("cat", "dog") as predicted +union all + select array("cat", "dog") as actual, array("cat", "bird") as predicted +union all + select array("cat") as actual, array() as predicted +union all + select array("bird") as actual, array("bird") as predicted +union all + select array("bird", "cat") as actual, array("bird", "cat") as predicted +union all + select array("cat", "dog") as actual, array("cat", "dog", "bird") as predicted +union all + select array("dog", "bird") as actual, array("dog") as predicted +) +select + fmeasure(actual, predicted, '-beta 2.') +from data +; +``` + +> 0.6779661016949152 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/docs/gitbook/eval/regression.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/eval/regression.md b/docs/gitbook/eval/regression.md new file mode 100644 index 0000000..9a7345e --- /dev/null +++ b/docs/gitbook/eval/regression.md @@ -0,0 +1,77 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +Using the [E2006 tfidf regression example](../regression/e2006_arow.html), we explain how to evaluate the prediction model on Hive. + +<!-- toc --> + +# Scoring by evaluation metrics + +```sql +select avg(actual), avg(predicted) from e2006tfidf_pa2a_submit; +``` +> -3.8200363760415414 -3.9124877451612488 + +```sql +set hivevar:mean_actual=-3.8200363760415414; + +select +-- Root Mean Squared Error + rmse(predicted, actual) as RMSE, + -- sqrt(sum(pow(predicted - actual,2.0))/count(1)) as RMSE, +-- Mean Squared Error + mse(predicted, actual) as MSE, + -- sum(pow(predicted - actual,2.0))/count(1) as MSE, +-- Mean Absolute Error + mae(predicted, actual) as MAE, + -- sum(abs(predicted - actual))/count(1) as MAE, +-- coefficient of determination (R^2) + -- 1 - sum(pow(actual - predicted,2.0)) / sum(pow(actual - ${mean_actual},2.0)) as R2 + r2(actual, predicted) as R2 -- supported since Hivemall v0.4.1-alpha.5 +from + e2006tfidf_pa2a_submit; +``` +> 0.38538660838804495 0.14852283792484033 0.2466732002711477 0.48623913673053565 + +# Logarithmic Loss + +[Logarithmic Loss](https://www.kaggle.com/wiki/LogarithmicLoss) can be computed as follows: + +```sql +WITH t as ( + select + 0 as actual, + 0.01 as predicted + union all + select + 1 as actual, + 0.02 as predicted +) +select + -SUM(actual*LN(predicted)+(1-actual)*LN(1-predicted))/count(1) as logloss1, + logloss(predicted, actual) as logloss2 -- supported since Hivemall v0.4.2-rc.1 +from +from t; +``` +> 1.9610366706408238 1.9610366706408238 + +# References + +* R2 http://en.wikipedia.org/wiki/Coefficient_of_determination +* Evaluation Metrics https://www.kaggle.com/wiki/Metrics http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/docs/gitbook/eval/stat_eval.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/eval/stat_eval.md b/docs/gitbook/eval/stat_eval.md deleted file mode 100644 index 149adf8..0000000 --- a/docs/gitbook/eval/stat_eval.md +++ /dev/null @@ -1,77 +0,0 @@ -<!-- - Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. ---> - -Using the [E2006 tfidf regression example](../regression/e2006_arow.html), we explain how to evaluate the prediction model on Hive. - -<!-- toc --> - -# Scoring by evaluation metrics - -```sql -select avg(actual), avg(predicted) from e2006tfidf_pa2a_submit; -``` -> -3.8200363760415414 -3.9124877451612488 - -```sql -set hivevar:mean_actual=-3.8200363760415414; - -select --- Root Mean Squared Error - rmse(predicted, actual) as RMSE, - -- sqrt(sum(pow(predicted - actual,2.0))/count(1)) as RMSE, --- Mean Squared Error - mse(predicted, actual) as MSE, - -- sum(pow(predicted - actual,2.0))/count(1) as MSE, --- Mean Absolute Error - mae(predicted, actual) as MAE, - -- sum(abs(predicted - actual))/count(1) as MAE, --- coefficient of determination (R^2) - -- 1 - sum(pow(actual - predicted,2.0)) / sum(pow(actual - ${mean_actual},2.0)) as R2 - r2(actual, predicted) as R2 -- supported since Hivemall v0.4.1-alpha.5 -from - e2006tfidf_pa2a_submit; -``` -> 0.38538660838804495 0.14852283792484033 0.2466732002711477 0.48623913673053565 - -# Logarithmic Loss - -[Logarithmic Loss](https://www.kaggle.com/wiki/LogarithmicLoss) can be computed as follows: - -```sql -WITH t as ( - select - 0 as actual, - 0.01 as predicted - union all - select - 1 as actual, - 0.02 as predicted -) -select - -SUM(actual*LN(predicted)+(1-actual)*LN(1-predicted))/count(1) as logloss1, - logloss(predicted, actual) as logloss2 -- supported since Hivemall v0.4.2-rc.1 -from -from t; -``` -> 1.9610366706408238 1.9610366706408238 - -# References - -* R2 http://en.wikipedia.org/wiki/Coefficient_of_determination -* Evaluation Metrics https://www.kaggle.com/wiki/Metrics http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/resources/ddl/define-all-as-permanent.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index 100fe22..8cdd371 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -565,7 +565,10 @@ CREATE FUNCTION lr_datagen as 'hivemall.dataset.LogisticRegressionDataGeneratorU -------------------------- DROP FUNCTION IF EXISTS f1score; -CREATE FUNCTION f1score as 'hivemall.evaluation.FMeasureUDAF' USING JAR '${hivemall_jar}'; +CREATE FUNCTION f1score as 'hivemall.evaluation.F1ScoreUDAF' USING JAR '${hivemall_jar}'; + +DROP FUNCTION IF EXISTS fmeasure; +CREATE FUNCTION fmeasure as 'hivemall.evaluation.FMeasureUDAF' USING JAR '${hivemall_jar}'; DROP FUNCTION IF EXISTS mae; CREATE FUNCTION mae as 'hivemall.evaluation.MeanAbsoluteErrorUDAF' USING JAR '${hivemall_jar}'; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/resources/ddl/define-all.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 6fb34ca..756c57a 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -557,7 +557,10 @@ create temporary function lr_datagen as 'hivemall.dataset.LogisticRegressionData -------------------------- drop temporary function if exists f1score; -create temporary function f1score as 'hivemall.evaluation.FMeasureUDAF'; +create temporary function f1score as 'hivemall.evaluation.F1ScoreUDAF'; + +drop temporary function if exists fmeasure; +create temporary function fmeasure as 'hivemall.evaluation.FMeasureUDAF'; drop temporary function if exists mae; create temporary function mae as 'hivemall.evaluation.MeanAbsoluteErrorUDAF'; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/resources/ddl/define-all.spark ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index d0a1084..bddbc85 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -541,7 +541,10 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION lr_datagen AS 'hivemall.dataset.Logist */ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS f1score") -sqlContext.sql("CREATE TEMPORARY FUNCTION f1score AS 'hivemall.evaluation.FMeasureUDAF'") +sqlContext.sql("CREATE TEMPORARY FUNCTION f1score AS 'hivemall.evaluation.F1ScoreUDAF'") + +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS fmeasure") +sqlContext.sql("CREATE TEMPORARY FUNCTION fmeasure AS 'hivemall.evaluation.FMeasureUDAF'") sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS mae") sqlContext.sql("CREATE TEMPORARY FUNCTION mae AS 'hivemall.evaluation.MeanAbsoluteErrorUDAF'") http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/098a7f3d/resources/ddl/define-udfs.td.hql ---------------------------------------------------------------------- diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql index d90cb3c..c59b120 100644 --- a/resources/ddl/define-udfs.td.hql +++ b/resources/ddl/define-udfs.td.hql @@ -130,7 +130,8 @@ create temporary function normalize_unicode as 'hivemall.tools.text.NormalizeUni create temporary function base91 as 'hivemall.tools.text.Base91UDF'; create temporary function unbase91 as 'hivemall.tools.text.Unbase91UDF'; create temporary function lr_datagen as 'hivemall.dataset.LogisticRegressionDataGeneratorUDTF'; -create temporary function f1score as 'hivemall.evaluation.FMeasureUDAF'; +create temporary function f1score as 'hivemall.evaluation.F1ScoreUDAF'; +create temporary function fmeasure as 'hivemall.evaluation.FMeasureUDAF'; create temporary function mae as 'hivemall.evaluation.MeanAbsoluteErrorUDAF'; create temporary function mse as 'hivemall.evaluation.MeanSquaredErrorUDAF'; create temporary function rmse as 'hivemall.evaluation.RootMeanSquaredErrorUDAF';
