Minor refactoring of FMeasureUDAF
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/c1cd4b2e Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/c1cd4b2e Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/c1cd4b2e Branch: refs/heads/master Commit: c1cd4b2e050fd6c8f8c768140ab7e4f3e9d04c14 Parents: b058473 Author: Makoto Yui <m...@apache.org> Authored: Wed Sep 13 22:55:05 2017 +0900 Committer: Makoto Yui <m...@apache.org> Committed: Wed Sep 13 22:55:05 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/evaluation/FMeasureUDAF.java | 93 +++++++++++--------- 1 file changed, 51 insertions(+), 42 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c1cd4b2e/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java b/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java index feb50b7..e64dc12 100644 --- a/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java +++ b/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java @@ -20,40 +20,45 @@ package hivemall.evaluation; import hivemall.UDAFEvaluatorWithOptions; import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import hivemall.utils.lang.Primitives; +import javax.annotation.Nonnull; + import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; - import org.apache.hadoop.hive.ql.exec.Description; -import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; import org.apache.hadoop.hive.ql.util.JavaDataModel; import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer; -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType; - -import javax.annotation.Nonnull; @Description( name = "fmeasure", - value = "_FUNC_(array | int | boolean actual , array | int | boolean predicted, String) - Return a F-measure (f1score is the special with beta=1.)") + value = "_FUNC_(array|int|boolean actual, array|int| boolean predicted [, const string options])" + + " - Return a F-measure (f1score is the special with beta=1.0)") public final class FMeasureUDAF extends AbstractGenericUDAFResolver { + @Override public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException { if (typeInfo.length != 2 && typeInfo.length != 3) { @@ -176,9 +181,10 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { return outputOI; } + @Nonnull private static StructObjectInspector internalMergeOI() { - ArrayList<String> fieldNames = new ArrayList<>(); - ArrayList<ObjectInspector> fieldOIs = new ArrayList<>(); + List<String> fieldNames = new ArrayList<>(); + List<ObjectInspector> fieldOIs = new ArrayList<>(); fieldNames.add("tp"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); @@ -206,7 +212,7 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { throws HiveException { FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg; myAggr.reset(); - myAggr.setOptions(this.beta, this.average); + myAggr.setOptions(beta, average); } @Override @@ -219,7 +225,7 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { final List<?> predicted; if (isList) {// array case - if (this.average.equals("binary")) { + if ("binary".equals(average)) { throw new UDFArgumentException( "\"-average binary\" is not supported when `predict` is array"); } @@ -232,16 +238,16 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { predicted = Arrays.asList(asIntLabel(parameters[1], (BooleanObjectInspector) predictedOI)); } else { // int case - int actualLabel = asIntLabel(parameters[0], (IntObjectInspector) actualOI); - - if (actualLabel == 0 && this.average.equals("binary")) { + final int actualLabel = asIntLabel(parameters[0], (IntObjectInspector) actualOI); + if (actualLabel == 0 && "binary".equals(average)) { actual = Collections.emptyList(); } else { actual = Arrays.asList(actualLabel); } - int predictedLabel = asIntLabel(parameters[1], (IntObjectInspector) predictedOI); - if (predictedLabel == 0 && this.average.equals("binary")) { + final int predictedLabel = asIntLabel(parameters[1], + (IntObjectInspector) predictedOI); + if (predictedLabel == 0 && "binary".equals(average)) { predicted = Collections.emptyList(); } else { predicted = Arrays.asList(predictedLabel); @@ -251,7 +257,8 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { myAggr.iterate(actual, predicted); } - private int asIntLabel(@Nonnull Object o, @Nonnull BooleanObjectInspector booleanOI) { + private static int asIntLabel(@Nonnull final Object o, + @Nonnull final BooleanObjectInspector booleanOI) { if (booleanOI.get(o)) { return 1; } else { @@ -259,19 +266,20 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { } } - private int asIntLabel(@Nonnull Object o, @Nonnull IntObjectInspector intOI) - throws HiveException { - int value = intOI.get(o); - if (!(value == 1 || value == 0 || value == -1)) { - throw new UDFArgumentException("Int label must be 1, 0 or -1: " + value); + private static int asIntLabel(@Nonnull final Object o, + @Nonnull final IntObjectInspector intOI) throws UDFArgumentException { + final int value = intOI.get(o); + switch (value) { + case 1: + return 1; + case 0: + case -1: + return 0; + default: + throw new UDFArgumentException("Int label must be 1, 0 or -1: " + value); } - if (value == -1) { - value = 0; - } - return value; } - @Override public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { @@ -349,7 +357,8 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { this.totalPredicted = 0L; } - void merge(long o_tp, long o_actual, long o_predicted, double beta, String average) { + void merge(final long o_tp, final long o_actual, final long o_predicted, final double beta, + final String average) { tp += o_tp; totalActual += o_actual; totalPredicted += o_predicted; @@ -358,11 +367,11 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { } double get() { - double squareBeta = beta * beta; - double divisor; - double numerator; + final double squareBeta = beta * beta; - if (average.equals("micro")) { + final double divisor; + final double numerator; + if ("micro".equals(average)) { divisor = denom(tp, totalActual, totalPredicted, squareBeta); numerator = (1.d + squareBeta) * tp; } else { // binary @@ -379,23 +388,23 @@ public final class FMeasureUDAF extends AbstractGenericUDAFResolver { } } - private static double denom(long tp, long totalActual, long totalPredicted, - double squareBeta) { + private static double denom(final long tp, final long totalActual, + final long totalPredicted, double squareBeta) { long lp = totalActual - tp; long pl = totalPredicted - tp; return squareBeta * (tp + lp) + tp + pl; } - private static double precision(long tp, long totalPredicted) { - return (totalPredicted == 0L) ? 0d : tp / (double) totalPredicted; + private static double precision(final long tp, final long totalPredicted) { + return (totalPredicted == 0L) ? 0.d : tp / (double) totalPredicted; } - private static double recall(long tp, long totalActual) { - return (totalActual == 0L) ? 0d : tp / (double) totalActual; + private static double recall(final long tp, final long totalActual) { + return (totalActual == 0L) ? 0.d : tp / (double) totalActual; } - void iterate(@Nonnull List<?> actual, @Nonnull List<?> predicted) { + void iterate(@Nonnull final List<?> actual, @Nonnull final List<?> predicted) { final int numActual = actual.size(); final int numPredicted = predicted.size(); int countTp = 0;