Repository: incubator-hivemall Updated Branches: refs/heads/master e91e0f2ec -> 1ae9c9d7d
[HIVEMALL-223] Add -kv_map and -vk_map option to to_ordered_list UDAF ## What changes were proposed in this pull request? Add `-kv_map` and `-vk_map` option to `to_ordered_list` UDAF. ## What type of PR is it? Improvement ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-223 ## How was this patch tested? unit tests and manual tests on EMR ## How to use this feature? Will be described in http://hivemall.incubator.apache.org/userguide/misc/generic_funcs.html#array ## Checklist - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <[email protected]> Closes #170 from myui/HIVEMALL-223. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/1ae9c9d7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/1ae9c9d7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/1ae9c9d7 Branch: refs/heads/master Commit: 1ae9c9d7d35cb4355759bc92d4d8569dfad00263 Parents: e91e0f2 Author: Makoto Yui <[email protected]> Authored: Tue Nov 13 18:18:35 2018 +0900 Committer: Makoto Yui <[email protected]> Committed: Tue Nov 13 18:18:35 2018 +0900 ---------------------------------------------------------------------- .../hivemall/tools/list/UDAFToOrderedList.java | 181 +++++++-- .../java/hivemall/utils/hadoop/HiveUtils.java | 21 + .../tools/list/UDAFToOrderedListTest.java | 402 ++++++++++++++++++- .../collections/BoundedPriorityQueueTest.java | 33 ++ 4 files changed, 596 insertions(+), 41 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1ae9c9d7/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java index 83adb0f..6435a5f 100644 --- a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java +++ b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java @@ -31,7 +31,9 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.PriorityQueue; import javax.annotation.CheckForNull; @@ -86,7 +88,9 @@ import org.apache.hadoop.io.IntWritable; + " to_ordered_list(value, key, '-k -2 -reverse'), -- [apple, candy] (reverse tail-k = top-k)\n" + " to_ordered_list(value, '-k 2'), -- [egg, donut] (alphabetically)\n" + " to_ordered_list(key, '-k -2 -reverse'), -- [5, 4] (top-2 keys)\n" - + " to_ordered_list(key) -- [2, 3, 3, 4, 5] (natural ordered keys)\n" + + " to_ordered_list(key), -- [2, 3, 3, 4, 5] (natural ordered keys)\n" + + " to_ordered_list(value, key, '-k 2 -kv_map'), -- {4:\"candy\",5:\"apple\"}\n" + + " to_ordered_list(value, key, '-k 2 -vk_map') -- {\"candy\":4,\"apple\":5}\n" + "FROM\n" + " t") //@formatter:on public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { @@ -135,17 +139,23 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { private StructField keyListField; private StructField sizeField; private StructField reverseOrderField; + private StructField outKVField, outVKField; @Nonnegative private int size; private boolean reverseOrder; private boolean sortByKey; + private boolean outKV, outVK; protected Options getOptions() { Options opts = new Options(); opts.addOption("k", true, "To top-k (positive) or tail-k (negative) ordered queue"); opts.addOption("reverse", "reverse_order", false, "Sort values by key in a reverse (e.g., descending) order [default: false]"); + opts.addOption("kv", "kv_map", false, + "Return Map<K, V> for the result of to_ordered_list(V, K)"); + opts.addOption("vk", "vk_map", false, + "Return Map<V, K> for the result of to_ordered_list(V, K)"); return opts; } @@ -190,6 +200,7 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { int k = 0; boolean reverseOrder = false; + boolean outKV = false, outVK = false; if (argOIs.length >= optionIndex + 1) { String rawArgs = HiveUtils.getConstString(argOIs[optionIndex]); cl = parseOptions(rawArgs); @@ -202,8 +213,23 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { throw new UDFArgumentException("`k` must be non-zero value: " + k); } } + + outKV = cl.hasOption("kv_map"); + outVK = cl.hasOption("vk_map"); + if (outKV && outVK) { + throw new UDFArgumentException( + "Both `-kv_map` and `-vk_map` option are unexpectedly specified"); + } else if (outKV && sortByKey == false) { + throw new UDFArgumentException( + "`-kv_map` option can only be applied when both key and value are provided"); + } else if (outVK && sortByKey == false) { + throw new UDFArgumentException( + "`-vk_map` option can only be applied when both key and value are provided"); + } } this.size = Math.abs(k); + this.outKV = outKV; + this.outVK = outVK; if ((k > 0 && reverseOrder) || (k < 0 && reverseOrder == false) || (k == 0 && reverseOrder == false)) { @@ -258,23 +284,45 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { this.sizeField = soi.getStructFieldRef("size"); this.reverseOrderField = soi.getStructFieldRef("reverseOrder"); + + List<? extends StructField> fieldRefs = soi.getAllStructFieldRefs(); + + + this.outKVField = HiveUtils.getStructFieldRef("outKV", fieldRefs); + if (outKVField != null) { + this.outKV = true; + } + this.outVKField = HiveUtils.getStructFieldRef("outVK", fieldRefs); + if (outVKField != null) { + this.outVK = true; + } } // initialize output final ObjectInspector outputOI; if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial - outputOI = internalMergeOI(valueOI, keyOI); + outputOI = internalMergeOI(valueOI, keyOI, outKV, outVK); } else {// terminate - outputOI = ObjectInspectorFactory.getStandardListObjectInspector( - ObjectInspectorUtils.getStandardObjectInspector(valueOI)); + if (outKV) { + outputOI = ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(keyOI), + ObjectInspectorUtils.getStandardObjectInspector(valueOI)); + } else if (outVK) { + outputOI = ObjectInspectorFactory.getStandardMapObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(valueOI), + ObjectInspectorUtils.getStandardObjectInspector(keyOI)); + } else { + outputOI = ObjectInspectorFactory.getStandardListObjectInspector( + ObjectInspectorUtils.getStandardObjectInspector(valueOI)); + } } return outputOI; } @Nonnull - private static StructObjectInspector internalMergeOI(@Nonnull ObjectInspector valueOI, - @Nonnull PrimitiveObjectInspector keyOI) { + private StructObjectInspector internalMergeOI(@Nonnull ObjectInspector valueOI, + @Nonnull PrimitiveObjectInspector keyOI, boolean outKV, boolean outVK) { List<String> fieldNames = new ArrayList<String>(); List<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); @@ -288,6 +336,13 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("reverseOrder"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector); + if (outKV) { + fieldNames.add("outKV"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector); + } else if (outVK) { + fieldNames.add("outVK"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector); + } return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @@ -304,7 +359,7 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; - myagg.reset(size, reverseOrder); + myagg.reset(size, reverseOrder, outKV, outVK); } @Override @@ -344,11 +399,16 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { List<Object> keyList = tuples.getKey(); List<Object> valueList = tuples.getValue(); - Object[] partialResult = new Object[4]; + Object[] partialResult = new Object[outKV || outVK ? 5 : 4]; partialResult[0] = valueList; partialResult[1] = keyList; partialResult[2] = new IntWritable(myagg.size); partialResult[3] = new BooleanWritable(myagg.reverseOrder); + if (myagg.outKV) { + partialResult[4] = new BooleanWritable(true); + } else if (myagg.outVK) { + partialResult[4] = new BooleanWritable(true); + } return partialResult; } @@ -363,17 +423,16 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { final List<?> valueListRaw = valueListOI.getList(HiveUtils.castLazyBinaryObject(valueListObj)); final List<Object> valueList = new ArrayList<Object>(); - for (int i = 0, n = valueListRaw.size(); i < n; i++) { - valueList.add( - ObjectInspectorUtils.copyToStandardObject(valueListRaw.get(i), valueOI)); + for (Object v : valueListRaw) { + valueList.add(ObjectInspectorUtils.copyToStandardObject(v, valueOI)); } Object keyListObj = internalMergeOI.getStructFieldData(partial, keyListField); final List<?> keyListRaw = keyListOI.getList(HiveUtils.castLazyBinaryObject(keyListObj)); final List<Object> keyList = new ArrayList<Object>(); - for (int i = 0, n = keyListRaw.size(); i < n; i++) { - keyList.add(ObjectInspectorUtils.copyToStandardObject(keyListRaw.get(i), keyOI)); + for (Object k : keyListRaw) { + keyList.add(ObjectInspectorUtils.copyToStandardObject(k, keyOI)); } Object sizeObj = internalMergeOI.getStructFieldData(partial, sizeField); @@ -385,41 +444,47 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { reverseOrderObj); QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; - myagg.setOptions(size, reverseOrder); + myagg.setOptions(size, reverseOrder, outKV, outVK); myagg.merge(keyList, valueList); } @Override - public List<Object> terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) throws HiveException { QueueAggregationBuffer myagg = (QueueAggregationBuffer) agg; - Pair<List<Object>, List<Object>> tuples = myagg.drainQueue(); - if (tuples == null) { - return null; + if (myagg.outKV) { + return myagg.drainMapKV(); + } else if (myagg.outVK) { + return myagg.drainMapVK(); + } else { + return myagg.drainValues(); } - return tuples.getValue(); } static class QueueAggregationBuffer extends AbstractAggregationBuffer { - private AbstractQueueHandler queueHandler; + private transient AbstractQueueHandler queueHandler; @Nonnegative private int size; private boolean reverseOrder; + private boolean outKV, outVK; QueueAggregationBuffer() { super(); } - void reset(@Nonnegative int size, boolean reverseOrder) { - setOptions(size, reverseOrder); + void reset(@Nonnegative int size, boolean reverseOrder, boolean outKV, boolean outVK) { + setOptions(size, reverseOrder, outKV, outVK); this.queueHandler = null; } - void setOptions(@Nonnegative int size, boolean reverseOrder) { + void setOptions(@Nonnegative int size, boolean reverseOrder, boolean outKV, + boolean outVK) { this.size = size; this.reverseOrder = reverseOrder; + this.outKV = outKV; + this.outVK = outVK; } void iterate(@Nonnull TupleWithKey tuple) { @@ -429,22 +494,23 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { queueHandler.offer(tuple); } - void merge(@Nonnull List<Object> o_keyList, @Nonnull List<Object> o_valueList) { + void merge(@Nonnull List<Object> keys, @Nonnull List<Object> values) { if (queueHandler == null) { initQueueHandler(); } - for (int i = 0, n = o_keyList.size(); i < n; i++) { - queueHandler.offer(new TupleWithKey(o_keyList.get(i), o_valueList.get(i))); + for (int i = 0, n = keys.size(); i < n; i++) { + queueHandler.offer(new TupleWithKey(keys.get(i), values.get(i))); } } + @Deprecated @Nullable Pair<List<Object>, List<Object>> drainQueue() { if (queueHandler == null) { return null; } - int n = queueHandler.size(); + final int n = queueHandler.size(); final Object[] keys = new Object[n]; final Object[] values = new Object[n]; for (int i = n - 1; i >= 0; i--) { // head element in queue should be stored to tail of array @@ -457,6 +523,67 @@ public final class UDAFToOrderedList extends AbstractGenericUDAFResolver { return Pair.of(Arrays.asList(keys), Arrays.asList(values)); } + @Nullable + List<Object> drainValues() { + if (queueHandler == null) { + return null; + } + + final int n = queueHandler.size(); + final Object[] values = new Object[n]; + for (int i = n - 1; i >= 0; i--) { // head element in queue should be stored to tail of array + TupleWithKey tuple = queueHandler.poll(); + values[i] = tuple.getValue(); + } + queueHandler.clear(); + + return Arrays.asList(values); + } + + @Nullable + Map<Object, Object> drainMapKV() { + if (queueHandler == null) { + return null; + } + + final int n = queueHandler.size(); + final Map<Object, Object> map = new HashMap<>(n * 2); + for (int i = n - 1; i >= 0; i--) { // head element in queue should be stored to tail of array + TupleWithKey tuple = queueHandler.poll(); + Object k = tuple.getKey(); + if (map.containsKey(k)) { + continue; // avoid duplicate + } + Object v = tuple.getValue(); + map.put(k, v); + } + queueHandler.clear(); + + return map; + } + + @Nullable + Map<Object, Object> drainMapVK() { + if (queueHandler == null) { + return null; + } + + final int n = queueHandler.size(); + final Map<Object, Object> map = new HashMap<>(n * 2); + for (int i = n - 1; i >= 0; i--) { // head element in queue should be stored to tail of array + TupleWithKey tuple = queueHandler.poll(); + Object k = tuple.getValue(); + if (map.containsKey(k)) { + continue; // avoid duplicate + } + Object v = tuple.getKey(); + map.put(k, v); + } + queueHandler.clear(); + + return map; + } + private void initQueueHandler() { final Comparator<TupleWithKey> comparator; if (reverseOrder) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1ae9c9d7/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 12b0e97..e42d1b6 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -69,6 +69,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; import org.apache.hadoop.hive.serde2.objectinspector.StandardConstantListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; @@ -1227,4 +1228,24 @@ public final class HiveUtils { TypeInfoFactory.stringTypeInfo, new Text(str)); } + @Nullable + public static StructField getStructFieldRef(@Nonnull String fieldName, + @Nonnull final List<? extends StructField> fields) { + fieldName = fieldName.toLowerCase(); + for (StructField f : fields) { + if (f.getFieldName().equals(fieldName)) { + return f; + } + } + // For backward compatibility: fieldNames can also be integer Strings. + try { + final int i = Integer.parseInt(fieldName); + if (i >= 0 && i < fields.size()) { + return fields.get(i); + } + } catch (NumberFormatException e) { + // ignore + } + return null; + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1ae9c9d7/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java b/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java index 78043aa..e75a10c 100644 --- a/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java +++ b/core/src/test/java/hivemall/tools/list/UDAFToOrderedListTest.java @@ -21,8 +21,11 @@ package hivemall.tools.list; import hivemall.tools.list.UDAFToOrderedList.UDAFToOrderedListEvaluator; import hivemall.tools.list.UDAFToOrderedList.UDAFToOrderedListEvaluator.QueueAggregationBuffer; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; @@ -45,7 +48,7 @@ public class UDAFToOrderedListTest { @Test public void testNaturalOrder() throws Exception { ObjectInspector[] inputOIs = - new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}; + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector}; final String[] values = new String[] {"banana", "apple", "candy"}; @@ -56,7 +59,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(3, res.size()); Assert.assertEquals("apple", res.get(0)); @@ -65,6 +69,56 @@ public class UDAFToOrderedListTest { } @Test + public void testIntegerNaturalOrder() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector}; + + final Integer[] values = new Integer[] {3, -1, 4, 2, 5}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(5, res.size()); + Assert.assertEquals(-1, res.get(0)); + Assert.assertEquals(2, res.get(1)); + Assert.assertEquals(3, res.get(2)); + Assert.assertEquals(4, res.get(3)); + Assert.assertEquals(5, res.get(4)); + } + + @Test + public void testDoubleNaturalOrder() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaDoubleObjectInspector}; + + final Double[] values = new Double[] {3.1d, -1.1d, 4.1d, 2.1d, 5.1d}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(5, res.size()); + Assert.assertEquals(-1.1d, res.get(0)); + Assert.assertEquals(2.1d, res.get(1)); + Assert.assertEquals(3.1d, res.get(2)); + Assert.assertEquals(4.1d, res.get(3)); + Assert.assertEquals(5.1d, res.get(4)); + } + + @Test public void testReverseOrder() throws Exception { ObjectInspector[] inputOIs = new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, @@ -81,7 +135,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(3, res.size()); Assert.assertEquals("candy", res.get(0)); @@ -105,7 +160,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(2, res.size()); Assert.assertEquals("candy", res.get(0)); @@ -113,6 +169,30 @@ public class UDAFToOrderedListTest { } @Test + public void testTop2IntNuturalOrder() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaIntObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-k 2")}; + + final Integer[] values = new Integer[] {3, -1, 4, 4, 2, 5}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i]}); + } + + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); + + Assert.assertEquals(2, res.size()); + Assert.assertEquals(5, res.get(0)); + Assert.assertEquals(4, res.get(1)); + } + + @Test public void testReverseTopK() throws Exception { // = tail-k ObjectInspector[] inputOIs = @@ -130,7 +210,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(2, res.size()); Assert.assertEquals("apple", res.get(0)); @@ -153,7 +234,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(2, res.size()); Assert.assertEquals("apple", res.get(0)); @@ -178,7 +260,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(2, res.size()); Assert.assertEquals("candy", res.get(0)); @@ -201,7 +284,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i], keys[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(3, res.size()); Assert.assertEquals("apple", res.get(0)); @@ -231,6 +315,7 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i], keys[i]}); } + @SuppressWarnings("unchecked") List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(3, res.size()); @@ -260,7 +345,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i], keys[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(2, res.size()); Assert.assertEquals("candy", res.get(0)); @@ -287,7 +373,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i], keys[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(2, res.size()); Assert.assertEquals("apple", res.get(0)); @@ -312,7 +399,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i], keys[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(2, res.size()); Assert.assertEquals("apple", res.get(0)); @@ -339,7 +427,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i], keys[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(2, res.size()); Assert.assertEquals("candy", res.get(0)); @@ -360,7 +449,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertNull(res); } @@ -379,7 +469,8 @@ public class UDAFToOrderedListTest { evaluator.iterate(agg, new Object[] {values[i]}); } - List<Object> res = evaluator.terminate(agg); + @SuppressWarnings("unchecked") + List<Object> res = (List<Object>) evaluator.terminate(agg); Assert.assertEquals(3, res.size()); Assert.assertEquals("apple", res.get(0)); @@ -387,4 +478,287 @@ public class UDAFToOrderedListTest { Assert.assertEquals("candy", res.get(2)); } + @Test + public void testKVMapOption() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k 2 -kv_map")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(2, map.size()); + + Assert.assertEquals("candy", map.get(0.8d)); + Assert.assertEquals("banana", map.get(0.7d)); + } + + @Test + public void testVKMapOption() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k 2 -vk_map")}; + + final String[] values = new String[] {"banana", "apple", "candy"}; + final double[] keys = new double[] {0.7, 0.5, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(2, map.size()); + + Assert.assertEquals(0.8d, map.get("candy")); + Assert.assertEquals(0.7d, map.get("banana")); + } + + @Test + public void testVKMapOptionBananaOverlap() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k 2 -vk_map")}; + + final String[] values = new String[] {"banana", "banana", "candy"}; + final double[] keys = new double[] {0.7, 0.8, 0.81}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(2, map.size()); + + Assert.assertEquals(0.81d, map.get("candy")); + Assert.assertEquals(0.8d, map.get("banana")); + } + + @Test + public void testVKMapOptionBananaOverlap2() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k 2 -vk_map")}; + + final String[] values = new String[] {"banana", "banana", "candy"}; + final double[] keys = new double[] {0.8, 0.8, 0.7}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(1, map.size()); + + Assert.assertEquals(0.8d, map.get("banana")); + } + + @Test + public void testVKMapOptionReverseOrderTop2() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k -2 -vk_map")}; + + final String[] values = new String[] {"banana", "apple", "banana"}; + final double[] keys = new double[] {0.7, 0.6, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(2, map.size()); + + Assert.assertEquals(0.6d, map.get("apple")); + Assert.assertEquals(0.7d, map.get("banana")); + } + + @Test + public void testVKMapOptionReverseOrder() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-reverse -vk_map")}; + + final String[] values = new String[] {"banana", "apple", "banana"}; + final double[] keys = new double[] {0.7, 0.6, 0.8}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(2, map.size()); + + Assert.assertEquals(0.6d, map.get("apple")); + Assert.assertEquals(0.7d, map.get("banana")); + } + + @Test + public void testVKMapOptionBananaOverlapReverseOrder() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k -2 -vk_map")}; + + final String[] values = new String[] {"banana", "banana", "candy"}; + final double[] keys = new double[] {0.9, 0.8, 0.7}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(2, map.size()); + + Assert.assertEquals(0.7d, map.get("candy")); + Assert.assertEquals(0.8d, map.get("banana")); + } + + @Test + public void testVKMapTop2() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k 2 -vk_map")}; + + final int[] keys = new int[] {5, 3, 4, 2, 3}; + final String[] values = new String[] {"apple", "banana", "candy", "donut", "egg"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(2, map.size()); + + Assert.assertEquals(5, map.get("apple")); + Assert.assertEquals(4, map.get("candy")); + } + + @Test + public void testKVMapTop2() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k 2 -kv_map")}; + + final int[] keys = new int[] {5, 3, 4, 2, 3}; + final String[] values = new String[] {"apple", "banana", "candy", "donut", "egg"}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + evaluator.reset(agg); + + for (int i = 0; i < values.length; i++) { + evaluator.iterate(agg, new Object[] {values[i], keys[i]}); + } + + Object result = evaluator.terminate(agg); + + Assert.assertEquals(HashMap.class, result.getClass()); + Map<?, ?> map = (Map<?, ?>) result; + Assert.assertEquals(2, map.size()); + + Assert.assertEquals("apple", map.get(5)); + Assert.assertEquals("candy", map.get(4)); + } + + @Test(expected = UDFArgumentException.class) + public void testKVandVKFail() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k 2 -kv_map -vk_map")}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + } + + @Test(expected = UDFArgumentException.class) + public void testKVMapReturnWithoutValue() throws Exception { + ObjectInspector[] inputOIs = + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + "-k 2 -kv_map")}; + + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputOIs); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1ae9c9d7/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java b/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java index b8aa559..72c3e1f 100644 --- a/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java +++ b/core/src/test/java/hivemall/utils/collections/BoundedPriorityQueueTest.java @@ -111,4 +111,37 @@ public class BoundedPriorityQueueTest { Assert.assertNull(queue.poll()); } + @Test + public void testReverseOrderForTailK() { + // Note that queue holds tail-k elements for reverseOrder + BoundedPriorityQueue<Integer> queue = + new BoundedPriorityQueue<Integer>(2, Collections.<Integer>reverseOrder()); + queue.offer(3); + queue.offer(1); + queue.offer(2); + queue.offer(4); + queue.offer(-1); + + Assert.assertEquals(2, queue.size()); + // but order by reverse order + Assert.assertEquals(1, queue.poll().intValue()); + Assert.assertEquals(-1, queue.poll().intValue()); + } + + @Test + public void testNaturalOrderForTopK() { + // Note that queue holds top-k elements for Natural + BoundedPriorityQueue<Integer> queue = + new BoundedPriorityQueue<Integer>(2, NaturalComparator.<Integer>getInstance()); + queue.offer(3); + queue.offer(1); + queue.offer(2); + queue.offer(4); + queue.offer(-1); + + Assert.assertEquals(2, queue.size()); + // but order by natural order + Assert.assertEquals(3, queue.poll().intValue()); + Assert.assertEquals(4, queue.poll().intValue()); + } }
