This is an automated email from the ASF dual-hosted git repository. glauesppen pushed a commit to branch develop in repository https://gitbox.apache.org/repos/asf/incubator-wayang.git
commit d0b49634f9fe72f92dea7b452b4732e40f680790 Author: AdityaGoel11 <[email protected]> AuthorDate: Fri Oct 6 06:08:57 2023 +0530 Add feature: Support for basic SQL Aggregation queries for SUM, MIN, MAX, COUNT and AVG aggregation functions --- .../calcite/converter/WayangAggregateVisitor.java | 62 ++++++++++++---------- .../java/org/apache/wayang/api/sql/SqlAPI.java | 2 +- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java index 59369fdc..7e4eaf9c 100644 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java @@ -49,9 +49,7 @@ public class WayangAggregateVisitor extends WayangRelNodeVisitor<WayangAggregate for (AggregateCall aggregateCall : aggregateCalls) { if (aggregateCall.getAggregation().getName().equals("SUM")) { int fieldIndex = aggregateCall.getArgList().get(0); - //System.out.println(fieldIndex); int groupCount = wayangRelNode.getGroupCount(); - //System.out.println(groupCount); if (groupCount > 0) { int groupIndex = 0; // wayangRelNode.getGroupSets() // Create the ReduceByOperator @@ -79,9 +77,7 @@ public class WayangAggregateVisitor extends WayangRelNodeVisitor<WayangAggregate } else if (aggregateCall.getAggregation().getName().equals("MIN")) { int fieldIndex = aggregateCall.getArgList().get(0); - //System.out.println(fieldIndex); int groupCount = wayangRelNode.getGroupCount(); - //System.out.println(groupCount); if (groupCount > 0) { int groupIndex = 0; // wayangRelNode.getGroupSets() // Create the ReduceByOperator @@ -109,9 +105,8 @@ public class WayangAggregateVisitor extends WayangRelNodeVisitor<WayangAggregate } else if (aggregateCall.getAggregation().getName().equals("MAX")) { int fieldIndex = aggregateCall.getArgList().get(0); - //System.out.println(fieldIndex); + int groupCount = wayangRelNode.getGroupCount(); - //System.out.println(groupCount); if (groupCount > 0) { int groupIndex = 0; // wayangRelNode.getGroupSets() // Create the ReduceByOperator @@ -137,10 +132,42 @@ public class WayangAggregateVisitor extends WayangRelNodeVisitor<WayangAggregate return globalReduceOperator; } } + else if (aggregateCall.getAggregation().getName().equals("COUNT")) { + MapOperator mapOperator = new MapOperator( + new addCountCol(), + Record.class, + Record.class + ); + childOp.connectTo(0, mapOperator, 0); + + int groupCount = wayangRelNode.getGroupCount(); + if (groupCount > 0) { + int groupIndex = 0; // wayangRelNode.getGroupSets() + // Create the ReduceByOperator + ReduceByOperator<Record, Object> reduceByOperator; + reduceByOperator = new ReduceByOperator<>( + new TransformationDescriptor<>(new KeyExtractor(groupIndex), Record.class, Object.class), + new ReduceDescriptor<>(new countFunction(), + DataUnitType.createGrouped(Record.class), + DataUnitType.createBasicUnchecked(Record.class)) + ); + // Connect it to the child operator + mapOperator.connectTo(0, reduceByOperator, 0); + + return reduceByOperator; + } else { + GlobalReduceOperator<Record> globalReduceOperator; + globalReduceOperator = new GlobalReduceOperator<>( + new ReduceDescriptor<>(new countFunction(), + DataUnitType.createGrouped(Record.class), + DataUnitType.createBasicUnchecked(Record.class)) + ); + mapOperator.connectTo(0, globalReduceOperator, 0); + return globalReduceOperator; + } + } else if (aggregateCall.getAggregation().getName().equals("AVG")) { - //System.out.println(aggregateCall.getArgList()); int fieldIndex = aggregateCall.getArgList().get(0); - //System.out.println(fieldIndex); MapOperator mapOperator1 = new MapOperator( new addCountCol(), Record.class, @@ -149,7 +176,6 @@ public class WayangAggregateVisitor extends WayangRelNodeVisitor<WayangAggregate childOp.connectTo(0, mapOperator1, 0); int groupCount = wayangRelNode.getGroupCount(); - //System.out.println(groupCount); if (groupCount > 0) { int groupIndex = 0; // wayangRelNode.getGroupSets() // Create the ReduceByOperator @@ -327,24 +353,6 @@ class addCountCol implements FunctionDescriptor.SerializableFunction<Record, Rec } } -class removeCountCol implements FunctionDescriptor.SerializableFunction<Record, Record> { - public removeCountCol() {} - - @Override - public Record apply(final Record record) { - int l = record.size(); - int count = record.getInt(l-1); - Object[] resValues = new Object[l-1]; - resValues[0] = count; - for(int i=0; i<l-2; i++){ - resValues[i+1] = record.getField(i); - } - - return new Record(resValues); - - } -} - class getAvg implements FunctionDescriptor.SerializableFunction<Record, Record> { private final int fieldIndex; public getAvg(int fieldindex) { diff --git a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlAPI.java b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlAPI.java index 87febd3d..67d287a9 100755 --- a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlAPI.java +++ b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlAPI.java @@ -148,7 +148,7 @@ public class SqlAPI { Collection<Record> result = sqlContext.executeSql( - "SELECT avg(amount) FROM postgres.payment" + "SELECT sum(amount),customer_id, staff_id FROM postgres.payment group by customer_id, staff_id" ); printResults(10, result);
