Modified: hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt URL: http://svn.apache.org/viewvc/hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt?rev=1485419&r1=1485418&r2=1485419&view=diff ============================================================================== --- hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt (original) +++ hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFSum.txt Wed May 22 20:58:08 2013 @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.exec.ve import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates. VectorAggregateExpression.AggregationBuffer; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; @@ -50,9 +51,18 @@ public class <ClassName> extends VectorA /** /* class for storing the current aggregate value. */ - static private final class Aggregation implements AggregationBuffer { + private static final class Aggregation implements AggregationBuffer { <ValueType> sum; boolean isNull; + + public void sumValue(<ValueType> value) { + if (isNull) { + sum = value; + isNull = false; + } else { + sum += value; + } + } } VectorExpression inputExpression; @@ -63,17 +73,207 @@ public class <ClassName> extends VectorA this.inputExpression = inputExpression; result = new <OutputType>(); } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex); + return myagg; + } + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + VectorizedRowBatch batch) throws HiveException { + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + inputExpression.evaluate(batch); + + LongColumnVector inputVector = (LongColumnVector)batch. + cols[this.inputExpression.getOutputColumn()]; + long[] vector = inputVector.vector; + + if (inputVector.noNulls) { + if (inputVector.isRepeating) { + iterateNoNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector[0], batchSize); + } else { + if (batch.selectedInUse) { + iterateNoNullsSelectionWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector, batch.selected, batchSize); + } else { + iterateNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector, batchSize); + } + } + } else { + if (inputVector.isRepeating) { + if (batch.selectedInUse) { + iterateHasNullsRepeatingSelectionWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector[0], batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsRepeatingWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector[0], batchSize, inputVector.isNull); + } + } else { + if (batch.selectedInUse) { + iterateHasNullsSelectionWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector, batchSize, batch.selected, inputVector.isNull); + } else { + iterateHasNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, + vector, batchSize, inputVector.isNull); + } + } + } + } + + private void iterateNoNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + long value, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + myagg.sumValue(value); + } + } + + private void iterateNoNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + long[] values, + int[] selection, + int batchSize) { + + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + myagg.sumValue(values[selection[i]]); + } + } + + private void iterateNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + long[] values, + int batchSize) { + for (int i=0; i < batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + myagg.sumValue(values[i]); + } + } + + private void iterateHasNullsRepeatingSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + long value, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[selection[i]]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + myagg.sumValue(value); + } + } + + } + + private void iterateHasNullsRepeatingWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + long value, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + myagg.sumValue(value); + } + } + } + + private void iterateHasNullsSelectionWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + long[] values, + int batchSize, + int[] selection, + boolean[] isNull) { + + for (int j=0; j < batchSize; ++j) { + int i = selection[j]; + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + j); + myagg.sumValue(values[i]); + } + } + } + + private void iterateHasNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + long[] values, + int batchSize, + boolean[] isNull) { + + for (int i=0; i < batchSize; ++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + myagg.sumValue(values[i]); + } + } + } + @Override - public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch unit) + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) throws HiveException { - inputExpression.evaluate(unit); + inputExpression.evaluate(batch); - <InputColumnVectorType> inputVector = (<InputColumnVectorType>)unit. + <InputColumnVectorType> inputVector = (<InputColumnVectorType>)batch. cols[this.inputExpression.getOutputColumn()]; - int batchSize = unit.size; + int batchSize = batch.size; if (batchSize == 0) { return; @@ -94,17 +294,17 @@ public class <ClassName> extends VectorA return; } - if (!unit.selectedInUse && inputVector.noNulls) { + if (!batch.selectedInUse && inputVector.noNulls) { iterateNoSelectionNoNulls(myagg, vector, batchSize); } - else if (!unit.selectedInUse) { + else if (!batch.selectedInUse) { iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); } else if (inputVector.noNulls){ - iterateSelectionNoNulls(myagg, vector, batchSize, unit.selected); + iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); } else { - iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, unit.selected); + iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); } }
Modified: hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFVar.txt URL: http://svn.apache.org/viewvc/hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFVar.txt?rev=1485419&r1=1485418&r2=1485419&view=diff ============================================================================== --- hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFVar.txt (original) +++ hive/branches/vectorization/ql/src/java/org/apache/hadoop/hive/ql/exec/vector/expressions/templates/VectorUDAFVar.txt Wed May 22 20:58:08 2013 @@ -25,6 +25,7 @@ import org.apache.hadoop.hive.ql.exec.ve import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression; import org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates .VectorAggregateExpression.AggregationBuffer; +import org.apache.hadoop.hive.ql.exec.vector.VectorAggregationBufferRow; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; @@ -50,13 +51,13 @@ public class <ClassName> extends VectorA /** /* class for storing the current aggregate value. */ - static private final class Aggregation implements AggregationBuffer { + private static final class Aggregation implements AggregationBuffer { double sum; long count; double variance; boolean isNull; - public void init () { + public void init() { isNull = false; sum = 0; count = 0; @@ -86,7 +87,7 @@ public class <ClassName> extends VectorA initPartialResultInspector(); } - private void initPartialResultInspector () { + private void initPartialResultInspector() { ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>(); foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector); foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); @@ -99,17 +100,200 @@ public class <ClassName> extends VectorA soi = ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi); } + + private Aggregation getCurrentAggregationBuffer( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + int row) { + VectorAggregationBufferRow mySet = aggregationBufferSets[row]; + Aggregation myagg = (Aggregation) mySet.getAggregationBuffer(aggregateIndex); + return myagg; + } + + + @Override + public void aggregateInputSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + VectorizedRowBatch batch) throws HiveException { + + inputExpression.evaluate(batch); + + <InputColumnVectorType> inputVector = (<InputColumnVectorType>)batch. + cols[this.inputExpression.getOutputColumn()]; + + int batchSize = batch.size; + + if (batchSize == 0) { + return; + } + + <ValueType>[] vector = inputVector.vector; + + if (inputVector.isRepeating) { + if (inputVector.noNulls || !inputVector.isNull[0]) { + iterateRepeatingNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector[0], batchSize); + } + } + else if (!batch.selectedInUse && inputVector.noNulls) { + iterateNoSelectionNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize); + } + else if (!batch.selectedInUse) { + iterateNoSelectionHasNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, inputVector.isNull); + } + else if (inputVector.noNulls){ + iterateSelectionNoNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, batch.selected); + } + else { + iterateSelectionHasNullsWithAggregationSelection( + aggregationBufferSets, aggregateIndex, vector, batchSize, + inputVector.isNull, batch.selected); + } + + } + private void iterateRepeatingNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + <ValueType> value, + int batchSize) { + + for (int i=0; i<batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + if (myagg.isNull) { + myagg.init (); + } + myagg.sum += value; + myagg.count += 1; + if(myagg.count > 1) { + double t = myagg.count*value - myagg.sum; + myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); + } + } + } + + private void iterateSelectionHasNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + <ValueType>[] vector, + int batchSize, + boolean[] isNull, + int[] selected) { + + for (int j=0; j< batchSize; ++j) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + j); + int i = selected[j]; + if (!isNull[i]) { + <ValueType> value = vector[i]; + if (myagg.isNull) { + myagg.init (); + } + myagg.sum += value; + myagg.count += 1; + if(myagg.count > 1) { + double t = myagg.count*value - myagg.sum; + myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); + } + } + } + } + + private void iterateSelectionNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + <ValueType>[] vector, + int batchSize, + int[] selected) { + + for (int i=0; i< batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + <ValueType> value = vector[selected[i]]; + if (myagg.isNull) { + myagg.init (); + } + myagg.sum += value; + myagg.count += 1; + if(myagg.count > 1) { + double t = myagg.count*value - myagg.sum; + myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); + } + } + } + + private void iterateNoSelectionHasNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + <ValueType>[] vector, + int batchSize, + boolean[] isNull) { + + for(int i=0;i<batchSize;++i) { + if (!isNull[i]) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + <ValueType> value = vector[i]; + if (myagg.isNull) { + myagg.init (); + } + myagg.sum += value; + myagg.count += 1; + if(myagg.count > 1) { + double t = myagg.count*value - myagg.sum; + myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); + } + } + } + } + + private void iterateNoSelectionNoNullsWithAggregationSelection( + VectorAggregationBufferRow[] aggregationBufferSets, + int aggregateIndex, + <ValueType>[] vector, + int batchSize) { + + for (int i=0; i<batchSize; ++i) { + Aggregation myagg = getCurrentAggregationBuffer( + aggregationBufferSets, + aggregateIndex, + i); + if (myagg.isNull) { + myagg.init (); + } + <ValueType> value = vector[i]; + myagg.sum += value; + myagg.count += 1; + if(myagg.count > 1) { + double t = myagg.count*value - myagg.sum; + myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1)); + } + } + } + @Override - public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch unit) + public void aggregateInput(AggregationBuffer agg, VectorizedRowBatch batch) throws HiveException { - inputExpression.evaluate(unit); + inputExpression.evaluate(batch); - <InputColumnVectorType> inputVector = (<InputColumnVectorType>)unit. + <InputColumnVectorType> inputVector = (<InputColumnVectorType>)batch. cols[this.inputExpression.getOutputColumn()]; - int batchSize = unit.size; + int batchSize = batch.size; if (batchSize == 0) { return; @@ -124,17 +308,17 @@ public class <ClassName> extends VectorA iterateRepeatingNoNulls(myagg, vector[0], batchSize); } } - else if (!unit.selectedInUse && inputVector.noNulls) { + else if (!batch.selectedInUse && inputVector.noNulls) { iterateNoSelectionNoNulls(myagg, vector, batchSize); } - else if (!unit.selectedInUse) { + else if (!batch.selectedInUse) { iterateNoSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull); } else if (inputVector.noNulls){ - iterateSelectionNoNulls(myagg, vector, batchSize, unit.selected); + iterateSelectionNoNulls(myagg, vector, batchSize, batch.selected); } else { - iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, unit.selected); + iterateSelectionHasNulls(myagg, vector, batchSize, inputVector.isNull, batch.selected); } } Modified: hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java URL: http://svn.apache.org/viewvc/hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java?rev=1485419&r1=1485418&r2=1485419&view=diff ============================================================================== --- hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java (original) +++ hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/TestVectorGroupByOperator.java Wed May 22 20:58:08 2013 @@ -28,8 +28,10 @@ import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import org.apache.hadoop.hive.ql.exec.vector.util.FakeCaptureOutputOperator; import org.apache.hadoop.hive.ql.exec.vector.util.FakeVectorRowBatchFromConcat; @@ -93,8 +95,144 @@ public class TestVectorGroupByOperator { return desc; } + private static GroupByDesc buildKeyGroupByDesc( + VectorizationContext ctx, + String aggregate, + String column, + String key) { + + GroupByDesc desc = buildGroupByDesc(ctx, aggregate, column); + + ExprNodeDesc keyExp = buildColumnDesc(ctx, key); + ArrayList<ExprNodeDesc> keys = new ArrayList<ExprNodeDesc>(); + keys.add(keyExp); + desc.setKeys(keys); + + return desc; + } + + @Test + public void testMinLongKeyGroupByCompactBatch() throws HiveException { + testAggregateLongKeyAggregate( + "min", + 2, + Arrays.asList(new Long[]{01L,1L,2L,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(1L, 5L, 2L, 7L)); + } + + @Test + public void testMinLongKeyGroupBySingleBatch() throws HiveException { + testAggregateLongKeyAggregate( + "min", + 4, + Arrays.asList(new Long[]{01L,1L,2L,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(1L, 5L, 2L, 7L)); + } + + @Test + public void testMinLongKeyGroupByCrossBatch() throws HiveException { + testAggregateLongKeyAggregate( + "min", + 2, + Arrays.asList(new Long[]{01L,2L,1L,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(1L, 7L, 2L, 5L)); + } + + @Test + public void testMinLongNullKeyGroupByCrossBatch() throws HiveException { + testAggregateLongKeyAggregate( + "min", + 2, + Arrays.asList(new Long[]{null,2L,null,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(null, 7L, 2L, 5L)); + } + + @Test + public void testMinLongNullKeyGroupBySingleBatch() throws HiveException { + testAggregateLongKeyAggregate( + "min", + 4, + Arrays.asList(new Long[]{null,2L,null,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(null, 7L, 2L, 5L)); + } + + @Test + public void testMaxLongNullKeyGroupBySingleBatch() throws HiveException { + testAggregateLongKeyAggregate( + "max", + 4, + Arrays.asList(new Long[]{null,2L,null,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(null, 13L, 2L, 19L)); + } + + @Test + public void testCountLongNullKeyGroupBySingleBatch() throws HiveException { + testAggregateLongKeyAggregate( + "count", + 4, + Arrays.asList(new Long[]{null,2L,null,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(null, 2L, 2L, 2L)); + } + + @Test + public void testSumLongNullKeyGroupBySingleBatch() throws HiveException { + testAggregateLongKeyAggregate( + "sum", + 4, + Arrays.asList(new Long[]{null,2L,null,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(null, 20L, 2L, 24L)); + } + + @Test + public void testAvgLongNullKeyGroupBySingleBatch() throws HiveException { + testAggregateLongKeyAggregate( + "avg", + 4, + Arrays.asList(new Long[]{null,2L,null,02L}), + Arrays.asList(new Long[]{13L,5L,7L,19L}), + buildHashMap(null, 10.0, 2L, 12.0)); + } + + @Test + public void testVarLongNullKeyGroupBySingleBatch() throws HiveException { + testAggregateLongKeyAggregate( + "variance", + 4, + Arrays.asList(new Long[]{null,2L,01L,02L,01L,01L}), + Arrays.asList(new Long[]{13L, 5L,18L,19L,12L,15L}), + buildHashMap(null, 0.0, 2L, 49.0, 01L, 6.0)); + } + + @Test + public void testMinNullLongNullKeyGroupBy() throws HiveException { + testAggregateLongKeyAggregate( + "min", + 4, + Arrays.asList(new Long[]{null,2L,null,02L}), + Arrays.asList(new Long[]{null, null, null, null}), + buildHashMap(null, null, 2L, null)); + } + + @Test + public void testMinLongGroupBy() throws HiveException { + testAggregateLongAggregate( + "min", + 2, + Arrays.asList(new Long[]{13L,5L,7L,19L}), + 5L); + } + + @Test - public void testMinLongSimple () throws HiveException { + public void testMinLongSimple() throws HiveException { testAggregateLongAggregate( "min", 2, @@ -735,7 +873,28 @@ public class TestVectorGroupByOperator { new Long[] {value}, repeat, batchSize); testAggregateLongIterable (aggregateName, fdr, expected); } + + public HashMap<Object, Object> buildHashMap(Object... pairs) { + HashMap<Object, Object> map = new HashMap<Object, Object>(); + for(int i = 0; i < pairs.length; i += 2) { + map.put(pairs[i], pairs[i+1]); + } + return map; + } + + + public void testAggregateLongKeyAggregate ( + String aggregateName, + int batchSize, + Iterable<Long> keys, + Iterable<Long> values, + HashMap<Object, Object> expected) throws HiveException { + @SuppressWarnings("unchecked") + FakeVectorRowBatchFromIterables fdr = new FakeVectorRowBatchFromIterables(batchSize, keys, values); + testAggregateLongKeyIterable (aggregateName, fdr, expected); + } + public void testAggregateLongAggregate ( String aggregateName, int batchSize, @@ -915,5 +1074,68 @@ public class TestVectorGroupByOperator { Validator validator = getValidator(aggregateName); validator.validate(expected, result); } + + public void testAggregateLongKeyIterable ( + String aggregateName, + Iterable<VectorizedRowBatch> data, + HashMap<Object,Object> expected) throws HiveException { + Map<String, Integer> mapColumnNames = new HashMap<String, Integer>(); + mapColumnNames.put("Key", 0); + mapColumnNames.put("Value", 1); + VectorizationContext ctx = new VectorizationContext(mapColumnNames, 2); + Set<Object> keys = new HashSet<Object>(); + + GroupByDesc desc = buildKeyGroupByDesc (ctx, aggregateName, "Value", "Key"); + + VectorGroupByOperator vgo = new VectorGroupByOperator(ctx, desc); + + FakeCaptureOutputOperator out = FakeCaptureOutputOperator.addCaptureOutputChild(vgo); + vgo.initialize(null, null); + out.setOutputInspector(new FakeCaptureOutputOperator.OutputInspector() { + + private int rowIndex; + private String aggregateName; + private HashMap<Object,Object> expected; + private Set<Object> keys; + + @Override + public void inspectRow(Object row, int tag) throws HiveException { + assertTrue(row instanceof Object[]); + Object[] fields = (Object[]) row; + assertEquals(2, fields.length); + Object key = fields[0]; + Long keyValue = null; + if (null != key) { + assertTrue(key instanceof LongWritable); + LongWritable lwKey = (LongWritable)key; + keyValue = lwKey.get(); + } + assertTrue(expected.containsKey(keyValue)); + Object expectedValue = expected.get(keyValue); + Object value = fields[1]; + Validator validator = getValidator(aggregateName); + validator.validate(expectedValue, new Object[] {value}); + keys.add(keyValue); + } + + private FakeCaptureOutputOperator.OutputInspector init( + String aggregateName, HashMap<Object,Object> expected, Set<Object> keys) { + this.aggregateName = aggregateName; + this.expected = expected; + this.keys = keys; + return this; + } + }.init(aggregateName, expected, keys)); + + for (VectorizedRowBatch unit: data) { + vgo.process(unit, 0); + } + vgo.close(false); + + List<Object> outBatchList = out.getCapturedRows(); + assertNotNull(outBatchList); + assertEquals(expected.size(), outBatchList.size()); + assertEquals(expected.size(), keys.size()); + } } Modified: hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java URL: http://svn.apache.org/viewvc/hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java?rev=1485419&r1=1485418&r2=1485419&view=diff ============================================================================== --- hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java (original) +++ hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/expressions/TestConstantVectorExpression.java Wed May 22 20:58:08 2013 @@ -27,6 +27,7 @@ import org.apache.hadoop.hive.ql.exec.ve import org.apache.hadoop.hive.ql.exec.vector.DoubleColumnVector; import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.hadoop.hive.ql.exec.vector.util.VectorizedRowGroupGenUtil; import org.junit.Test; public class TestConstantVectorExpression { Modified: hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeCaptureOutputOperator.java URL: http://svn.apache.org/viewvc/hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeCaptureOutputOperator.java?rev=1485419&r1=1485418&r2=1485419&view=diff ============================================================================== --- hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeCaptureOutputOperator.java (original) +++ hive/branches/vectorization/ql/src/test/org/apache/hadoop/hive/ql/exec/vector/util/FakeCaptureOutputOperator.java Wed May 22 20:58:08 2013 @@ -20,6 +20,7 @@ package org.apache.hadoop.hive.ql.exec.v import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.hadoop.conf.Configuration; @@ -35,6 +36,20 @@ import org.apache.hadoop.hive.ql.plan.ap public class FakeCaptureOutputOperator extends Operator<FakeCaptureOutputDesc> implements Serializable { private static final long serialVersionUID = 1L; + + public interface OutputInspector { + public void inspectRow(Object row, int tag) throws HiveException; + } + + private OutputInspector outputInspector; + + public void setOutputInspector(OutputInspector outputInspector) { + this.outputInspector = outputInspector; + } + + public OutputInspector getOutputInspector() { + return outputInspector; + } private transient List<Object> rows; @@ -52,6 +67,7 @@ public class FakeCaptureOutputOperator e return out; } + public List<Object> getCapturedRows() { return rows; } @@ -64,6 +80,9 @@ public class FakeCaptureOutputOperator e @Override public void processOp(Object row, int tag) throws HiveException { rows.add(row); + if (null != outputInspector) { + outputInspector.inspectRow(row, tag); + } } @Override
