This is an automated email from the ASF dual-hosted git repository.

glauesppen pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/incubator-wayang.git

commit d0b49634f9fe72f92dea7b452b4732e40f680790
Author: AdityaGoel11 <[email protected]>
AuthorDate: Fri Oct 6 06:08:57 2023 +0530

    Add feature: Support for basic SQL Aggregation queries for SUM, MIN, MAX, 
COUNT and AVG aggregation functions
---
 .../calcite/converter/WayangAggregateVisitor.java  | 62 ++++++++++++----------
 .../java/org/apache/wayang/api/sql/SqlAPI.java     |  2 +-
 2 files changed, 36 insertions(+), 28 deletions(-)

diff --git 
a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java
 
b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java
index 59369fdc..7e4eaf9c 100644
--- 
a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java
+++ 
b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java
@@ -49,9 +49,7 @@ public class WayangAggregateVisitor extends 
WayangRelNodeVisitor<WayangAggregate
         for (AggregateCall aggregateCall : aggregateCalls) {
             if (aggregateCall.getAggregation().getName().equals("SUM")) {
                 int fieldIndex = aggregateCall.getArgList().get(0);
-                //System.out.println(fieldIndex);
                 int groupCount = wayangRelNode.getGroupCount();
-                //System.out.println(groupCount);
                 if (groupCount > 0) {
                     int groupIndex = 0; // wayangRelNode.getGroupSets()
                     // Create the ReduceByOperator
@@ -79,9 +77,7 @@ public class WayangAggregateVisitor extends 
WayangRelNodeVisitor<WayangAggregate
             }
             else if (aggregateCall.getAggregation().getName().equals("MIN")) {
                 int fieldIndex = aggregateCall.getArgList().get(0);
-                //System.out.println(fieldIndex);
                 int groupCount = wayangRelNode.getGroupCount();
-                //System.out.println(groupCount);
                 if (groupCount > 0) {
                     int groupIndex = 0; // wayangRelNode.getGroupSets()
                     // Create the ReduceByOperator
@@ -109,9 +105,8 @@ public class WayangAggregateVisitor extends 
WayangRelNodeVisitor<WayangAggregate
             }
             else if (aggregateCall.getAggregation().getName().equals("MAX")) {
                 int fieldIndex = aggregateCall.getArgList().get(0);
-                //System.out.println(fieldIndex);
+
                 int groupCount = wayangRelNode.getGroupCount();
-                //System.out.println(groupCount);
                 if (groupCount > 0) {
                     int groupIndex = 0; // wayangRelNode.getGroupSets()
                     // Create the ReduceByOperator
@@ -137,10 +132,42 @@ public class WayangAggregateVisitor extends 
WayangRelNodeVisitor<WayangAggregate
                     return globalReduceOperator;
                 }
             }
+            else if (aggregateCall.getAggregation().getName().equals("COUNT")) 
{
+                MapOperator mapOperator = new MapOperator(
+                        new addCountCol(),
+                        Record.class,
+                        Record.class
+                );
+                childOp.connectTo(0, mapOperator, 0);
+
+                int groupCount = wayangRelNode.getGroupCount();
+                if (groupCount > 0) {
+                    int groupIndex = 0; // wayangRelNode.getGroupSets()
+                    // Create the ReduceByOperator
+                    ReduceByOperator<Record, Object> reduceByOperator;
+                    reduceByOperator = new ReduceByOperator<>(
+                            new TransformationDescriptor<>(new 
KeyExtractor(groupIndex), Record.class, Object.class),
+                            new ReduceDescriptor<>(new countFunction(),
+                                    DataUnitType.createGrouped(Record.class),
+                                    
DataUnitType.createBasicUnchecked(Record.class))
+                    );
+                    // Connect it to the child operator
+                    mapOperator.connectTo(0, reduceByOperator, 0);
+
+                    return reduceByOperator;
+                } else {
+                    GlobalReduceOperator<Record> globalReduceOperator;
+                    globalReduceOperator = new GlobalReduceOperator<>(
+                            new ReduceDescriptor<>(new countFunction(),
+                                    DataUnitType.createGrouped(Record.class),
+                                    
DataUnitType.createBasicUnchecked(Record.class))
+                    );
+                    mapOperator.connectTo(0, globalReduceOperator, 0);
+                    return globalReduceOperator;
+                }
+            }
             else if (aggregateCall.getAggregation().getName().equals("AVG")) {
-                //System.out.println(aggregateCall.getArgList());
                 int fieldIndex = aggregateCall.getArgList().get(0);
-                //System.out.println(fieldIndex);
                 MapOperator mapOperator1 = new MapOperator(
                         new addCountCol(),
                         Record.class,
@@ -149,7 +176,6 @@ public class WayangAggregateVisitor extends 
WayangRelNodeVisitor<WayangAggregate
                 childOp.connectTo(0, mapOperator1, 0);
 
                 int groupCount = wayangRelNode.getGroupCount();
-                //System.out.println(groupCount);
                 if (groupCount > 0) {
                     int groupIndex = 0; // wayangRelNode.getGroupSets()
                     // Create the ReduceByOperator
@@ -327,24 +353,6 @@ class addCountCol implements 
FunctionDescriptor.SerializableFunction<Record, Rec
     }
 }
 
-class removeCountCol implements 
FunctionDescriptor.SerializableFunction<Record, Record> {
-    public removeCountCol() {}
-
-    @Override
-    public Record apply(final Record record) {
-        int l = record.size();
-        int count = record.getInt(l-1);
-        Object[] resValues = new Object[l-1];
-        resValues[0] = count;
-        for(int i=0; i<l-2; i++){
-            resValues[i+1] = record.getField(i);
-        }
-
-        return new Record(resValues);
-
-    }
-}
-
 class getAvg implements FunctionDescriptor.SerializableFunction<Record, 
Record> {
     private final int fieldIndex;
     public getAvg(int fieldindex) {
diff --git 
a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlAPI.java 
b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlAPI.java
index 87febd3d..67d287a9 100755
--- 
a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlAPI.java
+++ 
b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlAPI.java
@@ -148,7 +148,7 @@ public class SqlAPI {
 
 
         Collection<Record> result = sqlContext.executeSql(
-                "SELECT avg(amount) FROM postgres.payment"
+                "SELECT sum(amount),customer_id, staff_id FROM 
postgres.payment group by customer_id, staff_id"
         );
 
         printResults(10, result);

Reply via email to