HIVEMALL-138: Update `to_ordered_map` & implement `to_ordered_list`
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e3b27280 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e3b27280 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e3b27280 Branch: refs/heads/dev/v0.4.2 Commit: e3b27280451ab30c7628312ca2648931cfed9433 Parents: 07a7d51 Author: Takuya Kitazawa <[email protected]> Authored: Wed Sep 20 16:21:47 2017 +0900 Committer: Takuya Kitazawa <[email protected]> Committed: Fri Sep 22 15:49:02 2017 +0900 ---------------------------------------------------------------------- .../hivemall/tools/list/UDAFToOrderedList.java | 549 +++++++++++++++++++ .../hivemall/tools/map/UDAFToOrderedMap.java | 283 ++++++++-- .../collections/maps/BoundedSortedMap.java | 59 ++ .../java/hivemall/utils/hadoop/HiveUtils.java | 32 ++ .../hivemall/utils/lang/NaturalComparator.java | 48 ++ .../java/hivemall/utils/lang/StringUtils.java | 22 +- .../main/java/hivemall/utils/struct/Pair.java | 38 ++ .../tools/list/UDAFToOrderedListTest.java | 342 ++++++++++++ .../tools/map/UDAFToOrderedMapTest.java | 159 ++++++ .../collections/BoundedPriorityQueueTest.java | 114 ++++ .../collections/maps/BoundedSortedMapTest.java | 84 +++ resources/ddl/define-all-as-permanent.hive | 9 +- resources/ddl/define-all.hive | 9 +- resources/ddl/define-all.spark | 7 + resources/ddl/define-udfs.td.hql | 1 + 15 files changed, 1718 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java new file mode 100644 index 0000000..52bd533 --- /dev/null +++ b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java @@ -0,0 +1,549 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.list; + +import hivemall.utils.collections.BoundedPriorityQueue; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.CommandLineUtils; +import hivemall.utils.lang.NaturalComparator; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.struct.Pair; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.PriorityQueue; + +import javax.annotation.CheckForNull; +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.BooleanWritable; +import org.apache.hadoop.io.IntWritable; + +/** + * Return list of values sorted by value itself or specific key. + */ +@Description(name = "to_ordered_list", + value = "_FUNC_(PRIMITIVE value [, PRIMITIVE key, const string options])" + + " - Return list of values sorted by value itself or specific key") +public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { + + @Override + public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) + throws SemanticException { + @SuppressWarnings("deprecation") + TypeInfo[] typeInfo = info.getParameters(); + ObjectInspector[] argOIs = info.getParameterObjectInspectors(); + if ((typeInfo.length == 1) || (typeInfo.length == 2 && HiveUtils.isConstString(argOIs[1]))) { + // sort values by value itself w/o key + if (typeInfo[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(0, + "Only primitive type arguments are accepted for value but " + + typeInfo[0].getTypeName() + " was passed as the first parameter."); + } + } else if ((typeInfo.length == 2) + || (typeInfo.length == 3 && HiveUtils.isConstString(argOIs[2]))) { + // sort values by key + if (typeInfo[1].getCategory() != ObjectInspector.Category.PRIMITIVE) { + throw new UDFArgumentTypeException(1, + "Only primitive type arguments are accepted for key but " + + typeInfo[1].getTypeName() + " was passed as the second parameter."); + } + } else { + throw new UDFArgumentTypeException(typeInfo.length - 1, + "Number of arguments must be in [1, 3] including constant string for options: " + + typeInfo.length); + } + return new UDAFToOrderedListEvaluator(); + } + + public static class UDAFToOrderedListEvaluator extends GenericUDAFEvaluator { + + private ObjectInspector valueOI; + private PrimitiveObjectInspector keyOI; + + private ListObjectInspector valueListOI; + private ListObjectInspector keyListOI; + + private StructObjectInspector internalMergeOI; + + private StructField valueListField; + private StructField keyListField; + private StructField sizeField; + private StructField reverseOrderField; + + @Nonnegative + private int size; + private boolean reverseOrder; + private boolean sortByKey; + + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", true, "To top-k (positive) or tail-k (negative) ordered queue"); + opts.addOption("reverse", "reverse_order", false, + "Sort values by key in a reverse (e.g., descending) order [default: false]"); + return opts; + } + + @Nonnull + protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException { + String[] args = optionValue.split("\\s+"); + Options opts = getOptions(); + opts.addOption("help", false, "Show function help"); + CommandLine cl = CommandLineUtils.parseOptions(args, opts); + + if (cl.hasOption("help")) { + Description funcDesc = getClass().getAnnotation(Description.class); + final String cmdLineSyntax; + if (funcDesc == null) { + cmdLineSyntax = getClass().getSimpleName(); + } else { + String funcName = funcDesc.name(); + cmdLineSyntax = funcName == null ? getClass().getSimpleName() + : funcDesc.value().replace("_FUNC_", funcDesc.name()); + } + StringWriter sw = new StringWriter(); + sw.write('\n'); + PrintWriter pw = new PrintWriter(sw); + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, + HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); + pw.flush(); + String helpMsg = sw.toString(); + throw new UDFArgumentException(helpMsg); + } + + return cl; + } + + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + int optionIndex = 1; + if (sortByKey) { + optionIndex = 2; + } + + int k = 0; + boolean reverseOrder = false; + if (argOIs.length >= optionIndex + 1) { + String rawArgs = HiveUtils.getConstString(argOIs[optionIndex]); + cl = parseOptions(rawArgs); + + reverseOrder = cl.hasOption("reverse_order"); + + if (cl.hasOption("k")) { + k = Integer.parseInt(cl.getOptionValue("k")); + if (k == 0) { + throw new UDFArgumentException("`k` must be non-zero value: " + k); + } + } + } + this.size = Math.abs(k); + + if ((k > 0 && reverseOrder) || (k < 0 && reverseOrder == false) + || (k == 0 && reverseOrder == false)) { + // top-k on reverse order = tail-k on natural order (so, top-k on descending) + this.reverseOrder = true; + } else { // (k > 0 && reverseOrder == false) || (k < 0 && reverseOrder) || (k == 0 && reverseOrder) + // top-k on natural order = tail-k on reverse order (so, top-k on ascending) + this.reverseOrder = false; + } + + return cl; + } + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveException { + super.init(mode, argOIs); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + // this flag will be used in `processOptions` and `iterate` (= when Mode.PARTIAL1 or Mode.COMPLETE) + this.sortByKey = (argOIs.length == 2 && !HiveUtils.isConstString(argOIs[1])) + || (argOIs.length == 3 && HiveUtils.isConstString(argOIs[2])); + + if (sortByKey) { + this.valueOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.keyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]); + } else { + // sort values by value itself + this.valueOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.keyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + } + + processOptions(argOIs); + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) argOIs[0]; + this.internalMergeOI = soi; + + // re-extract input value OI + this.valueListField = soi.getStructFieldRef("valueList"); + StandardListObjectInspector valueListOI = (StandardListObjectInspector) valueListField.getFieldObjectInspector(); + this.valueOI = valueListOI.getListElementObjectInspector(); + this.valueListOI = ObjectInspectorFactory.getStandardListObjectInspector(valueOI); + + // re-extract input key OI + this.keyListField = soi.getStructFieldRef("keyList"); + StandardListObjectInspector keyListOI = (StandardListObjectInspector) keyListField.getFieldObjectInspector(); + this.keyOI = HiveUtils.asPrimitiveObjectInspector(keyListOI.getListElementObjectInspector()); + this.keyListOI = ObjectInspectorFactory.getStandardListObjectInspector(keyOI); + + this.sizeField = soi.getStructFieldRef("size"); + this.reverseOrderField = soi.getStructFieldRef("reverseOrder"); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(valueOI, keyOI); + } else {// terminate + outputOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(valueOI)); + } + + return outputOI; + } + + @Nonnull + private static StructObjectInspector internalMergeOI(@Nonnull ObjectInspector valueOI, + @Nonnull PrimitiveObjectInspector keyOI) { + List<String> fieldNames = new ArrayList<String>(); + List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("valueList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(valueOI))); + fieldNames.add("keyList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(keyOI))); + fieldNames.add("size"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("reverseOrder"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @SuppressWarnings("deprecation") + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + QueueAggregationBuffer myagg = new QueueAggregationBuffer(); + reset(myagg); + return myagg; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + myagg.reset(size, reverseOrder); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + if (parameters[0] == null) { + return; + } + Object value = ObjectInspectorUtils.copyToStandardObject(parameters[0], valueOI); + + final Object key; + if (sortByKey) { + if (parameters[1] == null) { + return; + } + key = ObjectInspectorUtils.copyToStandardObject(parameters[1], keyOI); + } else { + // set value to key + key = ObjectInspectorUtils.copyToStandardObject(parameters[0], valueOI); + } + + TupleWithKey tuple = new TupleWithKey(key, value); + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + + myagg.iterate(tuple); + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + + Pair<List<Object>, List<Object>> tuples = myagg.drainQueue(); + List<Object> keyList = tuples.getKey(); + List<Object> valueList = tuples.getValue(); + if (valueList.isEmpty()) { + return null; + } + + Object[] partialResult = new Object[4]; + partialResult[0] = valueList; + partialResult[1] = keyList; + partialResult[2] = new IntWritable(myagg.size); + partialResult[3] = new BooleanWritable(myagg.reverseOrder); + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + Object valueListObj = internalMergeOI.getStructFieldData(partial, valueListField); + final List<?> valueListRaw = valueListOI.getList(HiveUtils.castLazyBinaryObject(valueListObj)); + final List<Object> valueList = new ArrayList<Object>(); + for (int i = 0, n = valueListRaw.size(); i < n; i++) { + valueList.add(ObjectInspectorUtils.copyToStandardObject(valueListRaw.get(i), + valueOI)); + } + + Object keyListObj = internalMergeOI.getStructFieldData(partial, keyListField); + final List<?> keyListRaw = keyListOI.getList(HiveUtils.castLazyBinaryObject(keyListObj)); + final List<Object> keyList = new ArrayList<Object>(); + for (int i = 0, n = keyListRaw.size(); i < n; i++) { + keyList.add(ObjectInspectorUtils.copyToStandardObject(keyListRaw.get(i), keyOI)); + } + + Object sizeObj = internalMergeOI.getStructFieldData(partial, sizeField); + int size = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(sizeObj); + + Object reverseOrderObj = internalMergeOI.getStructFieldData(partial, reverseOrderField); + boolean reverseOrder = PrimitiveObjectInspectorFactory.writableBooleanObjectInspector.get(reverseOrderObj); + + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + myagg.setOptions(size, reverseOrder); + myagg.merge(keyList, valueList); + } + + @Override + public List<Object> terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; + Pair<List<Object>, List<Object>> tuples = myagg.drainQueue(); + return tuples.getValue(); + } + + static class QueueAggregationBuffer implements AggregationBuffer { + + private AbstractQueueHandler queueHandler; + + @Nonnegative + private int size; + private boolean reverseOrder; + + QueueAggregationBuffer() { + super(); + } + + void reset(@Nonnegative int size, boolean reverseOrder) { + setOptions(size, reverseOrder); + this.queueHandler = null; + } + + void setOptions(@Nonnegative int size, boolean reverseOrder) { + this.size = size; + this.reverseOrder = reverseOrder; + } + + void iterate(@Nonnull TupleWithKey tuple) { + if (queueHandler == null) { + initQueueHandler(); + } + queueHandler.offer(tuple); + } + + void merge(@Nonnull List<Object> o_keyList, @Nonnull List<Object> o_valueList) { + if (queueHandler == null) { + initQueueHandler(); + } + for (int i = 0, n = o_keyList.size(); i < n; i++) { + queueHandler.offer(new TupleWithKey(o_keyList.get(i), o_valueList.get(i))); + } + } + + @Nonnull + Pair<List<Object>, List<Object>> drainQueue() { + int n = queueHandler.size(); + final Object[] keys = new Object[n]; + final Object[] values = new Object[n]; + for (int i = n - 1; i >= 0; i--) { // head element in queue should be stored to tail of array + TupleWithKey tuple = queueHandler.poll(); + keys[i] = tuple.getKey(); + values[i] = tuple.getValue(); + } + queueHandler.clear(); + + return Pair.of(Arrays.asList(keys), Arrays.asList(values)); + } + + private void initQueueHandler() { + final Comparator<TupleWithKey> comparator; + if (reverseOrder) { + comparator = Collections.reverseOrder(); + } else { + comparator = NaturalComparator.getInstance(); + } + + if (size > 0) { + this.queueHandler = new BoundedQueueHandler(size, comparator); + } else { + this.queueHandler = new QueueHandler(comparator); + } + } + + } + + /** + * Since BoundedPriorityQueue does not directly inherit PriorityQueue, we provide handler + * class which wraps each of PriorityQueue and BoundedPriorityQueue. + */ + private static abstract class AbstractQueueHandler { + + abstract void offer(@Nonnull TupleWithKey tuple); + + abstract int size(); + + @Nullable + abstract TupleWithKey poll(); + + abstract void clear(); + + } + + private static final class QueueHandler extends AbstractQueueHandler { + + private static final int DEFAULT_INITIAL_CAPACITY = 11; // same as PriorityQueue + + @Nonnull + private final PriorityQueue<TupleWithKey> queue; + + QueueHandler(@Nonnull Comparator<TupleWithKey> comparator) { + this.queue = new PriorityQueue<TupleWithKey>(DEFAULT_INITIAL_CAPACITY, comparator); + } + + @Override + void offer(TupleWithKey tuple) { + queue.offer(tuple); + } + + @Override + int size() { + return queue.size(); + } + + @Override + TupleWithKey poll() { + return queue.poll(); + } + + @Override + void clear() { + queue.clear(); + } + + } + + private static final class BoundedQueueHandler extends AbstractQueueHandler { + + @Nonnull + private final BoundedPriorityQueue<TupleWithKey> queue; + + BoundedQueueHandler(int size, @Nonnull Comparator<TupleWithKey> comparator) { + this.queue = new BoundedPriorityQueue<TupleWithKey>(size, comparator); + } + + @Override + void offer(TupleWithKey tuple) { + queue.offer(tuple); + } + + @Override + int size() { + return queue.size(); + } + + @Override + TupleWithKey poll() { + return queue.poll(); + } + + @Override + void clear() { + queue.clear(); + } + + } + + private static final class TupleWithKey implements Comparable<TupleWithKey> { + @Nonnull + private final Object key; + @Nonnull + private final Object value; + + TupleWithKey(@CheckForNull Object key, @CheckForNull Object value) { + this.key = Preconditions.checkNotNull(key); + this.value = Preconditions.checkNotNull(value); + } + + @Nonnull + Object getKey() { + return key; + } + + @Nonnull + Object getValue() { + return value; + } + + @Override + public int compareTo(TupleWithKey o) { + @SuppressWarnings("unchecked") + Comparable<? super Object> k = (Comparable<? super Object>) key; + return k.compareTo(o.getKey()); + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java b/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java index 5782180..97bb7b1 100644 --- a/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java +++ b/core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java @@ -1,93 +1,308 @@ /* - * Hivemall: Hive scalable Machine Learning Library + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * Copyright (C) 2015 Makoto YUI - * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * http://www.apache.org/licenses/LICENSE-2.0 * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ package hivemall.tools.map; +import hivemall.utils.collections.maps.BoundedSortedMap; import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.TreeMap; +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.parse.SemanticException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.IntWritable; /** * Convert two aggregated columns into a sorted key-value map. */ @Description(name = "to_ordered_map", - value = "_FUNC_(key, value [, const boolean reverseOrder=false]) " + value = "_FUNC_(key, value [, const int k|const boolean reverseOrder=false]) " + "- Convert two aggregated columns into an ordered key-value map") -public class UDAFToOrderedMap extends UDAFToMap { +public final class UDAFToOrderedMap extends UDAFToMap { @Override public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo info) throws SemanticException { @SuppressWarnings("deprecation") - TypeInfo[] typeInfo = info.getParameters(); + final TypeInfo[] typeInfo = info.getParameters(); if (typeInfo.length != 2 && typeInfo.length != 3) { throw new UDFArgumentTypeException(typeInfo.length - 1, - "Expecting two or three arguments: " + typeInfo.length); + "Expecting two or three arguments: " + typeInfo.length); } if (typeInfo[0].getCategory() != ObjectInspector.Category.PRIMITIVE) { throw new UDFArgumentTypeException(0, - "Only primitive type arguments are accepted for the key but " - + typeInfo[0].getTypeName() + " was passed as parameter 1."); + "Only primitive type arguments are accepted for the key but " + + typeInfo[0].getTypeName() + " was passed as parameter 1."); } + boolean reverseOrder = false; + int size = 0; if (typeInfo.length == 3) { - if (HiveUtils.isBooleanTypeInfo(typeInfo[2]) == false) { - throw new UDFArgumentTypeException(2, "The three argument must be boolean type: " - + typeInfo[2].getTypeName()); - } ObjectInspector[] argOIs = info.getParameterObjectInspectors(); - reverseOrder = HiveUtils.getConstBoolean(argOIs[2]); + ObjectInspector argOI2 = argOIs[2]; + if (HiveUtils.isConstBoolean(argOI2)) { + reverseOrder = HiveUtils.getConstBoolean(argOI2); + } else if (HiveUtils.isConstInteger(argOI2)) { + size = HiveUtils.getConstInt(argOI2); + if (size == 0) { + throw new UDFArgumentException("Map size must be non-zero value: " + size); + } + reverseOrder = (size > 0); // positive size => top-k + } else { + throw new UDFArgumentTypeException(2, + "The third argument must be boolean or int type: " + typeInfo[2].getTypeName()); + } } - if (reverseOrder) { - return new ReverseOrdereMapEvaluator(); - } else { - return new NaturalOrdereMapEvaluator(); + if (reverseOrder) { // descending + if (size == 0) { + return new ReverseOrderedMapEvaluator(); + } else { + return new TopKOrderedMapEvaluator(); + } + } else { // ascending + if (size == 0) { + return new NaturalOrderedMapEvaluator(); + } else { + return new TailKOrderedMapEvaluator(); + } } } - public static class NaturalOrdereMapEvaluator extends UDAFToMapEvaluator { + public static class NaturalOrderedMapEvaluator extends UDAFToMapEvaluator { @Override - public void reset(AggregationBuffer agg) throws HiveException { + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { ((MapAggregationBuffer) agg).container = new TreeMap<Object, Object>(); } } - public static class ReverseOrdereMapEvaluator extends UDAFToMapEvaluator { + public static class ReverseOrderedMapEvaluator extends UDAFToMapEvaluator { @Override - public void reset(AggregationBuffer agg) throws HiveException { + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { ((MapAggregationBuffer) agg).container = new TreeMap<Object, Object>( - Collections.reverseOrder()); + Collections.reverseOrder()); + } + + } + + public static class TopKOrderedMapEvaluator extends GenericUDAFEvaluator { + + protected PrimitiveObjectInspector inputKeyOI; + protected ObjectInspector inputValueOI; + protected MapObjectInspector partialMapOI; + protected PrimitiveObjectInspector sizeOI; + + protected StructObjectInspector internalMergeOI; + + protected StructField partialMapField; + protected StructField sizeField; + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveException { + super.init(mode, argOIs); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + this.inputKeyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.inputValueOI = argOIs[1]; + this.sizeOI = HiveUtils.asIntegerOI(argOIs[2]); + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) argOIs[0]; + this.internalMergeOI = soi; + + this.partialMapField = soi.getStructFieldRef("partialMap"); + // re-extract input key/value OIs + MapObjectInspector partialMapOI = (MapObjectInspector) partialMapField.getFieldObjectInspector(); + this.inputKeyOI = HiveUtils.asPrimitiveObjectInspector(partialMapOI.getMapKeyObjectInspector()); + this.inputValueOI = partialMapOI.getMapValueObjectInspector(); + + this.partialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI), + ObjectInspectorUtils.getStandardObjectInspector(inputValueOI)); + + this.sizeField = soi.getStructFieldRef("size"); + this.sizeOI = (PrimitiveObjectInspector) sizeField.getFieldObjectInspector(); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(inputKeyOI, inputValueOI); + } else {// terminate + outputOI = ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI), + ObjectInspectorUtils.getStandardObjectInspector(inputValueOI)); + } + return outputOI; + } + + @Nonnull + private static StructObjectInspector internalMergeOI( + @Nonnull PrimitiveObjectInspector keyOI, @Nonnull ObjectInspector valueOI) { + List<String> fieldNames = new ArrayList<String>(); + List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("partialMap"); + fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(keyOI), + ObjectInspectorUtils.getStandardObjectInspector(valueOI))); + + fieldNames.add("size"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + static class MapAggregationBuffer implements AggregationBuffer { + @Nullable + Map<Object, Object> container; + int size; + + MapAggregationBuffer() { + super(); + } + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + myagg.container = null; + myagg.size = 0; } + @Override + public MapAggregationBuffer getNewAggregationBuffer() throws HiveException { + MapAggregationBuffer myagg = new MapAggregationBuffer(); + reset(myagg); + return myagg; + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + assert (parameters.length == 3); + if (parameters[0] == null) { + return; + } + + Object key = ObjectInspectorUtils.copyToStandardObject(parameters[0], inputKeyOI); + Object value = ObjectInspectorUtils.copyToStandardObject(parameters[1], inputValueOI); + int size = Math.abs(HiveUtils.getInt(parameters[2], sizeOI)); // size could be negative for tail-k + + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + if (myagg.container == null) { + initBuffer(myagg, size); + } + myagg.container.put(key, value); + } + + void initBuffer(@Nonnull MapAggregationBuffer agg, @Nonnegative int size) { + Preconditions.checkArgument(size > 0, "size MUST be greather than zero: " + size); + + agg.container = new BoundedSortedMap<Object, Object>(size, true); + agg.size = size; + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + + Object[] partialResult = new Object[2]; + partialResult[0] = myagg.container; + partialResult[1] = new IntWritable(myagg.size); + + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + + Object partialMapObj = internalMergeOI.getStructFieldData(partial, partialMapField); + Map<?, ?> partialMap = partialMapOI.getMap(HiveUtils.castLazyBinaryObject(partialMapObj)); + if (partialMap == null) { + return; + } + + if (myagg.container == null) { + Object sizeObj = internalMergeOI.getStructFieldData(partial, sizeField); + int size = HiveUtils.getInt(sizeObj, sizeOI); + initBuffer(myagg, size); + } + for (Map.Entry<?, ?> e : partialMap.entrySet()) { + Object key = ObjectInspectorUtils.copyToStandardObject(e.getKey(), inputKeyOI); + Object value = ObjectInspectorUtils.copyToStandardObject(e.getValue(), inputValueOI); + myagg.container.put(key, value); + } + } + + @Override + @Nullable + public Map<Object, Object> terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + MapAggregationBuffer myagg = (MapAggregationBuffer) agg; + return myagg.container; + } + + } + + public static class TailKOrderedMapEvaluator extends TopKOrderedMapEvaluator { + + @Override + void initBuffer(MapAggregationBuffer agg, int size) { + agg.container = new BoundedSortedMap<Object, Object>(size); + agg.size = size; + } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/main/java/hivemall/utils/collections/maps/BoundedSortedMap.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/maps/BoundedSortedMap.java b/core/src/main/java/hivemall/utils/collections/maps/BoundedSortedMap.java new file mode 100644 index 0000000..b1bf806 --- /dev/null +++ b/core/src/main/java/hivemall/utils/collections/maps/BoundedSortedMap.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import hivemall.utils.lang.Preconditions; + +import java.util.Collections; +import java.util.Map.Entry; +import java.util.TreeMap; + +import javax.annotation.CheckForNull; +import javax.annotation.Nonnegative; +import javax.annotation.Nullable; + +public final class BoundedSortedMap<K, V> extends TreeMap<K, V> { + private static final long serialVersionUID = 4580890152997313541L; + + private final int bound; + + public BoundedSortedMap(@Nonnegative int size) { + this(size, false); + } + + public BoundedSortedMap(@Nonnegative int size, boolean reverseOrder) { + super(reverseOrder ? Collections.reverseOrder() : null); + Preconditions.checkArgument(size > 0, "size must be greater than zero: " + size); + this.bound = size; + } + + @Nullable + public V put(@CheckForNull final K key, @Nullable final V value) { + final V old = super.put(key, value); + if (size() > bound) { + Entry<K, V> e = pollLastEntry(); + if (e == null) { + return null; + } + return e.getValue(); + } + return old; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index ad0dac6..1cc8607 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -50,6 +50,8 @@ import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.lazy.LazyInteger; import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe; import org.apache.hadoop.hive.serde2.lazy.LazyString; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray; +import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryMap; import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @@ -233,6 +235,18 @@ public final class HiveUtils { return category == Category.LIST; } + public static boolean isConstInt(@Nonnull final ObjectInspector oi) { + return ObjectInspectorUtils.isConstantObjectInspector(oi) && isIntOI(oi); + } + + public static boolean isConstInteger(@Nonnull final ObjectInspector oi) { + return ObjectInspectorUtils.isConstantObjectInspector(oi) && isIntegerOI(oi); + } + + public static boolean isConstBoolean(@Nonnull final ObjectInspector oi) { + return ObjectInspectorUtils.isConstantObjectInspector(oi) && isBooleanOI(oi); + } + public static boolean isPrimitiveTypeInfo(@Nonnull TypeInfo typeInfo) { return typeInfo.getCategory() == ObjectInspector.Category.PRIMITIVE; } @@ -524,6 +538,13 @@ public final class HiveUtils { return ary; } + public static int getInt(@Nullable Object o, @Nonnull PrimitiveObjectInspector oi) { + if (o == null) { + return 0; + } + return PrimitiveObjectInspectorUtils.getInt(o, oi); + } + /** * @return the number of true bits */ @@ -772,4 +793,15 @@ public final class HiveUtils { serde.initialize(conf, tbl); return serde; } + + @Nonnull + public static Object castLazyBinaryObject(@Nonnull final Object obj) { + if (obj instanceof LazyBinaryMap) { + return ((LazyBinaryMap) obj).getMap(); + } else if (obj instanceof LazyBinaryArray) { + return ((LazyBinaryArray) obj).getList(); + } + return obj; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/main/java/hivemall/utils/lang/NaturalComparator.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/NaturalComparator.java b/core/src/main/java/hivemall/utils/lang/NaturalComparator.java new file mode 100644 index 0000000..d451f1b --- /dev/null +++ b/core/src/main/java/hivemall/utils/lang/NaturalComparator.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.lang; + +import java.util.Comparator; + +import javax.annotation.Nonnull; + +public final class NaturalComparator<T extends Comparable<? super T>> implements Comparator<T> { + + @SuppressWarnings("rawtypes") + private final static NaturalComparator INSTANCE = new NaturalComparator(); + + private NaturalComparator() {}// avoid instantiation + + @Override + public int compare(T o1, T o2) { + return o1.compareTo(o2); + } + + @SuppressWarnings("unchecked") + @Nonnull + public final static <T extends Comparable<? super T>> Comparator<T> getInstance() { + return (Comparator<T>) INSTANCE; + } + + @Nonnull + public final static <T extends Comparable<? super T>> Comparator<T> newInstance() { + return new NaturalComparator<T>(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/main/java/hivemall/utils/lang/StringUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/StringUtils.java b/core/src/main/java/hivemall/utils/lang/StringUtils.java index 16d92cb..c2d17ca 100644 --- a/core/src/main/java/hivemall/utils/lang/StringUtils.java +++ b/core/src/main/java/hivemall/utils/lang/StringUtils.java @@ -53,7 +53,7 @@ public final class StringUtils { /** * Checks whether the String a valid Java number. this code is ported from jakarta commons lang. - * + * * @link http://jakarta.apache.org/commons/lang/apidocs/org/apache/commons/lang * /math/NumberUtils.html */ @@ -97,7 +97,7 @@ public final class StringUtils { } else if (chars[i] == '.') { if (hasDecPoint || hasExp) { - // two decimal points or dec in exponent + // two decimal points or dec in exponent return false; } hasDecPoint = true; @@ -251,4 +251,22 @@ public final class StringUtils { } + public static int compare(@Nullable final String o1, @Nullable final String o2) { + return compare(o1, o2, true); + } + + public static int compare(@Nullable final String o1, @Nullable final String o2, + final boolean nullIsLess) { + if (o1 == o2) { + return 0; + } + if (o1 == null) { + return nullIsLess ? -1 : 1; + } + if (o2 == null) { + return nullIsLess ? 1 : -1; + } + return o1.compareTo(o2); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/main/java/hivemall/utils/struct/Pair.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/struct/Pair.java b/core/src/main/java/hivemall/utils/struct/Pair.java new file mode 100644 index 0000000..32ea826 --- /dev/null +++ b/core/src/main/java/hivemall/utils/struct/Pair.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.struct; + +import java.util.AbstractMap; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +public class Pair<K, V> extends AbstractMap.SimpleEntry<K, V> { + private static final long serialVersionUID = 6411527075103472113L; + + public Pair(@Nullable K key, @Nullable V value) { + super(key, value); + } + + @Nonnull + public static <K, V> Pair<K, V> of(@Nullable K key, @Nullable V value) { + return new Pair<K, V>(key, value); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java b/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java new file mode 100644 index 0000000..78f2de6 --- /dev/null +++ b/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.list; + +import hivemall.tools.list.UDAFToOrderedList.UDAFToOrderedListEvaluator; +import hivemall.tools.list.UDAFToOrderedList.UDAFToOrderedListEvaluator.QueueAggregationBuffer; + +import java.util.List; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class UDAFToOrderedListTest { + + private UDAFToOrderedListEvaluator evaluator; + private QueueAggregationBuffer agg; + + @Before + public void setUp() throws Exception { + this.evaluator = new UDAFToOrderedListEvaluator(); + this.agg = (QueueAggregationBuffer) evaluator.getNewAggregationBuffer(); + } + + @Test + public void testNaturalOrder() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(3, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + Assert.assertEquals("candy", res.get(2)); + } + + @Test + public void testReverseOrder() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-reverse_order")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(3, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + Assert.assertEquals("apple", res.get(2)); + } + + @Test + public void testTopK() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testReverseTopK() throws Exception { + // = tail-k + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2 -reverse")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testTailK() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testReverseTailK() throws Exception { + // = top-k + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2 -reverse")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testNaturalOrderWithKey() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.7}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(3, res.size()); + Assert.assertEquals("apple", res.get(0)); + if (res.get(1) == "banana") { // duplicated key (0.7) + Assert.assertEquals("candy", res.get(2)); + } else { + Assert.assertEquals("banana", res.get(2)); + } + } + + @Test + public void testReverseOrderWithKey() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-reverse_order")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.7}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(3, res.size()); + if (res.get(0) == "banana") { // duplicated key (0.7) + Assert.assertEquals("candy", res.get(1)); + } else { + Assert.assertEquals("banana", res.get(1)); + } + Assert.assertEquals("apple", res.get(2)); + } + + @Test + public void testTopKWithKey() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testReverseTopKWithKey() throws Exception { + // = tail-k + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2 -reverse")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testTailKWithKey() throws Exception { + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("apple", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + + @Test + public void testReverseTailKWithKey() throws Exception { + // = top-k + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k -2 -reverse")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + List<Object> res = evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals("candy", res.get(0)); + Assert.assertEquals("banana", res.get(1)); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/test/java/hivemall/tools/map/UDAFToOrderedMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/map/UDAFToOrderedMapTest.java b/core/src/test/java/hivemall/tools/map/UDAFToOrderedMapTest.java new file mode 100644 index 0000000..38bc5ae --- /dev/null +++ b/core/src/test/java/hivemall/tools/map/UDAFToOrderedMapTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.map; + +import hivemall.tools.map.UDAFToOrderedMap.NaturalOrderedMapEvaluator; +import hivemall.tools.map.UDAFToOrderedMap.ReverseOrderedMapEvaluator; +import hivemall.tools.map.UDAFToOrderedMap.TailKOrderedMapEvaluator; +import hivemall.tools.map.UDAFToOrderedMap.TopKOrderedMapEvaluator; + +import java.util.Map; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Test; + +public class UDAFToOrderedMapTest { + + @Test + public void testNaturalOrder() throws Exception { + NaturalOrderedMapEvaluator evaluator = new NaturalOrderedMapEvaluator(); + NaturalOrderedMapEvaluator.MapAggregationBuffer agg = (NaturalOrderedMapEvaluator.MapAggregationBuffer) evaluator.getNewAggregationBuffer(); + + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector}; + + final double[] keys = new double[] {0.7, 0.5, 0.8}; + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < keys.length; i++) { + evaluator.iterate(agg, new Object[] {keys[i], values[i]}); + } + + Map<Object, Object> res = evaluator.terminate(agg); + Object[] sortedValues = res.values().toArray(); + + Assert.assertEquals(3, sortedValues.length); + Assert.assertEquals("apple", sortedValues[0]); + Assert.assertEquals("banana", sortedValues[1]); + Assert.assertEquals("candy", sortedValues[2]); + + evaluator.close(); + } + + @Test + public void testReverseOrder() throws Exception { + ReverseOrderedMapEvaluator evaluator = new ReverseOrderedMapEvaluator(); + ReverseOrderedMapEvaluator.MapAggregationBuffer agg = (ReverseOrderedMapEvaluator.MapAggregationBuffer) evaluator.getNewAggregationBuffer(); + + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector}; + + final double[] keys = new double[] {0.7, 0.5, 0.8}; + final String[] values = new String[] {"banana", "apple", "candy"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < keys.length; i++) { + evaluator.iterate(agg, new Object[] {keys[i], values[i]}); + } + + Map<Object, Object> res = evaluator.terminate(agg); + Object[] sortedValues = res.values().toArray(); + + Assert.assertEquals(3, sortedValues.length); + Assert.assertEquals("candy", sortedValues[0]); + Assert.assertEquals("banana", sortedValues[1]); + Assert.assertEquals("apple", sortedValues[2]); + + evaluator.close(); + } + + @Test + public void testTopK() throws Exception { + TopKOrderedMapEvaluator evaluator = new TopKOrderedMapEvaluator(); + TopKOrderedMapEvaluator.MapAggregationBuffer agg = (TopKOrderedMapEvaluator.MapAggregationBuffer) evaluator.getNewAggregationBuffer(); + + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector}; + + final double[] keys = new double[] {0.7, 0.5, 0.8}; + final String[] values = new String[] {"banana", "apple", "candy"}; + int size = 2; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < keys.length; i++) { + evaluator.iterate(agg, new Object[] {keys[i], values[i], size}); + } + + Map<Object, Object> res = evaluator.terminate(agg); + Object[] sortedValues = res.values().toArray(); + + Assert.assertEquals(size, sortedValues.length); + Assert.assertEquals("candy", sortedValues[0]); + Assert.assertEquals("banana", sortedValues[1]); + + evaluator.close(); + } + + @Test + public void testTailK() throws Exception { + TailKOrderedMapEvaluator evaluator = new TailKOrderedMapEvaluator(); + TailKOrderedMapEvaluator.MapAggregationBuffer agg = (TailKOrderedMapEvaluator.MapAggregationBuffer) evaluator.getNewAggregationBuffer(); + + ObjectInspector[] inputOIs = new ObjectInspector[] { + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector}; + + final double[] keys = new double[] {0.7, 0.5, 0.8}; + final String[] values = new String[] {"banana", "apple", "candy"}; + int size = -2; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < keys.length; i++) { + evaluator.iterate(agg, new Object[] {keys[i], values[i], size}); + } + + Map<Object, Object> res = evaluator.terminate(agg); + Object[] sortedValues = res.values().toArray(); + + Assert.assertEquals(Math.abs(size), sortedValues.length); + Assert.assertEquals("apple", sortedValues[0]); + Assert.assertEquals("banana", sortedValues[1]); + + evaluator.close(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java b/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java new file mode 100644 index 0000000..b9cfee0 --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections; + +import hivemall.utils.lang.NaturalComparator; +import hivemall.utils.lang.StringUtils; + +import java.util.Collections; +import java.util.Comparator; + +import org.junit.Assert; +import org.junit.Test; + +public class BoundedPriorityQueueTest { + + @Test + public void testTop3() { + BoundedPriorityQueue<Integer> queue = new BoundedPriorityQueue<Integer>(3, + new Comparator<Integer>() { + @Override + public int compare(Integer o1, Integer o2) { + return (o1 < o2) ? -1 : ((o1 == o2) ? 0 : 1); + } + }); + Assert.assertTrue(queue.offer(1)); + Assert.assertTrue(queue.offer(4)); + Assert.assertTrue(queue.offer(3)); + Assert.assertTrue(queue.offer(2)); + Assert.assertFalse(queue.offer(1)); + Assert.assertTrue(queue.offer(2)); + Assert.assertTrue(queue.offer(3)); + + Assert.assertEquals(3, queue.size()); + + Assert.assertEquals(Integer.valueOf(3), queue.peek()); + Assert.assertEquals(Integer.valueOf(3), queue.poll()); + Assert.assertEquals(Integer.valueOf(3), queue.poll()); + Assert.assertEquals(Integer.valueOf(4), queue.poll()); + Assert.assertNull(queue.poll()); + Assert.assertEquals(0, queue.size()); + } + + @Test + public void testTail3() { + BoundedPriorityQueue<Integer> queue = new BoundedPriorityQueue<Integer>(3, + Collections.<Integer>reverseOrder()); + Assert.assertTrue(queue.offer(1)); + Assert.assertTrue(queue.offer(4)); + Assert.assertTrue(queue.offer(3)); + Assert.assertTrue(queue.offer(2)); + Assert.assertTrue(queue.offer(1)); + Assert.assertTrue(queue.offer(2)); + Assert.assertFalse(queue.offer(3)); + + Assert.assertEquals(3, queue.size()); + + Assert.assertEquals(Integer.valueOf(2), queue.peek()); + Assert.assertEquals(Integer.valueOf(2), queue.poll()); + Assert.assertEquals(Integer.valueOf(1), queue.poll()); + Assert.assertEquals(Integer.valueOf(1), queue.poll()); + Assert.assertNull(queue.poll()); + Assert.assertEquals(0, queue.size()); + } + + @Test + public void testString1() { + BoundedPriorityQueue<String> queue = new BoundedPriorityQueue<String>(3, + new Comparator<String>() { + @Override + public int compare(String o1, String o2) { + return StringUtils.compare(o1, o2); + } + }); + queue.offer("B"); + queue.offer("A"); + queue.offer("C"); + queue.offer("D"); + Assert.assertEquals("B", queue.poll()); + Assert.assertEquals("C", queue.poll()); + Assert.assertEquals("D", queue.poll()); + Assert.assertNull(queue.poll()); + } + + @Test + public void testString2() { + BoundedPriorityQueue<String> queue = new BoundedPriorityQueue<String>(3, + NaturalComparator.<String>getInstance()); + queue.offer("B"); + queue.offer("A"); + queue.offer("C"); + queue.offer("D"); + Assert.assertEquals("B", queue.poll()); + Assert.assertEquals("C", queue.poll()); + Assert.assertEquals("D", queue.poll()); + Assert.assertNull(queue.poll()); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/core/src/test/java/hivemall/utils/collections/maps/BoundedSortedMapTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/maps/BoundedSortedMapTest.java b/core/src/test/java/hivemall/utils/collections/maps/BoundedSortedMapTest.java new file mode 100644 index 0000000..ce376cf --- /dev/null +++ b/core/src/test/java/hivemall/utils/collections/maps/BoundedSortedMapTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.collections.maps; + +import java.util.Iterator; +import java.util.Map.Entry; +import java.util.SortedMap; + +import org.junit.Assert; +import org.junit.Test; + +public class BoundedSortedMapTest { + + @Test + public void testNaturalOrderTop3() { + // natural order = ascending + SortedMap<Integer, Double> map = new BoundedSortedMap<Integer, Double>(3); + Assert.assertNull(map.put(1, 1.d)); + Assert.assertEquals(Double.valueOf(1.d), map.put(1, 1.1d)); + Assert.assertNull(map.put(4, 4.d)); + Assert.assertNull(map.put(2, 2.d)); + Assert.assertEquals(Double.valueOf(2.d), map.put(2, 2.2d)); + Assert.assertEquals(Double.valueOf(4.d), map.put(3, 3.d)); + Assert.assertEquals(Double.valueOf(3.d), map.put(3, 3.3d)); + + Assert.assertEquals(3, map.size()); + + Iterator<Entry<Integer, Double>> itor = map.entrySet().iterator(); + Entry<Integer, Double> e = itor.next(); + Assert.assertEquals(Integer.valueOf(1), e.getKey()); + Assert.assertEquals(Double.valueOf(1.1d), e.getValue()); + e = itor.next(); + Assert.assertEquals(Integer.valueOf(2), e.getKey()); + Assert.assertEquals(Double.valueOf(2.2d), e.getValue()); + e = itor.next(); + Assert.assertEquals(Integer.valueOf(3), e.getKey()); + Assert.assertEquals(Double.valueOf(3.3d), e.getValue()); + Assert.assertFalse(itor.hasNext()); + } + + @Test + public void testReverseOrderTop3() { + // reverse order = descending + SortedMap<Integer, Double> map = new BoundedSortedMap<Integer, Double>(3, true); + Assert.assertNull(map.put(1, 1.d)); + Assert.assertEquals(Double.valueOf(1.d), map.put(1, 1.1d)); + Assert.assertNull(map.put(4, 4.d)); + Assert.assertNull(map.put(2, 2.d)); + Assert.assertEquals(Double.valueOf(2.d), map.put(2, 2.2d)); + Assert.assertEquals(Double.valueOf(1.1d), map.put(3, 3.d)); + Assert.assertEquals(Double.valueOf(3.d), map.put(3, 3.3d)); + + Assert.assertEquals(3, map.size()); + + Iterator<Entry<Integer, Double>> itor = map.entrySet().iterator(); + Entry<Integer, Double> e = itor.next(); + Assert.assertEquals(Integer.valueOf(4), e.getKey()); + Assert.assertEquals(Double.valueOf(4.d), e.getValue()); + e = itor.next(); + Assert.assertEquals(Integer.valueOf(3), e.getKey()); + Assert.assertEquals(Double.valueOf(3.3d), e.getValue()); + e = itor.next(); + Assert.assertEquals(Integer.valueOf(2), e.getKey()); + Assert.assertEquals(Double.valueOf(2.2d), e.getValue()); + Assert.assertFalse(itor.hasNext()); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/resources/ddl/define-all-as-permanent.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all-as-permanent.hive b/resources/ddl/define-all-as-permanent.hive index 20511d4..5f4a57c 100644 --- a/resources/ddl/define-all-as-permanent.hive +++ b/resources/ddl/define-all-as-permanent.hive @@ -45,7 +45,7 @@ CREATE FUNCTION train_adagrad_rda as 'hivemall.classifier.AdaGradRDAUDTF' USING -------------------------------- -- Multiclass classification -- --------------------------------- +-------------------------------- DROP FUNCTION IF EXISTS train_multiclass_perceptron; CREATE FUNCTION train_multiclass_perceptron as 'hivemall.classifier.multiclass.MulticlassPerceptronUDTF' USING JAR '${hivemall_jar}'; @@ -423,6 +423,13 @@ DROP FUNCTION IF EXISTS to_ordered_map; CREATE FUNCTION to_ordered_map as 'hivemall.tools.map.UDAFToOrderedMap' USING JAR '${hivemall_jar}'; --------------------- +-- list functions -- +--------------------- + +DROP FUNCTION IF EXISTS to_ordered_list; +CREATE FUNCTION to_ordered_list as 'hivemall.tools.list.UDAFToOrderedList' USING JAR '${hivemall_jar}'; + +--------------------- -- Math functions -- --------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/resources/ddl/define-all.hive ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive index f0dbb42..2dd61c7 100644 --- a/resources/ddl/define-all.hive +++ b/resources/ddl/define-all.hive @@ -41,7 +41,7 @@ create temporary function train_adagrad_rda as 'hivemall.classifier.AdaGradRDAUD -------------------------------- -- Multiclass classification -- --------------------------------- +-------------------------------- drop temporary function train_multiclass_perceptron; create temporary function train_multiclass_perceptron as 'hivemall.classifier.multiclass.MulticlassPerceptronUDTF'; @@ -419,6 +419,13 @@ drop temporary function to_ordered_map; create temporary function to_ordered_map as 'hivemall.tools.map.UDAFToOrderedMap'; --------------------- +-- list functions -- +--------------------- + +drop temporary function if exists to_ordered_list; +create temporary function to_ordered_list as 'hivemall.tools.list.UDAFToOrderedList'; + +--------------------- -- Math functions -- --------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/resources/ddl/define-all.spark ---------------------------------------------------------------------- diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark index 69d8c3b..7d6e0b2 100644 --- a/resources/ddl/define-all.spark +++ b/resources/ddl/define-all.spark @@ -342,6 +342,13 @@ sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS to_ordered_map") sqlContext.sql("CREATE TEMPORARY FUNCTION to_ordered_map AS 'hivemall.tools.map.UDAFToOrderedMap'") /** + * List functions + */ + +sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS to_ordered_list") +sqlContext.sql("CREATE TEMPORARY FUNCTION to_ordered_list AS 'hivemall.tools.list.UDAFToOrderedList'") + +/** * Math functions */ http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3b27280/resources/ddl/define-udfs.td.hql ---------------------------------------------------------------------- diff --git a/resources/ddl/define-udfs.td.hql b/resources/ddl/define-udfs.td.hql index db54b79..00ecd30 100644 --- a/resources/ddl/define-udfs.td.hql +++ b/resources/ddl/define-udfs.td.hql @@ -143,6 +143,7 @@ create temporary function train_randomforest_regr as 'hivemall.smile.regression. create temporary function tree_predict as 'hivemall.smile.tools.TreePredictUDF'; create temporary function rf_ensemble as 'hivemall.smile.tools.RandomForestEnsembleUDAF'; create temporary function guess_attribute_types as 'hivemall.smile.tools.GuessAttributesUDF'; +create temporary function to_ordered_list as 'hivemall.tools.list.UDAFToOrderedList'; -- NLP features create temporary function tokenize_ja as 'hivemall.nlp.tokenizer.KuromojiUDF';
