Repository: incubator-hivemall Updated Branches: refs/heads/master 380478916 -> 688daa5f8
Close #108: [HIVEMALL-138] to_ordered_map & to_ordered_list as a UDAF variant of each_top_k Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9cd3c59a Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9cd3c59a Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9cd3c59a Branch: refs/heads/master Commit: 9cd3c59aebb67cc6b58cdd611b96fcf42f297cde Parents: 3804789 Author: Takuya Kitazawa <[email protected]> Authored: Mon Sep 11 15:38:05 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Mon Sep 11 15:38:05 2017 +0900 ---------------------------------------------------------------------- .../hivemall/tools/list/UDAFToOrderedList.java | 535 +++++++++++++++++++ .../hivemall/tools/map/UDAFToOrderedMap.java | 214 +++++++- .../java/hivemall/utils/hadoop/HiveUtils.java | 7 + .../tools/array/SelectKBeatUDFTest.java | 69 --- .../tools/array/SelectKBestUDFTest.java | 69 +++ .../tools/list/UDAFToOrderedListTest.java | 344 ++++++++++++ .../tools/map/UDAFToOrderedMapTest.java | 153 ++++++ docs/gitbook/eval/rank.md | 5 + docs/gitbook/misc/generic_funcs.md | 7 +- docs/gitbook/misc/topk.md | 63 +++ docs/gitbook/recommend/item_based_cf.md | 5 + docs/gitbook/recommend/movielens_cf.md | 5 + resources/ddl/define-all-as-permanent.hive | 7 + resources/ddl/define-all.hive | 7 + resources/ddl/define-all.spark | 7 + resources/ddl/define-udfs.td.hql | 1 + 16 files changed, 1419 insertions(+), 79 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java new file mode 100644 index 0000000..16c966a --- /dev/null +++ b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java @@ -0,0 +1,535 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.list; + +import hivemall.utils.collections.BoundedPriorityQueue; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.CommandLineUtils; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.IntWritable; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.*; + +/** + * Return list of values sorted by value itself or specific key. + */ +@Description( + name = "to_ordered_list", + value = "_FUNC_(value [, key, const string options]) - Return list of values sorted by value itself or specific key") +public class UDAFToOrderedList extends AbstractGenericUDAFResolver { + + @Override + public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) + throws SemanticException { + @SuppressWarnings("deprecation") + TypeInfo[] typeInfo = info.getParameters(); + ObjectInspector[] argOIs = info.getParameterObjectInspectors(); + if ((typeInfo.length == 1) || (typeInfo.length == 2 && HiveUtils.isConstString(argOIs[1]))) { + // sort values by value itself w/o key + if (typeInfo[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "Only primitive type arguments are accepted for value but " + + typeInfo[0].getTypeName() + " was passed as the first parameter."); + } + } else if ((typeInfo.length == 2) + || (typeInfo.length == 3 && HiveUtils.isConstString(argOIs[2]))) { + // sort values by key + if (typeInfo[1].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(1, + "Only primitive type arguments are accepted for key but " + + typeInfo[1].getTypeName() + " was passed as the second parameter."); + } + } else { + throw new UDFArgumentTypeException(typeInfo.length - 1, + "Number of arguments must be in [1, 3] including constant string for options: " + + typeInfo.length); + } + return new UDAFToOrderedListEvaluator(); + } + + public static class UDAFToOrderedListEvaluator extends GenericUDAFEvaluator { + + private ObjectInspector valueOI; + private PrimitiveObjectInspector keyOI; + + private ListObjectInspector valueListOI; + private ListObjectInspector keyListOI; + + private StructObjectInspector internalMergeOI; + + private StructField valueListField; + private StructField keyListField; + private StructField sizeField; + private StructField reverseOrderField; + + @Nonnegative + private int size; + private boolean reverseOrder; + private boolean sortByKey; + + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", true, "To top-k (positive) or tail-k (negative) ordered queue"); + opts.addOption("reverse", "reverse_order", false, + "Sort values by key in a reverse (e.g., descending) order [default: false]"); + return opts; + } + + @Nonnull + protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException { + String[] args = optionValue.split("\\s+"); + Options opts = getOptions(); + opts.addOption("help", false, "Show function help"); + CommandLine cl = CommandLineUtils.parseOptions(args, opts); + + if (cl.hasOption("help")) { + Description funcDesc = getClass().getAnnotation(Description.class); + final String cmdLineSyntax; + if (funcDesc == null) { + cmdLineSyntax = getClass().getSimpleName(); + } else { + String funcName = funcDesc.name(); + cmdLineSyntax = funcName == null ? getClass().getSimpleName() + : funcDesc.value().replace("_FUNC_", funcDesc.name()); + } + StringWriter sw = new StringWriter(); + sw.write('\n'); + PrintWriter pw = new PrintWriter(sw); + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, + HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); + pw.flush(); + String helpMsg = sw.toString(); + throw new UDFArgumentException(helpMsg); + } + + return cl; + } + + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + int optionIndex = 1; + if (sortByKey) { + optionIndex = 2; + } + + int k = 0; + boolean reverseOrder = false; + + if (argOIs.length >= optionIndex + 1) { + String rawArgs = HiveUtils.getConstString(argOIs[optionIndex]); + cl = parseOptions(rawArgs); + + reverseOrder = cl.hasOption("reverse_order"); + + if (cl.hasOption("k")) { + k = Integer.parseInt(cl.getOptionValue("k")); + if (k == 0) { + throw new UDFArgumentException("`k` must be nonzero: " + k); + } + } + } + + this.size = Math.abs(k); + + if ((k > 0 && reverseOrder) || (k < 0 && !reverseOrder) || (k == 0 && !reverseOrder)) { + // reverse top-k, natural tail-k = ascending = natural order output = reverse order priority queue + this.reverseOrder = true; + } else { // (k > 0 && !reverseOrder) || (k < 0 && reverseOrder) || (k == 0 && reverseOrder) + // natural top-k or reverse tail-k = descending = reverse order output = natural order priority queue + this.reverseOrder = false; + } + + return cl; + } + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveException { + super.init(mode, argOIs); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + // this flag will be used in `processOptions` and `iterate` (= when Mode.PARTIAL1 or Mode.COMPLETE) + this.sortByKey = (argOIs.length == 2 && !HiveUtils.isConstString(argOIs[1])) + || (argOIs.length == 3 && HiveUtils.isConstString(argOIs[2])); + + if (sortByKey) { + this.valueOI = argOIs[0]; + this.keyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]); + } else { + // sort values by value itself + this.valueOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.keyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + } + + processOptions(argOIs); + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) argOIs[0]; + this.internalMergeOI = soi; + + // re-extract input value OI + this.valueListField = soi.getStructFieldRef("valueList"); + StandardListObjectInspector valueListOI = (StandardListObjectInspector) valueListField.getFieldObjectInspector(); + this.valueOI = valueListOI.getListElementObjectInspector(); + this.valueListOI = ObjectInspectorFactory.getStandardListObjectInspector(valueOI); + + // re-extract input key OI + this.keyListField = soi.getStructFieldRef("keyList"); + StandardListObjectInspector keyListOI = (StandardListObjectInspector) keyListField.getFieldObjectInspector(); + this.keyOI = HiveUtils.asPrimitiveObjectInspector(keyListOI.getListElementObjectInspector()); + this.keyListOI = ObjectInspectorFactory.getStandardListObjectInspector(keyOI); + + this.sizeField = soi.getStructFieldRef("size"); + this.reverseOrderField = soi.getStructFieldRef("reverseOrder"); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(valueOI, keyOI); + } else {// terminate + outputOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(valueOI)); + } + + return outputOI; + } + + private static StructObjectInspector internalMergeOI(@Nonnull ObjectInspector valueOI, + @Nonnull PrimitiveObjectInspector keyOI) { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("valueList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(valueOI))); + + fieldNames.add("keyList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(keyOI))); + + fieldNames.add("size"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("reverseOrder"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @SuppressWarnings("deprecation") + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + QueueAggregationBuffer myagg = new QueueAggregationBuffer(); + reset(myagg); + return myagg; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + myagg.reset(size, reverseOrder); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + if (parameters[0] == null) { + return; + } + Object value = ObjectInspectorUtils.copyToStandardObject(parameters[0], valueOI); + + final Object key; + if (sortByKey) { + if (parameters[1] == null) { + return; + } + key = ObjectInspectorUtils.copyToStandardObject(parameters[1], keyOI); + } else { + // set value to key + key = ObjectInspectorUtils.copyToStandardObject(parameters[0], valueOI); + } + + TupleWithKey tuple = new TupleWithKey(key, value); + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + + myagg.iterate(tuple); + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + + Map<String, List<Object>> tuples = myagg.drainQueue(); + List<Object> valueList = tuples.get("value"); + List<Object> keyList = tuples.get("key"); + if (valueList.size() == 0) { + return null; + } + + Object[] partialResult = new Object[4]; + partialResult[0] = valueList; + partialResult[1] = keyList; + partialResult[2] = new IntWritable(myagg.size); + partialResult[3] = new BooleanWritable(myagg.reverseOrder); + + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + Object valueListObj = internalMergeOI.getStructFieldData(partial, valueListField); + final List<?> valueListRaw = valueListOI.getList(HiveUtils.castLazyBinaryObject(valueListObj)); + final List<Object> valueList = new ArrayList<Object>(); + for (int i = 0, n = valueListRaw.size(); i < n; i++) { + valueList.add(ObjectInspectorUtils.copyToStandardObject(valueListRaw.get(i), + valueOI)); + } + + Object keyListObj = internalMergeOI.getStructFieldData(partial, keyListField); + final List<?> keyListRaw = keyListOI.getList(HiveUtils.castLazyBinaryObject(keyListObj)); + final List<Object> keyList = new ArrayList<Object>(); + for (int i = 0, n = keyListRaw.size(); i < n; i++) { + keyList.add(ObjectInspectorUtils.copyToStandardObject(keyListRaw.get(i), keyOI)); + } + + Object sizeObj = internalMergeOI.getStructFieldData(partial, sizeField); + int size = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(sizeObj); + + Object reverseOrderObj = internalMergeOI.getStructFieldData(partial, reverseOrderField); + boolean reverseOrder = PrimitiveObjectInspectorFactory.writableBooleanObjectInspector.get(reverseOrderObj); + + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + myagg.setOptions(size, reverseOrder); + myagg.merge(keyList, valueList); + } + + @Override + public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + Map<String, List<Object>> tuples = myagg.drainQueue(); + return tuples.get("value"); + } + + static class QueueAggregationBuffer extends AbstractAggregationBuffer { + + private AbstractQueueHandler queueHandler; + + @Nonnegative + private int size; + private boolean reverseOrder; + + QueueAggregationBuffer() { + super(); + } + + void reset(@Nonnegative int size, boolean reverseOrder) { + setOptions(size, reverseOrder); + this.queueHandler = null; + } + + void setOptions(@Nonnegative int size, boolean reverseOrder) { + this.size = size; + this.reverseOrder = reverseOrder; + } + + void iterate(TupleWithKey tuple) { + if (queueHandler == null) { + initQueueHandler(); + } + queueHandler.offer(tuple); + } + + void merge(List<Object> o_keyList, List<Object> o_valueList) { + if (queueHandler == null) { + initQueueHandler(); + } + for (int i = 0, n = o_keyList.size(); i < n; i++) { + queueHandler.offer(new TupleWithKey(o_keyList.get(i), o_valueList.get(i))); + } + } + + @Nonnull + Map<String, List<Object>> drainQueue() { + int n = queueHandler.size(); + final Object[] keys = new Object[n]; + final Object[] values = new Object[n]; + for (int i = n - 1; i >= 0; i--) { // head element in queue should be stored to tail of array + TupleWithKey tuple = queueHandler.poll(); + keys[i] = tuple.getKey(); + values[i] = tuple.getValue(); + } + queueHandler.clear(); + + Map<String, List<Object>> res = new HashMap<String, List<Object>>(); + res.put("key", Arrays.asList(keys)); + res.put("value", Arrays.asList(values)); + return res; + } + + private void initQueueHandler() { + final Comparator<TupleWithKey> comparator; + if (reverseOrder) { + comparator = Collections.reverseOrder(); + } else { + comparator = new Comparator<TupleWithKey>() { + @Override + public int compare(TupleWithKey o1, TupleWithKey o2) { + return o1.compareTo(o2); + } + }; + } + + if (size > 0) { + this.queueHandler = new BoundedQueueHandler(size, comparator); + } else { + this.queueHandler = new QueueHandler(comparator); + } + } + + } + + /** + * Since BoundedPriorityQueue does not directly inherit PriorityQueue, we provide handler + * class which wraps each of PriorityQueue and BoundedPriorityQueue. + */ + private static abstract class AbstractQueueHandler { + + abstract void offer(TupleWithKey tuple); + + abstract int size(); + + abstract TupleWithKey poll(); + + abstract void clear(); + + } + + private static final class QueueHandler extends AbstractQueueHandler { + + private static final int DEFAULT_INITIAL_CAPACITY = 11; // same as PriorityQueue + + private final PriorityQueue<TupleWithKey> queue; + + QueueHandler(@Nonnull Comparator<TupleWithKey> comparator) { + this.queue = new PriorityQueue<TupleWithKey>(DEFAULT_INITIAL_CAPACITY, comparator); + } + + @Override + void offer(TupleWithKey tuple) { + queue.offer(tuple); + } + + @Override + int size() { + return queue.size(); + } + + @Override + TupleWithKey poll() { + return queue.poll(); + } + + @Override + void clear() { + queue.clear(); + } + + } + + private static final class BoundedQueueHandler extends AbstractQueueHandler { + + private final BoundedPriorityQueue<TupleWithKey> queue; + + BoundedQueueHandler(int size, @Nonnull Comparator<TupleWithKey> comparator) { + this.queue = new BoundedPriorityQueue<TupleWithKey>(size, comparator); + } + + @Override + void offer(TupleWithKey tuple) { + queue.offer(tuple); + } + + @Override + int size() { + return queue.size(); + } + + @Override + TupleWithKey poll() { + return queue.poll(); + } + + @Override + void clear() { + queue.clear(); + } + + } + + private static final class TupleWithKey implements Comparable<TupleWithKey> { + private Object key; + private Object value; + + TupleWithKey(Object key, Object value) { + this.key = key; + this.value = value; + } + + Object getKey() { + return key; + } + + Object getValue() { + return value; + } + + @Override + public int compareTo(TupleWithKey o) { + Comparable<? super Object> k = (Comparable<? super Object>) key; + return k.compareTo(o.getKey()); + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java b/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java index 4d011cd..3e6caa4 100644 --- a/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java +++ b/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java @@ -20,23 +20,37 @@ package hivemall.tools.map; import hivemall.utils.hadoop.HiveUtils; +import java.util.ArrayList; import java.util.Collections; +import java.util.Map; +import java.util.SortedMap; import java.util.TreeMap; import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.IntWritable; + +import javax.annotation.Nonnull; /** * Convert two aggregated columns into a sorted key-value map. */ @Description(name = "to_ordered_map", - value = "_FUNC_(key, value [, const boolean reverseOrder=false]) " + value = "_FUNC_(key, value [, const int k|const boolean reverseOrder=false]) " + "- Convert two aggregated columns into an ordered key-value map") public class UDAFToOrderedMap extends UDAFToMap { @@ -54,19 +68,35 @@ public class UDAFToOrderedMap extends UDAFToMap { "Only primitive type arguments are accepted for the key but " + typeInfo[0].getTypeName() + " was passed as parameter 1."); } + boolean reverseOrder = false; + int size = 0; if (typeInfo.length == 3) { - if (HiveUtils.isBooleanTypeInfo(typeInfo[2]) == false) { - throw new UDFArgumentTypeException(2, "The three argument must be boolean type: " - + typeInfo[2].getTypeName()); - } ObjectInspector[] argOIs = info.getParameterObjectInspectors(); - reverseOrder = HiveUtils.getConstBoolean(argOIs[2]); + if (HiveUtils.isBooleanTypeInfo(typeInfo[2])) { + reverseOrder = HiveUtils.getConstBoolean(argOIs[2]); + } else if (HiveUtils.isIntegerTypeInfo(typeInfo[2])) { + size = HiveUtils.getConstInt(argOIs[2]); + if (size == 0) { + throw new UDFArgumentException("Map size must be nonzero: " + size); + } + reverseOrder = (size > 0); // positive size => top-k + } else { + throw new UDFArgumentTypeException(2, + "The third argument must be boolean or integer type: " + + typeInfo[2].getTypeName()); + } } - if (reverseOrder) { + if (reverseOrder) { // descending + if (size != 0) { + return new TopKOrderedMapEvaluator(); + } return new ReverseOrderedMapEvaluator(); - } else { + } else { // ascending + if (size != 0) { + return new TailKOrderedMapEvaluator(); + } return new NaturalOrderedMapEvaluator(); } } @@ -92,4 +122,172 @@ public class UDAFToOrderedMap extends UDAFToMap { } + public static class TopKOrderedMapEvaluator extends GenericUDAFEvaluator { + + protected PrimitiveObjectInspector inputKeyOI; + protected ObjectInspector inputValueOI; + protected StandardMapObjectInspector partialMapOI; + protected PrimitiveObjectInspector sizeOI; + + protected StructObjectInspector internalMergeOI; + + protected StructField partialMapField; + protected StructField sizeField; + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveException { + super.init(mode, argOIs); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.inputKeyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.inputValueOI = argOIs[1]; + this.sizeOI = HiveUtils.asIntegerOI(argOIs[2]); + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) argOIs[0]; + this.internalMergeOI = soi; + + this.partialMapField = soi.getStructFieldRef("partialMap"); + // re-extract input key/value OIs + StandardMapObjectInspector partialMapOI = (StandardMapObjectInspector) partialMapField.getFieldObjectInspector(); + this.inputKeyOI = HiveUtils.asPrimitiveObjectInspector(partialMapOI.getMapKeyObjectInspector()); + this.inputValueOI = partialMapOI.getMapValueObjectInspector(); + + this.partialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI), + ObjectInspectorUtils.getStandardObjectInspector(inputValueOI)); + + this.sizeField = soi.getStructFieldRef("size"); + this.sizeOI = (PrimitiveObjectInspector) sizeField.getFieldObjectInspector(); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(inputKeyOI, inputValueOI); + } else {// terminate + outputOI = ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI), + ObjectInspectorUtils.getStandardObjectInspector(inputValueOI)); + } + return outputOI; + } + + private static StructObjectInspector internalMergeOI( + @Nonnull PrimitiveObjectInspector keyOI, @Nonnull ObjectInspector valueOI) { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("partialMap"); + fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(keyOI), + ObjectInspectorUtils.getStandardObjectInspector(valueOI))); + + fieldNames.add("size"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + static class MapAggregationBuffer extends AbstractAggregationBuffer { + Map<Object, Object> container; + int size; + + MapAggregationBuffer() { + super(); + } + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + myagg.container = new TreeMap<Object, Object>(Collections.reverseOrder()); + myagg.size = Integer.MAX_VALUE; + } + + @Override + public MapAggregationBuffer getNewAggregationBuffer() throws HiveException { + MapAggregationBuffer myagg = new MapAggregationBuffer(); + reset(myagg); + return myagg; + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + assert (parameters.length == 3); + + if (parameters[0] == null) { + return; + } + + Object key = ObjectInspectorUtils.copyToStandardObject(parameters[0], inputKeyOI); + Object value = ObjectInspectorUtils.copyToStandardObject(parameters[1], inputValueOI); + int size = Math.abs(HiveUtils.getInt(parameters[2], sizeOI)); // size could be negative for tail-k + + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + myagg.container.put(key, value); + myagg.size = size; + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + + Object[] partialResult = new Object[2]; + partialResult[0] = myagg.container; + partialResult[1] = new IntWritable(myagg.size); + + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + + Object partialMapObj = internalMergeOI.getStructFieldData(partial, partialMapField); + Map<?, ?> partialMap = partialMapOI.getMap(HiveUtils.castLazyBinaryObject(partialMapObj)); + for (Map.Entry<?, ?> e : partialMap.entrySet()) { + Object key = ObjectInspectorUtils.copyToStandardObject(e.getKey(), inputKeyOI); + Object value = ObjectInspectorUtils.copyToStandardObject(e.getValue(), inputValueOI); + myagg.container.put(key, value); + } + + Object sizeObj = internalMergeOI.getStructFieldData(partial, sizeField); + int size = HiveUtils.getInt(sizeObj, sizeOI); + myagg.size = size; + } + + @Override + public Map<Object, Object> terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + if (myagg.size < myagg.container.size()) { + Object toKey = myagg.container.keySet().toArray()[myagg.size]; + return ((SortedMap<Object, Object>) myagg.container).headMap(toKey); + } + return myagg.container; + } + + } + + public static class TailKOrderedMapEvaluator extends TopKOrderedMapEvaluator { + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + myagg.container = new TreeMap<Object, Object>(); + myagg.size = Integer.MAX_VALUE; + } + + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index db56b82..afa8a58 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -440,6 +440,13 @@ public final class HiveUtils { return PrimitiveObjectInspectorUtils.getDouble(o, oi); } + public static int getInt(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) { + if (o == null) { + return 0; + } + return PrimitiveObjectInspectorUtils.getInt(o, oi); + } + @SuppressWarnings("unchecked") @Nullable public static <T extends Writable> T getConstValue(@Nonnull final ObjectInspector oi) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java b/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java deleted file mode 100644 index 3e3fc12..0000000 --- a/core/src/test/java/hivemall/tools/array/SelectKBeatUDFTest.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.tools.array; - -import hivemall.utils.hadoop.WritableUtils; - -import java.util.List; - -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; -import org.apache.hadoop.hive.serde2.io.DoubleWritable; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.junit.Assert; -import org.junit.Test; - -public class SelectKBeatUDFTest { - - @Test - public void test() throws Exception { - final SelectKBestUDF selectKBest = new SelectKBestUDF(); - final int k = 2; - final double[] data = new double[] {250.29999999999998, 170.90000000000003, 73.2, - 12.199999999999996}; - final double[] importanceList = new double[] {292.1666753739119, 152.70000455081467, - 187.93333893418327, 59.93333511948589}; - - final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] { - new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(data)), - new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(importanceList)), - new GenericUDF.DeferredJavaObject(k)}; - - selectKBest.initialize(new ObjectInspector[] { - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), - ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaIntObjectInspector, k)}); - final List<DoubleWritable> resultObj = selectKBest.evaluate(dObjs); - - Assert.assertEquals(resultObj.size(), k); - - final double[] result = new double[k]; - for (int i = 0; i < k; i++) { - result[i] = resultObj.get(i).get(); - } - - final double[] answer = new double[] {250.29999999999998, 73.2}; - - Assert.assertArrayEquals(answer, result, 0.d); - selectKBest.close(); - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java b/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java new file mode 100644 index 0000000..15366a7 --- /dev/null +++ b/core/src/test/java/hivemall/tools/array/SelectKBestUDFTest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.array; + +import hivemall.utils.hadoop.WritableUtils; + +import java.util.List; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Test; + +public class SelectKBestUDFTest { + + @Test + public void test() throws Exception { + final SelectKBestUDF selectKBest = new SelectKBestUDF(); + final int k = 2; + final double[] data = new double[] {250.29999999999998, 170.90000000000003, 73.2, + 12.199999999999996}; + final double[] importanceList = new double[] {292.1666753739119, 152.70000455081467, + 187.93333893418327, 59.93333511948589}; + + final GenericUDF.DeferredObject[] dObjs = new GenericUDF.DeferredObject[] { + new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(data)), + new GenericUDF.DeferredJavaObject(WritableUtils.toWritableList(importanceList)), + new GenericUDF.DeferredJavaObject(k)}; + + selectKBest.initialize(new ObjectInspector[] { + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaIntObjectInspector, k)}); + final List<DoubleWritable> resultObj = selectKBest.evaluate(dObjs); + + Assert.assertEquals(resultObj.size(), k); + + final double[] result = new double[k]; + for (int i = 0; i < k; i++) { + result[i] = resultObj.get(i).get(); + } + + final double[] answer = new double[] {250.29999999999998, 73.2}; + + Assert.assertArrayEquals(answer, result, 0.d); + selectKBest.close(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java b/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java new file mode 100644 index 0000000..f466dbc --- /dev/null +++ b/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.list; + +import hivemall.tools.list.UDAFToOrderedList.UDAFToOrderedListEvaluator; +import hivemall.tools.list.UDAFToOrderedList.UDAFToOrderedListEvaluator.QueueAggregationBuffer; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.List; + +@SuppressWarnings("deprecation") +public class UDAFToOrderedListTest { + + GenericUDAFEvaluator evaluator; + QueueAggregationBuffer agg; + + @Before + public void setUp() throws Exception { + this.evaluator = new UDAFToOrderedListEvaluator(); + this.agg = (QueueAggregationBuffer) evaluator.getNewAggregationBuffer(); + } + + @Test + public void testNaturalOrder() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(3, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + Assert.assertEquals("candy", res.get(2)); + } + + @Test + public void testReverseOrder() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-reverse_order")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(3, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + Assert.assertEquals("apple", res.get(2)); + } + + @Test + public void testTopK() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testReverseTopK() throws Exception { + // = tail-k + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2 -reverse")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testTailK() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testReverseTailK() throws Exception { + // = top-k + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2 -reverse")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testNaturalOrderWithKey() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.7}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(3, res.size()); + Assert.assertEquals("apple", res.get(0)); + if (res.get(1) == "banana") { // duplicated key (0.7) + Assert.assertEquals("candy", res.get(2)); + } else { + Assert.assertEquals("banana", res.get(2)); + } + } + + @Test + public void testReverseOrderWithKey() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-reverse_order")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.7}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(3, res.size()); + if (res.get(0) == "banana") { // duplicated key (0.7) + Assert.assertEquals("candy", res.get(1)); + } else { + Assert.assertEquals("banana", res.get(1)); + } + Assert.assertEquals("apple", res.get(2)); + } + + @Test + public void testTopKWithKey() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testReverseTopKWithKey() throws Exception { + // = tail-k + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2 -reverse")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testTailKWithKey() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testReverseTailKWithKey() throws Exception { + // = top-k + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2 -reverse")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/core/src/test/java/hivemall/tools/map/UDAFToOrderedMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/map/UDAFToOrderedMapTest.java b/core/src/test/java/hivemall/tools/map/UDAFToOrderedMapTest.java new file mode 100644 index 0000000..9289a02 --- /dev/null +++ b/core/src/test/java/hivemall/tools/map/UDAFToOrderedMapTest.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import hivemall.tools.map.UDAFToOrderedMap.NaturalOrderedMapEvaluator; +import hivemall.tools.map.UDAFToOrderedMap.ReverseOrderedMapEvaluator; +import hivemall.tools.map.UDAFToOrderedMap.TopKOrderedMapEvaluator; +import hivemall.tools.map.UDAFToOrderedMap.TailKOrderedMapEvaluator; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.SortedMap; + +@SuppressWarnings("deprecation") +public class UDAFToOrderedMapTest { + + @Test + public void testNaturalOrder() throws Exception { + GenericUDAFEvaluator evaluator = new NaturalOrderedMapEvaluator(); + NaturalOrderedMapEvaluator.MapAggregationBuffer agg = (NaturalOrderedMapEvaluator.MapAggregationBuffer) evaluator.getNewAggregationBuffer(); + + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector}; + + final double[] keys = new double[] {0.7, 0.5, 0.8}; + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < keys.length; i++) { + evaluator.iterate(agg, new Object[] {keys[i], values[i]}); + } + + SortedMap<Object, Object> res = (SortedMap<Object, Object>) evaluator.terminate(agg); + Object[] sortedValues = res.values().toArray(); + + Assert.assertEquals(3, sortedValues.length); + Assert.assertEquals("apple", sortedValues[0]); + Assert.assertEquals("banana", sortedValues[1]); + Assert.assertEquals("candy", sortedValues[2]); + } + + @Test + public void testReverseOrder() throws Exception { + GenericUDAFEvaluator evaluator = new ReverseOrderedMapEvaluator(); + ReverseOrderedMapEvaluator.MapAggregationBuffer agg = (ReverseOrderedMapEvaluator.MapAggregationBuffer) evaluator.getNewAggregationBuffer(); + + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector}; + + final double[] keys = new double[] {0.7, 0.5, 0.8}; + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < keys.length; i++) { + evaluator.iterate(agg, new Object[] {keys[i], values[i]}); + } + + SortedMap<Object, Object> res = (SortedMap<Object, Object>) evaluator.terminate(agg); + Object[] sortedValues = res.values().toArray(); + + Assert.assertEquals(3, sortedValues.length); + Assert.assertEquals("candy", sortedValues[0]); + Assert.assertEquals("banana", sortedValues[1]); + Assert.assertEquals("apple", sortedValues[2]); + } + + @Test + public void testTopK() throws Exception { + GenericUDAFEvaluator evaluator = new TopKOrderedMapEvaluator(); + TopKOrderedMapEvaluator.MapAggregationBuffer agg = (TopKOrderedMapEvaluator.MapAggregationBuffer) evaluator.getNewAggregationBuffer(); + + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector}; + + final double[] keys = new double[] {0.7, 0.5, 0.8}; + final String[] values = new String[] {"banana", "apple", "candy"}; + int size = 2; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < keys.length; i++) { + evaluator.iterate(agg, new Object[] {keys[i], values[i], size}); + } + + SortedMap<Object, Object> res = (SortedMap<Object, Object>) evaluator.terminate(agg); + Object[] sortedValues = res.values().toArray(); + + Assert.assertEquals(size, sortedValues.length); + Assert.assertEquals("candy", sortedValues[0]); + Assert.assertEquals("banana", sortedValues[1]); + } + + @Test + public void testTailK() throws Exception { + GenericUDAFEvaluator evaluator = new TailKOrderedMapEvaluator(); + TailKOrderedMapEvaluator.MapAggregationBuffer agg = (TailKOrderedMapEvaluator.MapAggregationBuffer) evaluator.getNewAggregationBuffer(); + + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector}; + + final double[] keys = new double[] {0.7, 0.5, 0.8}; + final String[] values = new String[] {"banana", "apple", "candy"}; + int size = -2; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < keys.length; i++) { + evaluator.iterate(agg, new Object[] {keys[i], values[i], size}); + } + + SortedMap<Object, Object> res = (SortedMap<Object, Object>) evaluator.terminate(agg); + Object[] sortedValues = res.values().toArray(); + + Assert.assertEquals(Math.abs(size), sortedValues.length); + Assert.assertEquals("apple", sortedValues[0]); + Assert.assertEquals("banana", sortedValues[1]); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/docs/gitbook/eval/rank.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/eval/rank.md b/docs/gitbook/eval/rank.md index ed1a44c..db681ac 100644 --- a/docs/gitbook/eval/rank.md +++ b/docs/gitbook/eval/rank.md @@ -28,6 +28,11 @@ Practical machine learning applications such as information retrieval and recomm This page focuses on evaluation of the results from such ranking problems. +> #### Caution +> In order to obtain ranked list of items, this page introduces queries using `to_ordered_map()` such as `map_values(to_ordered_map(score, itemid, true))`. However, this kind of usage has a potential issue that multiple `itemid`-s (i.e., values) which have the exactly same `score` (i.e., key) will be aggregated to single arbitrary `itemid`, because `to_ordered_map()` creates a key-value map which uses duplicated `score` as key. +> +> Hence, if map key could duplicate on more then one map values, we recommend you to use `to_ordered_list(value, key, '-reverse')` instead of `map_values(to_ordered_map(key, value, true))`. The alternative approach is available from Hivemall v0.5-rc.1 or later. + # Binary Response Measures In a context of ranking problem, **binary response** means that binary labels are assigned to items, and positive items are considered as *truth* observations. http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/docs/gitbook/misc/generic_funcs.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/misc/generic_funcs.md b/docs/gitbook/misc/generic_funcs.md index b3a0421..b27117f 100644 --- a/docs/gitbook/misc/generic_funcs.md +++ b/docs/gitbook/misc/generic_funcs.md @@ -83,6 +83,10 @@ This page describes a list of useful Hivemall generic functions. - `array_sum(array<NUMBER>)` - Returns an array<double> in which each element is summed up +## List UDAF + +- `to_ordered_list(value [, const string options])` or `to_ordered_list(value, key [, const string options])` - Return list of values sorted by value itself or specific key + # Bitset functions ## Bitset UDF @@ -141,8 +145,7 @@ The compression level must be in range [-1,9] - `to_map(key, value)` - Convert two aggregated columns into a key-value map -- `to_ordered_map(key, value [, const boolean reverseOrder=false])` - Convert two aggregated columns into an ordered key-value map - +- `to_ordered_map(key, value [, const int|boolean k|reverseOrder=false])` - Convert two aggregated columns into an ordered key-value map # MapReduce functions http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/docs/gitbook/misc/topk.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/misc/topk.md b/docs/gitbook/misc/topk.md index 6a80514..27cf7ad 100644 --- a/docs/gitbook/misc/topk.md +++ b/docs/gitbook/misc/topk.md @@ -379,3 +379,66 @@ FROM | 4 | 0.4432108402252197 | 3 | 26220 | 1 | | 5 | 0.44323229789733887 | 3 | 18541 | 0 | | .. | .. | .. | .. | .. | + +# Alternative approaches + +In order to utilize mapper-side aggregation and reduce computational cost of shuffling, you can use [`to_ordered_map`](./generic_funcs.md#map-udafs) or [`to_ordered_list`](./generic_funcs.md#list-udaf) to get top/tail-k elements instead of `each_top_k`. + +As long as `key` is unique in each `id`, the following queries return same result: + +```sql +with t as ( + select + each_top_k( + 10, id, key, + id, value + ) as (rank, key, id, value) + from ( + select + * + from + test + cluster by + id + ) t +) +select + id, collect_list(value) as topk +from + t +group by + id +``` + +```sql +with t as ( + select + id, to_ordered_map(key, value, 10) as m + from + test + group by + id +) +select + id, collect_list(value) as topk +from + t +lateral view explode(m) t as key, value +group by + id +``` + +```sql +select + id, to_ordered_list(value, key, '-k 10') as topk +from + test +group by + id +``` + +> #### Caution +> +> In case that `key` could duplicate in `id`, `to_ordered_map` behaves differently because key `K` is always unique in `Map<K, V>`. + +Similarly to `each_top_k`, tail-k can also be represented as: `to_ordered_map(key, value, -10)` and `to_ordered_list(value, key, '-k -10')`. \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/docs/gitbook/recommend/item_based_cf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/recommend/item_based_cf.md b/docs/gitbook/recommend/item_based_cf.md index 9e4f7e4..053b225 100644 --- a/docs/gitbook/recommend/item_based_cf.md +++ b/docs/gitbook/recommend/item_based_cf.md @@ -437,6 +437,11 @@ from ( In order to generate a list of recommended items, you can use either cooccurrence count or similarity as a relevance score. +> #### Caution +> In order to obtain ranked list of items, this section introduces queries using `map_values(to_ordered_map(rank, rec_item))`. However, this kind of usage has a potential issue that multiple `rec_item`-s which have the exactly same `rank` will be aggregated to single arbitrary `rec_item`, because `to_ordered_map()` creates a key-value map which uses duplicated `rank` as key. +> +> Since such situation is possible in case that `each_top_k()` is executed for different `userid`-s who have the same `cnt` or `similarity`, we recommend you to use `to_ordered_list(rec_item, rank, '-reverse')` instead of `map_values(to_ordered_map(rank, rec_item, true))`. The alternative approach is available from Hivemall v0.5-rc.1 or later. + ### Cooccurrence-based ```sql http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/docs/gitbook/recommend/movielens_cf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/recommend/movielens_cf.md b/docs/gitbook/recommend/movielens_cf.md index faa555c..08268a8 100644 --- a/docs/gitbook/recommend/movielens_cf.md +++ b/docs/gitbook/recommend/movielens_cf.md @@ -21,6 +21,11 @@ <!-- toc --> +> #### Caution +> In order to obtain ranked list of items, this page introduces queries using `to_ordered_map()` such as `map_values(to_ordered_map(rating, movieid, true))`. However, this kind of usage has a potential issue that multiple `movieid`-s (i.e., values) which have the exactly same `rating` (i.e., key) will be aggregated to single arbitrary `movieid`, because `to_ordered_map()` creates a key-value map which uses duplicated `rating` as key. +> +> Hence, if map key could duplicate on more then one map values, we recommend you to use `to_ordered_list(value, key, '-reverse')` instead of `map_values(to_ordered_map(key, value, true))`. The alternative approach is available from Hivemall v0.5-rc.1 or later. + # Compute movie-movie similarity [As we explained in the general introduction of item-based CF](item_based_cf.html#dimsum-approximated-all-pairs-cosine-similarity-computation.md), following query finds top-$$k$$ nearest-neighborhood movies for each movie: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/resources/ddl/define-all-as-permanent.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index c2b38fb..cda33f9 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -467,6 +467,13 @@ DROP FUNCTION IF EXISTS to_ordered_map; CREATE FUNCTION to_ordered_map as 'hivemall.tools.map.UDAFToOrderedMap' USING JAR '${hivemall_jar}'; --------------------- +-- list functions -- +--------------------- + +DROP FUNCTION IF EXISTS to_ordered_list; +CREATE FUNCTION to_ordered_list as 'hivemall.tools.list.UDAFToOrderedList' USING JAR '${hivemall_jar}'; + +--------------------- -- Math functions -- --------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/resources/ddl/define-all.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index 89821f8..6e116ac 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -459,6 +459,13 @@ drop temporary function if exists to_ordered_map; create temporary function to_ordered_map as 'hivemall.tools.map.UDAFToOrderedMap'; --------------------- +-- list functions -- +--------------------- + +drop temporary function if exists to_ordered_list; +create temporary function to_ordered_list as 'hivemall.tools.list.UDAFToOrderedList'; + +--------------------- -- Math functions -- --------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/resources/ddl/define-all.spark ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index b4926e3..d3eb3cd 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -458,6 +458,13 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS to_ordered_map") sqlContext.sql("CREATE TEMPORARY FUNCTION to_ordered_map AS 'hivemall.tools.map.UDAFToOrderedMap'") /** + * List functions + */ + +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS to_ordered_list") +sqlContext.sql("CREATE TEMPORARY FUNCTION to_ordered_list AS 'hivemall.tools.list.UDAFToOrderedList'") + +/** * Math functions */ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9cd3c59a/resources/ddl/define-udfs.td.hql ---------------------------------------------------------------------- diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql index c7fdd49..2662260 100644 --- a/resources/ddl/define-udfs.td.hql +++ b/resources/ddl/define-udfs.td.hql @@ -177,6 +177,7 @@ create temporary function tree_export as 'hivemall.smile.tools.TreeExportUDF'; create temporary function train_ffm as 'hivemall.fm.FieldAwareFactorizationMachineUDTF'; create temporary function ffm_predict as 'hivemall.fm.FFMPredictGenericUDAF'; create temporary function add_field_indicies as 'hivemall.ftvec.trans.AddFieldIndicesUDF'; +create temporary function to_ordered_list as 'hivemall.tools.list.UDAFToOrderedList'; -- NLP features create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';
