Minor refactoring of FMeasureUDAF

Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/c1cd4b2e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/c1cd4b2e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/c1cd4b2e

Branch: refs/heads/master
Commit: c1cd4b2e050fd6c8f8c768140ab7e4f3e9d04c14
Parents: b058473
Author: Makoto Yui <m...@apache.org>
Authored: Wed Sep 13 22:55:05 2017 +0900
Committer: Makoto Yui <m...@apache.org>
Committed: Wed Sep 13 22:55:05 2017 +0900

----------------------------------------------------------------------
 .../java/hivemall/evaluation/FMeasureUDAF.java  | 93 +++++++++++---------
 1 file changed, 51 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/c1cd4b2e/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java 
b/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java
index feb50b7..e64dc12 100644
--- a/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java
+++ b/core/src/main/java/hivemall/evaluation/FMeasureUDAF.java
@@ -20,40 +20,45 @@ package hivemall.evaluation;
 
 import hivemall.UDAFEvaluatorWithOptions;
 import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Primitives;
 
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
-import hivemall.utils.lang.Primitives;
+import javax.annotation.Nonnull;
+
 import org.apache.commons.cli.CommandLine;
 import org.apache.commons.cli.Options;
-
 import org.apache.hadoop.hive.ql.exec.Description;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.parse.SemanticException;
 import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
+import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType;
 import org.apache.hadoop.hive.ql.util.JavaDataModel;
 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
-import org.apache.hadoop.hive.serde2.objectinspector.*;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
-import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
+import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.apache.hadoop.io.LongWritable;
-import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
-import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationType;
-
-import javax.annotation.Nonnull;
 
 @Description(
         name = "fmeasure",
-        value = "_FUNC_(array | int | boolean actual , array | int | boolean 
predicted, String) - Return a F-measure (f1score is the special with beta=1.)")
+        value = "_FUNC_(array|int|boolean actual, array|int| boolean predicted 
[, const string options])"
+                + " - Return a F-measure (f1score is the special with 
beta=1.0)")
 public final class FMeasureUDAF extends AbstractGenericUDAFResolver {
+
     @Override
     public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) 
throws SemanticException {
         if (typeInfo.length != 2 && typeInfo.length != 3) {
@@ -176,9 +181,10 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
             return outputOI;
         }
 
+        @Nonnull
         private static StructObjectInspector internalMergeOI() {
-            ArrayList<String> fieldNames = new ArrayList<>();
-            ArrayList<ObjectInspector> fieldOIs = new ArrayList<>();
+            List<String> fieldNames = new ArrayList<>();
+            List<ObjectInspector> fieldOIs = new ArrayList<>();
 
             fieldNames.add("tp");
             
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
@@ -206,7 +212,7 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
                 throws HiveException {
             FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) agg;
             myAggr.reset();
-            myAggr.setOptions(this.beta, this.average);
+            myAggr.setOptions(beta, average);
         }
 
         @Override
@@ -219,7 +225,7 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
             final List<?> predicted;
 
             if (isList) {// array case
-                if (this.average.equals("binary")) {
+                if ("binary".equals(average)) {
                     throw new UDFArgumentException(
                         "\"-average binary\" is not supported when `predict` 
is array");
                 }
@@ -232,16 +238,16 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
                     predicted = Arrays.asList(asIntLabel(parameters[1],
                         (BooleanObjectInspector) predictedOI));
                 } else { // int case
-                    int actualLabel = asIntLabel(parameters[0], 
(IntObjectInspector) actualOI);
-
-                    if (actualLabel == 0 && this.average.equals("binary")) {
+                    final int actualLabel = asIntLabel(parameters[0], 
(IntObjectInspector) actualOI);
+                    if (actualLabel == 0 && "binary".equals(average)) {
                         actual = Collections.emptyList();
                     } else {
                         actual = Arrays.asList(actualLabel);
                     }
 
-                    int predictedLabel = asIntLabel(parameters[1], 
(IntObjectInspector) predictedOI);
-                    if (predictedLabel == 0 && this.average.equals("binary")) {
+                    final int predictedLabel = asIntLabel(parameters[1],
+                        (IntObjectInspector) predictedOI);
+                    if (predictedLabel == 0 && "binary".equals(average)) {
                         predicted = Collections.emptyList();
                     } else {
                         predicted = Arrays.asList(predictedLabel);
@@ -251,7 +257,8 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
             myAggr.iterate(actual, predicted);
         }
 
-        private int asIntLabel(@Nonnull Object o, @Nonnull 
BooleanObjectInspector booleanOI) {
+        private static int asIntLabel(@Nonnull final Object o,
+                @Nonnull final BooleanObjectInspector booleanOI) {
             if (booleanOI.get(o)) {
                 return 1;
             } else {
@@ -259,19 +266,20 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
             }
         }
 
-        private int asIntLabel(@Nonnull Object o, @Nonnull IntObjectInspector 
intOI)
-                throws HiveException {
-            int value = intOI.get(o);
-            if (!(value == 1 || value == 0 || value == -1)) {
-                throw new UDFArgumentException("Int label must be 1, 0 or -1: 
" + value);
+        private static int asIntLabel(@Nonnull final Object o,
+                @Nonnull final IntObjectInspector intOI) throws 
UDFArgumentException {
+            final int value = intOI.get(o);
+            switch (value) {
+                case 1:
+                    return 1;
+                case 0:
+                case -1:
+                    return 0;
+                default:
+                    throw new UDFArgumentException("Int label must be 1, 0 or 
-1: " + value);
             }
-            if (value == -1) {
-                value = 0;
-            }
-            return value;
         }
 
-
         @Override
         public Object terminatePartial(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
                 throws HiveException {
@@ -349,7 +357,8 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
             this.totalPredicted = 0L;
         }
 
-        void merge(long o_tp, long o_actual, long o_predicted, double beta, 
String average) {
+        void merge(final long o_tp, final long o_actual, final long 
o_predicted, final double beta,
+                final String average) {
             tp += o_tp;
             totalActual += o_actual;
             totalPredicted += o_predicted;
@@ -358,11 +367,11 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
         }
 
         double get() {
-            double squareBeta = beta * beta;
-            double divisor;
-            double numerator;
+            final double squareBeta = beta * beta;
 
-            if (average.equals("micro")) {
+            final double divisor;
+            final double numerator;
+            if ("micro".equals(average)) {
                 divisor = denom(tp, totalActual, totalPredicted, squareBeta);
                 numerator = (1.d + squareBeta) * tp;
             } else { // binary
@@ -379,23 +388,23 @@ public final class FMeasureUDAF extends 
AbstractGenericUDAFResolver {
             }
         }
 
-        private static double denom(long tp, long totalActual, long 
totalPredicted,
-                double squareBeta) {
+        private static double denom(final long tp, final long totalActual,
+                final long totalPredicted, double squareBeta) {
             long lp = totalActual - tp;
             long pl = totalPredicted - tp;
 
             return squareBeta * (tp + lp) + tp + pl;
         }
 
-        private static double precision(long tp, long totalPredicted) {
-            return (totalPredicted == 0L) ? 0d : tp / (double) totalPredicted;
+        private static double precision(final long tp, final long 
totalPredicted) {
+            return (totalPredicted == 0L) ? 0.d : tp / (double) totalPredicted;
         }
 
-        private static double recall(long tp, long totalActual) {
-            return (totalActual == 0L) ? 0d : tp / (double) totalActual;
+        private static double recall(final long tp, final long totalActual) {
+            return (totalActual == 0L) ? 0.d : tp / (double) totalActual;
         }
 
-        void iterate(@Nonnull List<?> actual, @Nonnull List<?> predicted) {
+        void iterate(@Nonnull final List<?> actual, @Nonnull final List<?> 
predicted) {
             final int numActual = actual.size();
             final int numPredicted = predicted.size();
             int countTp = 0;

Reply via email to