Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/111#discussion_r139901202
--- Diff: core/src/main/java/hivemall/evaluation/HitRateUDAF.java ---
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+package hivemall.evaluation;
+
+import hivemall.utils.hadoop.HiveUtils;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import javax.annotation.Nonnull;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.parse.SemanticException;
+import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
+import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.StructField;
+import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
+import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
+import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
+import org.apache.hadoop.io.LongWritable;
+
+
+@Description(
+ name = "hitrate",
+ value = "_FUNC_(array rankItems, array correctItems [, const int
recommendSize = rankItems.size])"
+ + " - Returns HitRate")
+public final class HitRateUDAF extends AbstractGenericUDAFResolver {
+
+ // prevent instantiation
+ private HitRateUDAF() {}
+
+ @Override
+ public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo)
throws SemanticException {
+ if (typeInfo.length != 2 && typeInfo.length != 3) {
+ throw new UDFArgumentTypeException(typeInfo.length - 1,
+ "_FUNC_ takes two or three arguments");
+ }
+
+ ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
+ if
(!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(0,
+ "The first argument `array rankItems` is invalid form: " +
typeInfo[0]);
+ }
+ ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
+ if
(!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) {
+ throw new UDFArgumentTypeException(1,
+ "The second argument `array correctItems` is invalid form:
" + typeInfo[1]);
+ }
+
+ return new HitRateUDAF.Evaluator();
+ }
+
+ public static class Evaluator extends GenericUDAFEvaluator {
+
+ private ListObjectInspector recommendListOI;
+ private ListObjectInspector truthListOI;
+ private WritableIntObjectInspector recommendSizeOI;
+
+ private StructObjectInspector internalMergeOI;
+ private StructField countField;
+ private StructField sumField;
+
+ public Evaluator() {}
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[]
parameters) throws HiveException {
+ assert (parameters.length == 2 || parameters.length == 3) :
parameters.length;
+ super.init(mode, parameters);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from
original data
+ this.recommendListOI = (ListObjectInspector) parameters[0];
+ this.truthListOI = (ListObjectInspector) parameters[1];
+ if (parameters.length == 3) {
+ this.recommendSizeOI = (WritableIntObjectInspector)
parameters[2];
+ }
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector)
parameters[0];
+ this.internalMergeOI = soi;
+ this.countField = soi.getStructFieldRef("count");
+ this.sumField = soi.getStructFieldRef("sum");
+ }
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {//
terminatePartial
+ outputOI = internalMergeOI();
+ } else {// terminate
+ outputOI =
PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
+ }
+ return outputOI;
+ }
+
+ private static StructObjectInspector internalMergeOI() {
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new
ArrayList<ObjectInspector>();
+
+ fieldNames.add("sum");
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+ fieldNames.add("count");
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
+
+ return
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ @Override
+ public HitRateAggregationBuffer getNewAggregationBuffer() throws
HiveException {
+ HitRateAggregationBuffer myAggr = new
HitRateAggregationBuffer();
+ reset(myAggr);
+ return myAggr;
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation")
AggregationBuffer agg)
+ throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer)
agg;
+ myAggr.reset();
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation")
AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ HitRateAggregationBuffer myAggr = (HitRateAggregationBuffer)
agg;
+
+ List<?> recommendList = recommendListOI.getList(parameters[0]);
+ if (recommendList == null) {
+ recommendList = Collections.emptyList();
+ }
+ List<?> truthList = truthListOI.getList(parameters[1]);
+ if (truthList == null) {
+ return;
+ }
+
+ int recommendSize = recommendList.size();
+ if (parameters.length == 3) {
+ recommendSize = recommendSizeOI.get(parameters[2]);
+ }
+ if (recommendSize < 0 || recommendSize > recommendList.size())
{
--- End diff --
Hit rate should accept `recommendSize > recommendList.size()`.
---