Hi,

I have a problem with creating a UDAF using Map with String key.

My intention is to create a UDAF for multi counting as following.

Supposed to be an input table like

col1 col2
----------------
a       apple
b       apple
a       orange
a       apple
a       grape
b       grape

and the UDAF MultiCount(),

select col1, MultiCount(col2) group by col1;

, is intended to count for each element in col2
and to return the result as Map
so expected result is

a       {"grape":1,"orange":1,"apple":2}
b       {"grape":1,"apple":1}


HashMap<String,Long> is used for storing the multi count numbers in
MultiCountAgg implements AggregationBuffer.

Then partialResult from map function and result from the UDAF are
defined as MapWritable.

The problem is that partialResult in the end of
terminateParial((AggregationBuffer agg))
is fine, however, the partialResult is mixed up for each col1 element
if a key String
is stated from the column name of the table in iterate function as

String keyStr = new String
(PrimitiveObjectInspectorUtils.getString(parameters[0],(PrimitiveObjectInspector)inputOI)
);
.

Actual output is now

a       {"grape":1,"orange":1,"apple":2}
b       {"grape":1,"orange":1,"apple":2}

I confirmed partialResult stores correct data in the end of partialTerminate(),
however, it is wrong in the beginning of merge().

The strange is that if a key String is stated using a fixed String as

String keyStr = "Count";

it results

a       {"Count":4}
b       {"Count":2}

and it seems fine.

Is there any solution for this problem?


-- 
Kotaro Ogino

E-mail: kokotat...@gmail.com


----code---

public static class GenericUDAFMultiCountEvaluator extends
GenericUDAFEvaluator {
    ObjectInspector inputOI;
        
        MapWritable result;
        StandardMapObjectInspector mapOI;
        
        Object[] partialResult;

        StructObjectInspector soi;
        StructField mapField;
        
    @Override
    public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
      super.init(m, parameters);

            mapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
                                
PrimitiveObjectInspectorFactory.writableStringObjectInspector,
                                
PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            result = new MapWritable();

        if (mode == mode.PARTIAL1 || mode == mode.COMPLETE)
        {
                inputOI = (PrimitiveObjectInspector)parameters[0];
        }
        else
        {
                soi = (StructObjectInspector)parameters[0];
                mapField = soi.getStructFieldRef("map");
        }



          if (mode == mode.PARTIAL1 || mode == mode.PARTIAL2)
            {
        ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
        foi.add(mapOI);
        ArrayList<String> fname = new ArrayList<String>();
                fname.add("map");
                partialResult = new Object [1];
                partialResult[0] = new MapWritable();
        return ObjectInspectorFactory.getStandardStructObjectInspector(
fname, foi);

          }
     else
          {
                result = new MapWritable();
        return ObjectInspectorFactory.getStandardMapObjectInspector(
                                
PrimitiveObjectInspectorFactory.writableStringObjectInspector,
                                
PrimitiveObjectInspectorFactory.writableLongObjectInspector     );
           }
    }

    /** class for storing MultiCount value */
    static class MultiCountAgg implements AggregationBuffer {
            HashMap<String,Long> map;
    }

    @Override
    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
      MultiCountAgg result = new MultiCountAgg();
      reset(result);
      return result;
    }

    @Override
    public void reset(AggregationBuffer agg) throws HiveException {
            ((MultiCountAgg)agg).map =new HashMap<String,Long>();
    }

    @Override
    public void iterate(AggregationBuffer agg, Object[] parameters)
        throws HiveException {
      assert(parameters.length == 1);
      if (parameters[0] != null)
            {
                /*String keyStr="Count"; //seems fine */
                String keyStr= new String
(PrimitiveObjectInspectorUtils.getString(parameters[0],(PrimitiveObjectInspector)inputOI)
);

                if(     ((MultiCountAgg)agg).map.containsKey( ketyStr ) )
                {
                        Long aggValObj = ((MultiCountAgg)agg).map.get(keyStr);
                        ((MultiCountAgg)agg).map.put( keyStr,new 
Long(aggValObj.longValue()+1));
                }
                else
                {
                        ((MultiCountAgg)agg).map.put( keyStr,new Long(1));
                }
             }  
    }

    @Override
    public void merge(AggregationBuffer agg, Object partial) throws
HiveException
       {
      if (partial != null)
            {
                MultiCountAgg myagg = (MultiCountAgg)agg;
                Object partialMap =  soi.getStructFieldData(partial,mapField);

                HashMap parMap =(HashMap) ((LazyBinaryMap)partialMap).getMap();
                Iterator itr = parMap.keySet().iterator() ;

                if(myagg.map.size()<1)
                {
                        while (itr.hasNext())
                        {
                                Object parKeyObj = itr.next();
                                String parKeyStr = parKeyObj.toString();
                                LongWritable parValObj =(LongWritable) 
parMap.get(parKeyObj);
                                ((MultiCountAgg)agg).map.put( parKeyStr, new 
Long(parValObj.get()));
                        }
                }
                else
                {
                        while (itr.hasNext())
                        {
                                Object parKeyObj = itr.next();
                                String parKeyStr = parKeyObj.toString();
                                LongWritable parValObj =(LongWritable) 
parMap.get(parKeyObj);
                        
                                if(     
((MultiCountAgg)agg).map.containsKey(parKeyStr) )
                                {
                                        Long aggValObj =(Long) 
((MultiCountAgg)agg).map.get(parKeyStr);
                                        ((MultiCountAgg)agg).map.put( 
parKeyStr, new
Long(parValObj.get()+aggValObj.longValue()) );
                                }
                                else
                                {
                                        ((MultiCountAgg)agg).map.put( 
parKeyStr, parValObj.get());
                                }
                        }
                }
            }
       }

    @Override
    public Object terminate(AggregationBuffer agg) throws HiveException
        {
                MultiCountAgg myagg = (MultiCountAgg)agg;
                if (myagg.map.size() < 1)
                 {
                return null;
                 }

                Iterator itr = myagg.map.keySet().iterator();
            while(itr.hasNext())
                {
                        Object keyObj = itr.next();
                        String str1 = keyObj.toString();
                        Long valObj =(Long) myagg.map.get(str1);
                        result.put( new Text(str1), new 
LongWritable(valObj.longValue()) );
                }
        return result;
    }

    @Override
    public Object terminatePartial(AggregationBuffer agg) throws HiveException
        {
                MultiCountAgg myagg = (MultiCountAgg)agg;

                Iterator itr = myagg.map.keySet().iterator();
            while(itr.hasNext())
                {
                        Object keyObj = itr.next();
                        String str1 = keyObj.toString();
                        Long valObj =(Long) myagg.map.get(str1);
                        ((MapWritable)partialResult[0]).put( new Text(str1), new
LongWritable(valObj) );
                }
        return partialResult;
         }
  }

Reply via email to