Github user myui commented on a diff in the pull request:
https://github.com/apache/incubator-hivemall/pull/108#discussion_r138585976
--- Diff: core/src/main/java/hivemall/tools/map/UDAFToOrderedMap.java ---
@@ -92,4 +122,172 @@ public void reset(@SuppressWarnings("deprecation")
AggregationBuffer agg)
}
+ public static class TopKOrderedMapEvaluator extends
GenericUDAFEvaluator {
+
+ protected PrimitiveObjectInspector inputKeyOI;
+ protected ObjectInspector inputValueOI;
+ protected StandardMapObjectInspector partialMapOI;
+ protected PrimitiveObjectInspector sizeOI;
+
+ protected StructObjectInspector internalMergeOI;
+
+ protected StructField partialMapField;
+ protected StructField sizeField;
+
+ @Override
+ public ObjectInspector init(Mode mode, ObjectInspector[] argOIs)
throws HiveException {
+ super.init(mode, argOIs);
+
+ // initialize input
+ if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from
original data
+ this.inputKeyOI =
HiveUtils.asPrimitiveObjectInspector(argOIs[0]);
+ this.inputValueOI = argOIs[1];
+ this.sizeOI = HiveUtils.asIntegerOI(argOIs[2]);
+ } else {// from partial aggregation
+ StructObjectInspector soi = (StructObjectInspector)
argOIs[0];
+ this.internalMergeOI = soi;
+
+ this.partialMapField = soi.getStructFieldRef("partialMap");
+ // re-extract input key/value OIs
+ StandardMapObjectInspector partialMapOI =
(StandardMapObjectInspector) partialMapField.getFieldObjectInspector();
+ this.inputKeyOI =
HiveUtils.asPrimitiveObjectInspector(partialMapOI.getMapKeyObjectInspector());
+ this.inputValueOI =
partialMapOI.getMapValueObjectInspector();
+
+ this.partialMapOI =
ObjectInspectorFactory.getStandardMapObjectInspector(
+
ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI),
+
ObjectInspectorUtils.getStandardObjectInspector(inputValueOI));
+
+ this.sizeField = soi.getStructFieldRef("size");
+ this.sizeOI = (PrimitiveObjectInspector)
sizeField.getFieldObjectInspector();
+ }
+
+ // initialize output
+ final ObjectInspector outputOI;
+ if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {//
terminatePartial
+ outputOI = internalMergeOI(inputKeyOI, inputValueOI);
+ } else {// terminate
+ outputOI =
ObjectInspectorFactory.getStandardMapObjectInspector(
+
ObjectInspectorUtils.getStandardObjectInspector(inputKeyOI),
+
ObjectInspectorUtils.getStandardObjectInspector(inputValueOI));
+ }
+ return outputOI;
+ }
+
+ private static StructObjectInspector internalMergeOI(
+ @Nonnull PrimitiveObjectInspector keyOI, @Nonnull
ObjectInspector valueOI) {
+ ArrayList<String> fieldNames = new ArrayList<String>();
+ ArrayList<ObjectInspector> fieldOIs = new
ArrayList<ObjectInspector>();
+
+ fieldNames.add("partialMap");
+
fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
+ ObjectInspectorUtils.getStandardObjectInspector(keyOI),
+ ObjectInspectorUtils.getStandardObjectInspector(valueOI)));
+
+ fieldNames.add("size");
+
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
+
+ return
ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
+ }
+
+ static class MapAggregationBuffer extends
AbstractAggregationBuffer {
+ Map<Object, Object> container;
+ int size;
+
+ MapAggregationBuffer() {
+ super();
+ }
+ }
+
+ @Override
+ public void reset(@SuppressWarnings("deprecation")
AggregationBuffer agg)
+ throws HiveException {
+ MapAggregationBuffer myagg = (MapAggregationBuffer) agg;
+ myagg.container = new TreeMap<Object,
Object>(Collections.reverseOrder());
+ myagg.size = Integer.MAX_VALUE;
+ }
+
+ @Override
+ public MapAggregationBuffer getNewAggregationBuffer() throws
HiveException {
+ MapAggregationBuffer myagg = new MapAggregationBuffer();
+ reset(myagg);
+ return myagg;
+ }
+
+ @Override
+ public void iterate(@SuppressWarnings("deprecation")
AggregationBuffer agg,
+ Object[] parameters) throws HiveException {
+ assert (parameters.length == 3);
+
+ if (parameters[0] == null) {
+ return;
+ }
+
+ Object key =
ObjectInspectorUtils.copyToStandardObject(parameters[0], inputKeyOI);
+ Object value =
ObjectInspectorUtils.copyToStandardObject(parameters[1], inputValueOI);
+ int size = Math.abs(HiveUtils.getInt(parameters[2], sizeOI));
// size could be negative for tail-k
--- End diff --
parameter might be `boolean` but not considered.
---