Github user takuti commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/107#discussion_r134157657
  
    --- Diff: core/src/main/java/hivemall/evaluation/FMeasureUDAF.java ---
    @@ -18,118 +18,387 @@
      */
     package hivemall.evaluation;
     
    -import hivemall.utils.hadoop.WritableUtils;
    +import hivemall.UDAFEvaluatorWithOptions;
    +import hivemall.utils.hadoop.HiveUtils;
     
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.Collections;
     import java.util.List;
     
    +import hivemall.utils.lang.Primitives;
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +
     import org.apache.hadoop.hive.ql.exec.Description;
    -import org.apache.hadoop.hive.ql.exec.UDAF;
    -import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.ql.parse.SemanticException;
    +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
    +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
     import org.apache.hadoop.hive.serde2.io.DoubleWritable;
    -import org.apache.hadoop.io.IntWritable;
    -
    -@SuppressWarnings("deprecation")
    -@Description(name = "f1score",
    -        value = "_FUNC_(array[int], array[int]) - Return a F-measure/F1 
score")
    -public final class FMeasureUDAF extends UDAF {
    -
    -    public static class Evaluator implements UDAFEvaluator {
    -
    -        public static class PartialResult {
    -            long tp;
    -            /** tp + fn */
    -            long totalAcutal;
    -            /** tp + fp */
    -            long totalPredicted;
    -
    -            PartialResult() {
    -                this.tp = 0L;
    -                this.totalPredicted = 0L;
    -                this.totalAcutal = 0L;
    -            }
    +import org.apache.hadoop.hive.serde2.objectinspector.*;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
    +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
    +import org.apache.hadoop.io.LongWritable;
     
    -            void updateScore(final List<IntWritable> actual, final 
List<IntWritable> predicted) {
    -                final int numActual = actual.size();
    -                final int numPredicted = predicted.size();
    -                int countTp = 0;
    -                for (int i = 0; i < numPredicted; i++) {
    -                    IntWritable p = predicted.get(i);
    -                    if (actual.contains(p)) {
    -                        countTp++;
    -                    }
    +import javax.annotation.Nonnull;
    +
    +@Description(
    +        name = "fmeasure",
    +        value = "_FUNC_(array | int | boolean, array | int | boolean, 
String) - Return a F-measure (f1score is the special with beta=1.)")
    +public final class FMeasureUDAF extends AbstractGenericUDAFResolver {
    +    @Override
    +    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) 
throws SemanticException {
    +        if (typeInfo.length != 2 && typeInfo.length != 3) {
    +            throw new UDFArgumentTypeException(typeInfo.length - 1,
    +                "_FUNC_ takes two or three arguments");
    +        }
    +
    +        boolean isArg1ListOrIntOrBoolean = 
HiveUtils.isListTypeInfo(typeInfo[0])
    +                || HiveUtils.isIntegerTypeInfo(typeInfo[0])
    +                || HiveUtils.isBooleanTypeInfo(typeInfo[0]);
    +        if (!isArg1ListOrIntOrBoolean) {
    +            throw new UDFArgumentTypeException(0,
    +                "The first argument `array/int/boolean actual` is invalid 
form: " + typeInfo[0]);
    +        }
    +
    +        boolean isArg2ListOrIntOrBoolean = 
HiveUtils.isListTypeInfo(typeInfo[1])
    +                || HiveUtils.isIntegerTypeInfo(typeInfo[1])
    +                || HiveUtils.isBooleanTypeInfo(typeInfo[1]);
    +        if (!isArg2ListOrIntOrBoolean) {
    +            throw new UDFArgumentTypeException(1,
    +                "The first argument `array/int/boolean actual` is invalid 
form: " + typeInfo[1]);
    +        }
    +
    +        if (typeInfo[0] != typeInfo[1]) {
    +            throw new UDFArgumentTypeException(1, "The first argument's 
`actual` type is "
    +                    + typeInfo[0] + ", but the second argument 
`predicated`'s type is not match: "
    +                    + typeInfo[1]);
    +        }
    +
    +        return new Evaluator();
    +    }
    +
    +    public static class Evaluator extends UDAFEvaluatorWithOptions {
    +
    +        private ObjectInspector actualOI;
    +        private ObjectInspector predictedOI;
    +        private StructObjectInspector internalMergeOI;
    +
    +        private StructField tpField;
    +        private StructField totalActualField;
    +        private StructField totalPredictedField;
    +        private StructField betaOptionField;
    +        private StructField averageOptionFiled;
    +
    +        private double beta;
    +        private String average;
    +
    +        public Evaluator() {}
    +
    +        @Override
    +        protected Options getOptions() {
    +            Options opts = new Options();
    +            opts.addOption("beta", true, "The weight of precision 
[default: 1.]");
    +            opts.addOption("average", true, "The way of average 
calculation [default: micro]");
    +            return opts;
    +        }
    +
    +        @Override
    +        protected CommandLine processOptions(ObjectInspector[] argOIs) 
throws UDFArgumentException {
    +            CommandLine cl = null;
    +
    +            if (argOIs.length >= 3) {
    +                String rawArgs = HiveUtils.getConstString(argOIs[2]);
    +                cl = parseOptions(rawArgs);
    +
    +                this.beta = 
Primitives.parseDouble(cl.getOptionValue("beta"), 1.0d);
    +                if (this.beta <= 0.d) {
    +                    throw new UDFArgumentException(
    +                        "The third argument `double beta` must be greater 
than 0.0: " + beta);
    +                }
    +
    +                this.average = cl.getOptionValue("average", "micro");
    +                if (!(this.average.equals("binary") || 
this.average.equals("macro") || this.average.equals("micro"))) {
    +                    throw new UDFArgumentException(
    +                        "The third argument `String average` must be one 
of the {binary, micro, macro}: "
    +                                + this.average);
                     }
    -                this.tp += countTp;
    -                this.totalAcutal += numActual;
    -                this.totalPredicted += numPredicted;
    +            } else {
    +                this.beta = 1.0d;
    +                this.average = "micro";
                 }
    +            return cl;
    +        }
     
    -            void merge(PartialResult other) {
    -                this.tp = other.tp;
    -                this.totalAcutal = other.totalAcutal;
    -                this.totalPredicted = other.totalPredicted;
    +        @Override
    +        public ObjectInspector init(Mode mode, ObjectInspector[] 
parameters) throws HiveException {
    +            assert (parameters.length == 2 || parameters.length == 3) : 
parameters.length;
    +            super.init(mode, parameters);
    +
    +            // initialize input
    +            if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from 
original data
    +                this.processOptions(parameters);
    +                this.actualOI = parameters[0];
    +                this.predictedOI = parameters[1];
    +            } else {// from partial aggregation
    +                StructObjectInspector soi = (StructObjectInspector) 
parameters[0];
    +                this.internalMergeOI = soi;
    +                this.tpField = soi.getStructFieldRef("tp");
    +                this.totalActualField = 
soi.getStructFieldRef("totalActual");
    +                this.totalPredictedField = 
soi.getStructFieldRef("totalPredicted");
    +                this.betaOptionField = soi.getStructFieldRef("beta");
    +                this.averageOptionFiled = soi.getStructFieldRef("average");
                 }
    +
    +            // initialize output
    +            final ObjectInspector outputOI;
    +            if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// 
terminatePartial
    +                outputOI = internalMergeOI();
    +            } else {// terminate
    +                outputOI = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
    +            }
    +            return outputOI;
             }
     
    -        private PartialResult partial;
    +        private static StructObjectInspector internalMergeOI() {
    +            ArrayList<String> fieldNames = new ArrayList<>();
    +            ArrayList<ObjectInspector> fieldOIs = new ArrayList<>();
     
    -        @Override
    -        public void init() {
    -            this.partial = null;
    +            fieldNames.add("tp");
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
    +            fieldNames.add("totalActual");
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
    +            fieldNames.add("totalPredicted");
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
    +            fieldNames.add("beta");
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
    +            fieldNames.add("average");
    +            
fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    +
    +
    +            return 
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
             }
     
    -        public boolean iterate(List<IntWritable> actual, List<IntWritable> 
predicted) {
    -            if (partial == null) {
    -                this.partial = new PartialResult();
    -            }
    -            partial.updateScore(actual, predicted);
    -            return true;
    +        @Override
    +        public FMeasureAggregationBuffer getNewAggregationBuffer() throws 
HiveException {
    +            FMeasureAggregationBuffer myAggr = new 
FMeasureAggregationBuffer();
    +            reset(myAggr);
    +            return myAggr;
             }
     
    -        public PartialResult terminatePartial() {
    -            return partial;
    +        @Override
    +        public void reset(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
    +                throws HiveException {
    +            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) 
agg;
    +            myAggr.reset();
    +            myAggr.setOptions(this.beta, this.average);
             }
     
    -        public boolean merge(PartialResult other) {
    -            if (other == null) {
    -                return true;
    +        @Override
    +        public void iterate(@SuppressWarnings("deprecation") 
AggregationBuffer agg,
    +                Object[] parameters) throws HiveException {
    +            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) 
agg;
    +            boolean isList = HiveUtils.isListOI(actualOI) && 
HiveUtils.isListOI(predictedOI);
    +
    +            List<?> actual = Collections.emptyList();
    +            List<?> predicted = Collections.emptyList();
    +
    +            if (this.average.equals("macro")) {
    +                throw new UnsupportedOperationException();
                 }
    -            if (partial == null) {
    -                this.partial = new PartialResult();
    +
    +
    +            if (isList) {// array case
    +                if (this.average.equals("binary")) {
    +                    throw new UnsupportedOperationException();
    +                }
    +                actual = ((ListObjectInspector) 
predictedOI).getList(parameters[0]);
    +                predicted = ((ListObjectInspector) 
predictedOI).getList(parameters[1]);
    +            } else {//binary case
    +                if (HiveUtils.isBooleanOI(actualOI)) { // boolean case
    +                    if (((BooleanObjectInspector) 
actualOI).get(parameters[0])) {
    +                        actual = Arrays.asList(1);
    +                    } else {
    +                        actual = Arrays.asList(0);
    +                    }
    +
    +                    if (((BooleanObjectInspector) 
predictedOI).get(parameters[1])) {
    +                        predicted = Arrays.asList(1);
    +                    } else {
    +                        predicted = Arrays.asList(0);
    +                    }
    +                } else { // int case
    +                    int actualOIValue = ((IntObjectInspector) 
actualOI).get(parameters[0]);
    +                    if (actualOIValue == 1) {
    +                        actual = Arrays.asList(1);
    +                    } else if (!(actualOIValue == 0 || actualOIValue == 
-1)) {
    +                        throw new UDFArgumentException(
    +                            "The first argument `int actual` must be 1, 0, 
or -1:" + actualOIValue);
    +                    } else if (!this.average.equals("binary")) {
    +                        actual = Arrays.asList(0);
    +                    }
    +
    +                    int predictedOIValue = ((IntObjectInspector) 
predictedOI).get(parameters[1]);
    +                    if (predictedOIValue == 1) {
    +                        predicted = Arrays.asList(1);
    +                    } else if (!(predictedOIValue == 0 || predictedOIValue 
== -1)) {
    +                        throw new UDFArgumentException(
    +                            "The second argument `int predicted` must be 
1, 0, or -1:"
    +                                    + predictedOIValue);
    +                    } else if (!this.average.equals("binary")) {
    +                        predicted = Arrays.asList(0);
    +                    }
    +                }
                 }
    -            partial.merge(other);
    -            return true;
    +            myAggr.iterate(actual, predicted);
             }
     
    -        public DoubleWritable terminate() {
    +        @Override
    +        public Object terminatePartial(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
    +                throws HiveException {
    +            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) 
agg;
    +
    +            Object[] partialResult = new Object[5];
    +            partialResult[0] = new LongWritable(myAggr.tp);
    +            partialResult[1] = new LongWritable(myAggr.totalActual);
    +            partialResult[2] = new LongWritable(myAggr.totalPredicted);
    +            partialResult[3] = new DoubleWritable(myAggr.beta);
    +            partialResult[4] = myAggr.average;
    +            return partialResult;
    +        }
    +
    +        @Override
    +        public void merge(@SuppressWarnings("deprecation") 
AggregationBuffer agg, Object partial)
    +                throws HiveException {
                 if (partial == null) {
    -                return null;
    +                return;
                 }
    -            double score = f1Score(partial);
    -            return WritableUtils.val(score);
    +
    +            Object tpObj = internalMergeOI.getStructFieldData(partial, 
tpField);
    +            Object totalActualObj = 
internalMergeOI.getStructFieldData(partial, totalActualField);
    +            Object totalPredictedObj = 
internalMergeOI.getStructFieldData(partial,
    +                totalPredictedField);
    +            Object betaObj = internalMergeOI.getStructFieldData(partial, 
betaOptionField);
    +            Object averageObj = 
internalMergeOI.getStructFieldData(partial, averageOptionFiled);
    +            long tp = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj);
    +            long totalActual = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(totalActualObj);
    +            long totalPredicted = 
PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(totalPredictedObj);
    +            double beta = 
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(betaObj);
    +            String average = 
PrimitiveObjectInspectorFactory.writableStringObjectInspector.getPrimitiveJavaObject(averageObj);
    +
    +            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) 
agg;
    +            myAggr.merge(tp, totalActual, totalPredicted, beta, average);
    +        }
    +
    +        @Override
    +        public DoubleWritable terminate(@SuppressWarnings("deprecation") 
AggregationBuffer agg)
    +                throws HiveException {
    +            FMeasureAggregationBuffer myAggr = (FMeasureAggregationBuffer) 
agg;
    +            double result = myAggr.get();
    +            return new DoubleWritable(result);
    +        }
    +    }
    +
    +    public static class FMeasureAggregationBuffer extends
    +            GenericUDAFEvaluator.AbstractAggregationBuffer {
    +        long tp;
    +        /** tp + fn */
    +        long totalActual;
    +        /** tp + fp */
    +        long totalPredicted;
    +        double beta;
    +        String average;
    +
    +        public FMeasureAggregationBuffer() {
    +            super();
             }
     
    -        /**
    -         * @return 2 * precision * recall / (precision + recall)
    -         */
    -        private static double f1Score(final PartialResult partial) {
    -            double precision = precision(partial);
    -            double recall = recall(partial);
    -            double divisor = precision + recall;
    +        void setOptions(double beta, String average) {
    +            this.beta = beta;
    +            this.average = average;
    +        }
    +
    +        void reset() {
    +            this.tp = 0L;
    +            this.totalActual = 0L;
    +            this.totalPredicted = 0L;
    +        }
    +
    +        void merge(long o_tp, long o_actual, long o_predicted, double 
beta, String average) {
    +            tp += o_tp;
    +            totalActual += o_actual;
    +            totalPredicted += o_predicted;
    +            this.beta = beta;
    +            this.average = average;
    +        }
    +
    +        double get() {
    +            double squareBeta = Math.pow(beta, 2.d);
    +            double divisor;
    +            double numerator;
    +
    +            if (average.equals("micro")) {
    +                divisor = denom(tp, totalActual, totalPredicted, 
squareBeta);
    +                numerator = (1.d + squareBeta) * tp;
    +            } else if (average.equals("binary")) {
    +                double precision = precision(tp, totalPredicted);
    +                double recall = recall(tp, totalActual);
    +                divisor = squareBeta * precision + recall;
    +                numerator = (1.d + squareBeta) * precision * recall;
    +            } else {
    +                throw new UnsupportedOperationException();
    +            }
    +
                 if (divisor > 0) {
    -                return (2.d * precision * recall) / divisor;
    +                return (numerator / divisor);
                 } else {
    -                return -1d;
    +                return -1.d;
    +            }
    +        }
    +
    +        private static double denom(long tp, long totalActual, long 
totalPredicted,
    +                double squareBeta) {
    +            long lp = totalPredicted - tp;
    +
    +            if (lp < 0) {
    --- End diff --
    
    Is this kind of situation (`totalPredicted < tp` and `totalActual < tp`) 
possible? I imagine that TP is always a subset of predicted/actual labels.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

Reply via email to