Repository: beam Updated Branches: refs/heads/DSL_SQL 6729a027d -> 523482be0
Support common-used aggregation functions in SQL, including: COUNT,SUM,AVG,MAX,MIN rename BeamAggregationTransform to BeamAggregationTransforms update comments Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/f728fbe5 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/f728fbe5 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/f728fbe5 Branch: refs/heads/DSL_SQL Commit: f728fbe5c7153341ace046fa8b2465ef8844be1b Parents: 6729a02 Author: mingmxu <[email protected]> Authored: Wed May 10 16:38:13 2017 -0700 Committer: mingmxu <[email protected]> Committed: Wed May 10 20:47:40 2017 -0700 ---------------------------------------------------------------------- .../interpreter/operator/BeamSqlPrimitive.java | 35 + .../beam/dsls/sql/rel/BeamAggregationRel.java | 40 +- .../apache/beam/dsls/sql/schema/BeamSQLRow.java | 4 + .../beam/dsls/sql/schema/BeamSqlRowCoder.java | 4 +- .../sql/transform/BeamAggregationTransform.java | 120 ---- .../transform/BeamAggregationTransforms.java | 671 +++++++++++++++++++ .../transform/BeamAggregationTransformTest.java | 436 ++++++++++++ .../schema/transform/BeamTransformBaseTest.java | 96 +++ 8 files changed, 1261 insertions(+), 145 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java index 3309577..a5938f3 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/interpreter/operator/BeamSqlPrimitive.java @@ -65,6 +65,41 @@ public class BeamSqlPrimitive<T> extends BeamSqlExpression{ return value; } + public long getLong() { + return (Long) getValue(); + } + + public double getDouble() { + return (Double) getValue(); + } + + public float getFloat() { + return (Float) getValue(); + } + + public int getInteger() { + return (Integer) getValue(); + } + + public short getShort() { + return (Short) getValue(); + } + + public byte getByte() { + return (Byte) getValue(); + } + public boolean getBoolean() { + return (Boolean) getValue(); + } + + public String getString() { + return (String) getValue(); + } + + public Date getDate() { + return (Date) getValue(); + } + @Override public boolean accept() { if (value == null) { http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java index 2c7626d..ab98cc4 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java @@ -18,15 +18,13 @@ package org.apache.beam.dsls.sql.rel; import java.util.List; -import org.apache.beam.dsls.sql.exception.BeamSqlUnsupportedException; import org.apache.beam.dsls.sql.planner.BeamPipelineCreator; import org.apache.beam.dsls.sql.planner.BeamSQLRelUtils; import org.apache.beam.dsls.sql.schema.BeamSQLRecordType; import org.apache.beam.dsls.sql.schema.BeamSQLRow; -import org.apache.beam.dsls.sql.transform.BeamAggregationTransform; +import org.apache.beam.dsls.sql.transform.BeamAggregationTransforms; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.Combine; -import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.WithKeys; @@ -79,7 +77,7 @@ public class BeamAggregationRel extends Aggregate implements BeamRelNode { PCollection<BeamSQLRow> upstream = planCreator.popUpstream(); if (windowFieldIdx != -1) { upstream = upstream.apply("assignEventTimestamp", WithTimestamps - .<BeamSQLRow>of(new BeamAggregationTransform.WindowTimestampFn(windowFieldIdx))); + .<BeamSQLRow>of(new BeamAggregationTransforms.WindowTimestampFn(windowFieldIdx))); } PCollection<BeamSQLRow> windowStream = upstream.apply("window", @@ -88,32 +86,26 @@ public class BeamAggregationRel extends Aggregate implements BeamRelNode { .withAllowedLateness(allowedLatence) .accumulatingFiredPanes()); + //1. extract fields in group-by key part PCollection<KV<BeamSQLRow, BeamSQLRow>> exGroupByStream = windowStream.apply("exGroupBy", WithKeys - .of(new BeamAggregationTransform.AggregationGroupByKeyFn(windowFieldIdx, groupSet))); + .of(new BeamAggregationTransforms.AggregationGroupByKeyFn(windowFieldIdx, groupSet))); + //2. apply a GroupByKey. PCollection<KV<BeamSQLRow, Iterable<BeamSQLRow>>> groupedStream = exGroupByStream .apply("groupBy", GroupByKey.<BeamSQLRow, BeamSQLRow>create()); - if (aggCalls.size() > 1) { - throw new BeamSqlUnsupportedException("only single aggregation is supported now."); - } - - AggregateCall aggCall = aggCalls.get(0); - switch (aggCall.getAggregation().getName()) { - case "COUNT": - PCollection<KV<BeamSQLRow, Long>> aggregatedStream = groupedStream.apply("count", - Combine.<BeamSQLRow, BeamSQLRow, Long>groupedValues(Count.combineFn())); - PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord", - ParDo.of(new BeamAggregationTransform.MergeAggregationRecord( - BeamSQLRecordType.from(getRowType()), aggCall.getName()))); - planCreator.pushUpstream(mergedStream); - break; - default: - //Only support COUNT now, more are added in BEAM-2008 - throw new BeamSqlUnsupportedException( - String.format("Unsupported aggregation [%s]", aggCall.getAggregation().getName())); - } + //3. run aggregation functions + PCollection<KV<BeamSQLRow, BeamSQLRow>> aggregatedStream = groupedStream.apply("aggregation", + Combine.<BeamSQLRow, BeamSQLRow, BeamSQLRow>groupedValues( + new BeamAggregationTransforms.AggregationCombineFn(getAggCallList(), + BeamSQLRecordType.from(input.getRowType())))); + + //4. flat KV to a single record + PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord", + ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord( + BeamSQLRecordType.from(getRowType()), getAggCallList()))); + planCreator.pushUpstream(mergedStream); return planCreator.getPipeline(); } http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java index 65f4a41..5bdd5d2 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSQLRow.java @@ -144,6 +144,10 @@ public class BeamSQLRow implements Serializable { dataValues.set(index, fieldValue); } + public byte getByte(int idx) { + return (Byte) getFieldValue(idx); + } + public short getShort(int idx) { return (Short) getFieldValue(idx); } http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java index 3100ba5..0accb9a 100644 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlRowCoder.java @@ -70,9 +70,11 @@ public class BeamSqlRowCoder extends StandardCoder<BeamSQLRow>{ intCoder.encode(value.getInteger(idx), outStream, context.nested()); break; case SMALLINT: - case TINYINT: intCoder.encode((int) value.getShort(idx), outStream, context.nested()); break; + case TINYINT: + intCoder.encode((int) value.getByte(idx), outStream, context.nested()); + break; case DOUBLE: doubleCoder.encode(value.getDouble(idx), outStream, context.nested()); break; http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java deleted file mode 100644 index f478363..0000000 --- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransform.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.dsls.sql.transform; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import org.apache.beam.dsls.sql.schema.BeamSQLRecordType; -import org.apache.beam.dsls.sql.schema.BeamSQLRow; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.values.KV; -import org.apache.calcite.util.ImmutableBitSet; -import org.joda.time.Instant; - -/** - * Collections of {@code PTransform} and {@code DoFn} used to perform GROUP-BY operation. - */ -public class BeamAggregationTransform implements Serializable{ - /** - * Merge KV to single record. - */ - public static class MergeAggregationRecord extends DoFn<KV<BeamSQLRow, Long>, BeamSQLRow> { - private BeamSQLRecordType outRecordType; - private String aggFieldName; - - public MergeAggregationRecord(BeamSQLRecordType outRecordType, String aggFieldName) { - this.outRecordType = outRecordType; - this.aggFieldName = aggFieldName; - } - - @ProcessElement - public void processElement(ProcessContext c, BoundedWindow window) { - BeamSQLRow outRecord = new BeamSQLRow(outRecordType); - outRecord.updateWindowRange(c.element().getKey(), window); - - KV<BeamSQLRow, Long> kvRecord = c.element(); - for (String f : kvRecord.getKey().getDataType().getFieldsName()) { - outRecord.addField(f, kvRecord.getKey().getFieldValue(f)); - } - outRecord.addField(aggFieldName, kvRecord.getValue()); - -// if (c.pane().isLast()) { - c.output(outRecord); -// } - } - } - - /** - * extract group-by fields. - */ - public static class AggregationGroupByKeyFn - implements SerializableFunction<BeamSQLRow, BeamSQLRow> { - private List<Integer> groupByKeys; - - public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) { - this.groupByKeys = new ArrayList<>(); - for (int i : groupSet.asList()) { - if (i != windowFieldIdx) { - groupByKeys.add(i); - } - } - } - - @Override - public BeamSQLRow apply(BeamSQLRow input) { - BeamSQLRecordType typeOfKey = exTypeOfKeyRecord(input.getDataType()); - BeamSQLRow keyOfRecord = new BeamSQLRow(typeOfKey); - keyOfRecord.updateWindowRange(input, null); - - for (int idx = 0; idx < groupByKeys.size(); ++idx) { - keyOfRecord.addField(idx, input.getFieldValue(groupByKeys.get(idx))); - } - return keyOfRecord; - } - - private BeamSQLRecordType exTypeOfKeyRecord(BeamSQLRecordType dataType) { - BeamSQLRecordType typeOfKey = new BeamSQLRecordType(); - for (int idx : groupByKeys) { - typeOfKey.addField(dataType.getFieldsName().get(idx), dataType.getFieldsType().get(idx)); - } - return typeOfKey; - } - - } - - /** - * Assign event timestamp. - */ - public static class WindowTimestampFn implements SerializableFunction<BeamSQLRow, Instant> { - private int windowFieldIdx = -1; - - public WindowTimestampFn(int windowFieldIdx) { - super(); - this.windowFieldIdx = windowFieldIdx; - } - - @Override - public Instant apply(BeamSQLRow input) { - return new Instant(input.getDate(windowFieldIdx).getTime()); - } - } - -} http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java new file mode 100644 index 0000000..943c897 --- /dev/null +++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java @@ -0,0 +1,671 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.dsls.sql.transform; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import org.apache.beam.dsls.sql.exception.BeamSqlUnsupportedException; +import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlExpression; +import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlInputRefExpression; +import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlPrimitive; +import org.apache.beam.dsls.sql.schema.BeamSQLRecordType; +import org.apache.beam.dsls.sql.schema.BeamSQLRow; +import org.apache.beam.sdk.transforms.Combine.CombineFn; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.KV; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.joda.time.Instant; + +/** + * Collections of {@code PTransform} and {@code DoFn} used to perform GROUP-BY operation. + */ +public class BeamAggregationTransforms implements Serializable{ + /** + * Merge KV to single record. + */ + public static class MergeAggregationRecord extends DoFn<KV<BeamSQLRow, BeamSQLRow>, BeamSQLRow> { + private BeamSQLRecordType outRecordType; + private List<String> aggFieldNames; + + public MergeAggregationRecord(BeamSQLRecordType outRecordType, List<AggregateCall> aggList) { + this.outRecordType = outRecordType; + this.aggFieldNames = new ArrayList<>(); + for (AggregateCall ac : aggList) { + aggFieldNames.add(ac.getName()); + } + } + + @ProcessElement + public void processElement(ProcessContext c, BoundedWindow window) { + BeamSQLRow outRecord = new BeamSQLRow(outRecordType); + outRecord.updateWindowRange(c.element().getKey(), window); + + KV<BeamSQLRow, BeamSQLRow> kvRecord = c.element(); + for (String f : kvRecord.getKey().getDataType().getFieldsName()) { + outRecord.addField(f, kvRecord.getKey().getFieldValue(f)); + } + for (int idx = 0; idx < aggFieldNames.size(); ++idx) { + outRecord.addField(aggFieldNames.get(idx), kvRecord.getValue().getFieldValue(idx)); + } + + // if (c.pane().isLast()) { + c.output(outRecord); + // } + } + } + + /** + * extract group-by fields. + */ + public static class AggregationGroupByKeyFn + implements SerializableFunction<BeamSQLRow, BeamSQLRow> { + private List<Integer> groupByKeys; + + public AggregationGroupByKeyFn(int windowFieldIdx, ImmutableBitSet groupSet) { + this.groupByKeys = new ArrayList<>(); + for (int i : groupSet.asList()) { + if (i != windowFieldIdx) { + groupByKeys.add(i); + } + } + } + + @Override + public BeamSQLRow apply(BeamSQLRow input) { + BeamSQLRecordType typeOfKey = exTypeOfKeyRecord(input.getDataType()); + BeamSQLRow keyOfRecord = new BeamSQLRow(typeOfKey); + keyOfRecord.updateWindowRange(input, null); + + for (int idx = 0; idx < groupByKeys.size(); ++idx) { + keyOfRecord.addField(idx, input.getFieldValue(groupByKeys.get(idx))); + } + return keyOfRecord; + } + + private BeamSQLRecordType exTypeOfKeyRecord(BeamSQLRecordType dataType) { + BeamSQLRecordType typeOfKey = new BeamSQLRecordType(); + for (int idx : groupByKeys) { + typeOfKey.addField(dataType.getFieldsName().get(idx), dataType.getFieldsType().get(idx)); + } + return typeOfKey; + } + + } + + /** + * Assign event timestamp. + */ + public static class WindowTimestampFn implements SerializableFunction<BeamSQLRow, Instant> { + private int windowFieldIdx = -1; + + public WindowTimestampFn(int windowFieldIdx) { + super(); + this.windowFieldIdx = windowFieldIdx; + } + + @Override + public Instant apply(BeamSQLRow input) { + return new Instant(input.getDate(windowFieldIdx).getTime()); + } + } + + /** + * Aggregation function which supports COUNT, MAX, MIN, SUM, AVG. + * + * <p>Multiple aggregation functions are combined together. + * For each aggregation function, it may accept part of all data types:<br> + * 1). COUNT works for any data type;<br> + * 2). MAX/MIN works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT, TIMESTAMP;<br> + * 3). SUM/AVG works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT;<br> + * + */ + public static class AggregationCombineFn extends CombineFn<BeamSQLRow, BeamSQLRow, BeamSQLRow> { + private BeamSQLRecordType aggDataType; + + private int countIndex = -1; + + List<String> aggFunctions; + List<BeamSqlExpression> aggElementExpressions; + + public AggregationCombineFn(List<AggregateCall> aggregationCalls, + BeamSQLRecordType sourceRowRecordType) { + this.aggDataType = new BeamSQLRecordType(); + this.aggFunctions = new ArrayList<>(); + this.aggElementExpressions = new ArrayList<>(); + + boolean hasAvg = false; + boolean hasCount = false; + int countIndex = -1; + for (int idx = 0; idx < aggregationCalls.size(); ++idx) { + AggregateCall ac = aggregationCalls.get(idx); + //verify it's supported. + verifySupportedAggregation(ac); + + aggDataType.addField(ac.name, ac.type.getSqlTypeName()); + + SqlAggFunction aggFn = ac.getAggregation(); + switch (aggFn.getName()) { + case "COUNT": + aggElementExpressions.add(BeamSqlPrimitive.<Long>of(SqlTypeName.BIGINT, 1L)); + hasCount = true; + countIndex = idx; + break; + case "SUM": + case "MAX": + case "MIN": + case "AVG": + int refIndex = ac.getArgList().get(0); + aggElementExpressions.add(new BeamSqlInputRefExpression( + sourceRowRecordType.getFieldsType().get(refIndex), refIndex)); + if ("AVG".equals(aggFn.getName())) { + hasAvg = true; + } + break; + + default: + break; + } + aggFunctions.add(aggFn.getName()); + } + // add a COUNT holder if only have AVG + if (hasAvg && !hasCount) { + aggDataType.addField("__COUNT", SqlTypeName.BIGINT); + + aggFunctions.add("COUNT"); + aggElementExpressions.add(BeamSqlPrimitive.<Long>of(SqlTypeName.BIGINT, 1L)); + + hasCount = true; + countIndex = aggDataType.size() - 1; + } + + this.countIndex = countIndex; + } + + private void verifySupportedAggregation(AggregateCall ac) { + //donot support DISTINCT + if (ac.isDistinct()) { + throw new BeamSqlUnsupportedException("DISTINCT is not supported yet."); + } + String aggFnName = ac.getAggregation().getName(); + switch (aggFnName) { + case "COUNT": + //COUNT works for any data type; + break; + case "SUM": + // SUM only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, + // TINYINT now + if (!Arrays + .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE, + SqlTypeName.SMALLINT, SqlTypeName.TINYINT) + .contains(ac.type.getSqlTypeName())) { + throw new BeamSqlUnsupportedException( + "SUM only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT"); + } + break; + case "MAX": + case "MIN": + // MAX/MIN only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, + // TINYINT, TIMESTAMP now + if (!Arrays.asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, + SqlTypeName.DOUBLE, SqlTypeName.SMALLINT, SqlTypeName.TINYINT, + SqlTypeName.TIMESTAMP).contains(ac.type.getSqlTypeName())) { + throw new BeamSqlUnsupportedException("MAX/MIN only support for INT, LONG, FLOAT," + + " DOUBLE, SMALLINT, TINYINT, TIMESTAMP"); + } + break; + case "AVG": + // AVG only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, + // TINYINT now + if (!Arrays + .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE, + SqlTypeName.SMALLINT, SqlTypeName.TINYINT) + .contains(ac.type.getSqlTypeName())) { + throw new BeamSqlUnsupportedException( + "AVG only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT"); + } + break; + default: + throw new BeamSqlUnsupportedException( + String.format("[%s] is not supported.", aggFnName)); + } + } + + @Override + public BeamSQLRow createAccumulator() { + BeamSQLRow initialRecord = new BeamSQLRow(aggDataType); + for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { + BeamSqlExpression ex = aggElementExpressions.get(idx); + String aggFnName = aggFunctions.get(idx); + switch (aggFnName) { + case "COUNT": + initialRecord.addField(idx, 0L); + break; + case "AVG": + case "SUM": + //for both AVG/SUM, a summary value is hold at first. + switch (ex.getOutputType()) { + case INTEGER: + initialRecord.addField(idx, 0); + break; + case BIGINT: + initialRecord.addField(idx, 0L); + break; + case SMALLINT: + initialRecord.addField(idx, (short) 0); + break; + case TINYINT: + initialRecord.addField(idx, (byte) 0); + break; + case FLOAT: + initialRecord.addField(idx, 0.0f); + break; + case DOUBLE: + initialRecord.addField(idx, 0.0); + break; + default: + break; + } + break; + case "MAX": + switch (ex.getOutputType()) { + case INTEGER: + initialRecord.addField(idx, Integer.MIN_VALUE); + break; + case BIGINT: + initialRecord.addField(idx, Long.MIN_VALUE); + break; + case SMALLINT: + initialRecord.addField(idx, Short.MIN_VALUE); + break; + case TINYINT: + initialRecord.addField(idx, Byte.MIN_VALUE); + break; + case FLOAT: + initialRecord.addField(idx, Float.MIN_VALUE); + break; + case DOUBLE: + initialRecord.addField(idx, Double.MIN_VALUE); + break; + case TIMESTAMP: + initialRecord.addField(idx, new Date(0)); + break; + default: + break; + } + break; + case "MIN": + switch (ex.getOutputType()) { + case INTEGER: + initialRecord.addField(idx, Integer.MAX_VALUE); + break; + case BIGINT: + initialRecord.addField(idx, Long.MAX_VALUE); + break; + case SMALLINT: + initialRecord.addField(idx, Short.MAX_VALUE); + break; + case TINYINT: + initialRecord.addField(idx, Byte.MAX_VALUE); + break; + case FLOAT: + initialRecord.addField(idx, Float.MAX_VALUE); + break; + case DOUBLE: + initialRecord.addField(idx, Double.MAX_VALUE); + break; + case TIMESTAMP: + initialRecord.addField(idx, new Date(Long.MAX_VALUE)); + break; + default: + break; + } + break; + default: + break; + } + } + return initialRecord; + } + + @Override + public BeamSQLRow addInput(BeamSQLRow accumulator, BeamSQLRow input) { + BeamSQLRow deltaRecord = new BeamSQLRow(aggDataType); + for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { + BeamSqlExpression ex = aggElementExpressions.get(idx); + String aggFnName = aggFunctions.get(idx); + switch (aggFnName) { + case "COUNT": + deltaRecord.addField(idx, 1 + accumulator.getLong(idx)); + break; + case "AVG": + case "SUM": + // for both AVG/SUM, a summary value is hold at first. + switch (ex.getOutputType()) { + case INTEGER: + deltaRecord.addField(idx, + ex.evaluate(input).getInteger() + accumulator.getInteger(idx)); + break; + case BIGINT: + deltaRecord.addField(idx, ex.evaluate(input).getLong() + accumulator.getLong(idx)); + break; + case SMALLINT: + deltaRecord.addField(idx, + (short) (ex.evaluate(input).getShort() + accumulator.getShort(idx))); + break; + case TINYINT: + deltaRecord.addField(idx, + (byte) (ex.evaluate(input).getByte() + accumulator.getByte(idx))); + break; + case FLOAT: + deltaRecord.addField(idx, + (float) (ex.evaluate(input).getFloat() + accumulator.getFloat(idx))); + break; + case DOUBLE: + deltaRecord.addField(idx, ex.evaluate(input).getDouble() + accumulator.getDouble(idx)); + break; + default: + break; + } + break; + case "MAX": + switch (ex.getOutputType()) { + case INTEGER: + deltaRecord.addField(idx, + Math.max(ex.evaluate(input).getInteger(), accumulator.getInteger(idx))); + break; + case BIGINT: + deltaRecord.addField(idx, + Math.max(ex.evaluate(input).getLong(), accumulator.getLong(idx))); + break; + case SMALLINT: + deltaRecord.addField(idx, + (short) Math.max(ex.evaluate(input).getShort(), accumulator.getShort(idx))); + break; + case TINYINT: + deltaRecord.addField(idx, + (byte) Math.max(ex.evaluate(input).getByte(), accumulator.getByte(idx))); + break; + case FLOAT: + deltaRecord.addField(idx, + Math.max(ex.evaluate(input).getFloat(), accumulator.getFloat(idx))); + break; + case DOUBLE: + deltaRecord.addField(idx, + Math.max(ex.evaluate(input).getDouble(), accumulator.getDouble(idx))); + break; + case TIMESTAMP: + Date preDate = accumulator.getDate(idx); + Date nowDate = ex.evaluate(input).getDate(); + deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate); + break; + default: + break; + } + break; + case "MIN": + switch (ex.getOutputType()) { + case INTEGER: + deltaRecord.addField(idx, + Math.min(ex.evaluate(input).getInteger(), accumulator.getInteger(idx))); + break; + case BIGINT: + deltaRecord.addField(idx, + Math.min(ex.evaluate(input).getLong(), accumulator.getLong(idx))); + break; + case SMALLINT: + deltaRecord.addField(idx, + (short) Math.min(ex.evaluate(input).getShort(), accumulator.getShort(idx))); + break; + case TINYINT: + deltaRecord.addField(idx, + (byte) Math.min(ex.evaluate(input).getByte(), accumulator.getByte(idx))); + break; + case FLOAT: + deltaRecord.addField(idx, + Math.min(ex.evaluate(input).getFloat(), accumulator.getFloat(idx))); + break; + case DOUBLE: + deltaRecord.addField(idx, + Math.min(ex.evaluate(input).getDouble(), accumulator.getDouble(idx))); + break; + case TIMESTAMP: + Date preDate = accumulator.getDate(idx); + Date nowDate = ex.evaluate(input).getDate(); + deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate); + break; + default: + break; + } + break; + default: + break; + } + } + return deltaRecord; + } + + @Override + public BeamSQLRow mergeAccumulators(Iterable<BeamSQLRow> accumulators) { + BeamSQLRow deltaRecord = new BeamSQLRow(aggDataType); + + while (accumulators.iterator().hasNext()) { + BeamSQLRow sa = accumulators.iterator().next(); + for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { + BeamSqlExpression ex = aggElementExpressions.get(idx); + String aggFnName = aggFunctions.get(idx); + switch (aggFnName) { + case "COUNT": + deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx)); + break; + case "AVG": + case "SUM": + // for both AVG/SUM, a summary value is hold at first. + switch (ex.getOutputType()) { + case INTEGER: + deltaRecord.addField(idx, deltaRecord.getInteger(idx) + sa.getInteger(idx)); + break; + case BIGINT: + deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx)); + break; + case SMALLINT: + deltaRecord.addField(idx, (short) (deltaRecord.getShort(idx) + sa.getShort(idx))); + break; + case TINYINT: + deltaRecord.addField(idx, (byte) (deltaRecord.getByte(idx) + sa.getByte(idx))); + break; + case FLOAT: + deltaRecord.addField(idx, (float) (deltaRecord.getFloat(idx) + sa.getFloat(idx))); + break; + case DOUBLE: + deltaRecord.addField(idx, deltaRecord.getDouble(idx) + sa.getDouble(idx)); + break; + default: + break; + } + break; + case "MAX": + switch (ex.getOutputType()) { + case INTEGER: + deltaRecord.addField(idx, Math.max(deltaRecord.getInteger(idx), sa.getInteger(idx))); + break; + case BIGINT: + deltaRecord.addField(idx, Math.max(deltaRecord.getLong(idx), sa.getLong(idx))); + break; + case SMALLINT: + deltaRecord.addField(idx, + (short) Math.max(deltaRecord.getShort(idx), sa.getShort(idx))); + break; + case TINYINT: + deltaRecord.addField(idx, (byte) Math.max(deltaRecord.getByte(idx), sa.getByte(idx))); + break; + case FLOAT: + deltaRecord.addField(idx, Math.max(deltaRecord.getFloat(idx), sa.getFloat(idx))); + break; + case DOUBLE: + deltaRecord.addField(idx, Math.max(deltaRecord.getDouble(idx), sa.getDouble(idx))); + break; + case TIMESTAMP: + Date preDate = deltaRecord.getDate(idx); + Date nowDate = sa.getDate(idx); + deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate); + break; + default: + break; + } + break; + case "MIN": + switch (ex.getOutputType()) { + case INTEGER: + deltaRecord.addField(idx, Math.min(deltaRecord.getInteger(idx), sa.getInteger(idx))); + break; + case BIGINT: + deltaRecord.addField(idx, Math.min(deltaRecord.getLong(idx), sa.getLong(idx))); + break; + case SMALLINT: + deltaRecord.addField(idx, + (short) Math.min(deltaRecord.getShort(idx), sa.getShort(idx))); + break; + case TINYINT: + deltaRecord.addField(idx, (byte) Math.min(deltaRecord.getByte(idx), sa.getByte(idx))); + break; + case FLOAT: + deltaRecord.addField(idx, Math.min(deltaRecord.getFloat(idx), sa.getFloat(idx))); + break; + case DOUBLE: + deltaRecord.addField(idx, Math.min(deltaRecord.getDouble(idx), sa.getDouble(idx))); + break; + case TIMESTAMP: + Date preDate = deltaRecord.getDate(idx); + Date nowDate = sa.getDate(idx); + deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate); + break; + default: + break; + } + break; + default: + break; + } + } + } + return deltaRecord; + } + + @Override + public BeamSQLRow extractOutput(BeamSQLRow accumulator) { + BeamSQLRow finalRecord = new BeamSQLRow(aggDataType); + for (int idx = 0; idx < aggElementExpressions.size(); ++idx) { + BeamSqlExpression ex = aggElementExpressions.get(idx); + String aggFnName = aggFunctions.get(idx); + switch (aggFnName) { + case "COUNT": + finalRecord.addField(idx, accumulator.getLong(idx)); + break; + case "AVG": + long count = accumulator.getLong(countIndex); + switch (ex.getOutputType()) { + case INTEGER: + finalRecord.addField(idx, (int) (accumulator.getInteger(idx) / count)); + break; + case BIGINT: + finalRecord.addField(idx, accumulator.getLong(idx) / count); + break; + case SMALLINT: + finalRecord.addField(idx, (short) (accumulator.getShort(idx) / count)); + break; + case TINYINT: + finalRecord.addField(idx, (byte) (accumulator.getByte(idx) / count)); + break; + case FLOAT: + finalRecord.addField(idx, (float) (accumulator.getFloat(idx) / count)); + break; + case DOUBLE: + finalRecord.addField(idx, accumulator.getDouble(idx) / count); + break; + default: + break; + } + break; + case "SUM": + switch (ex.getOutputType()) { + case INTEGER: + finalRecord.addField(idx, accumulator.getInteger(idx)); + break; + case BIGINT: + finalRecord.addField(idx, accumulator.getLong(idx)); + break; + case SMALLINT: + finalRecord.addField(idx, accumulator.getShort(idx)); + break; + case TINYINT: + finalRecord.addField(idx, accumulator.getByte(idx)); + break; + case FLOAT: + finalRecord.addField(idx, accumulator.getFloat(idx)); + break; + case DOUBLE: + finalRecord.addField(idx, accumulator.getDouble(idx)); + break; + default: + break; + } + break; + case "MAX": + case "MIN": + switch (ex.getOutputType()) { + case INTEGER: + finalRecord.addField(idx, accumulator.getInteger(idx)); + break; + case BIGINT: + finalRecord.addField(idx, accumulator.getLong(idx)); + break; + case SMALLINT: + finalRecord.addField(idx, accumulator.getShort(idx)); + break; + case TINYINT: + finalRecord.addField(idx, accumulator.getByte(idx)); + break; + case FLOAT: + finalRecord.addField(idx, accumulator.getFloat(idx)); + break; + case DOUBLE: + finalRecord.addField(idx, accumulator.getDouble(idx)); + break; + case TIMESTAMP: + finalRecord.addField(idx, accumulator.getDate(idx)); + break; + default: + break; + } + break; + default: + break; + } + } + return finalRecord; + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java new file mode 100644 index 0000000..f174b9c --- /dev/null +++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java @@ -0,0 +1,436 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * <p> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.dsls.sql.schema.transform; + +import java.text.ParseException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.beam.dsls.sql.planner.BeamQueryPlanner; +import org.apache.beam.dsls.sql.schema.BeamSQLRecordType; +import org.apache.beam.dsls.sql.schema.BeamSQLRecordTypeCoder; +import org.apache.beam.dsls.sql.schema.BeamSQLRow; +import org.apache.beam.dsls.sql.schema.BeamSqlRowCoder; +import org.apache.beam.dsls.sql.transform.BeamAggregationTransforms; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder; +import org.apache.calcite.rel.type.RelDataTypeSystem; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlAvgAggFunction; +import org.apache.calcite.sql.fun.SqlCountAggFunction; +import org.apache.calcite.sql.fun.SqlMinMaxAggFunction; +import org.apache.calcite.sql.fun.SqlSumAggFunction; +import org.apache.calcite.sql.type.BasicSqlType; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.ImmutableBitSet; +import org.junit.Rule; +import org.junit.Test; + +/** + * Unit tests for {@link BeamAggregationTransforms}. + * + */ +public class BeamAggregationTransformTest extends BeamTransformBaseTest{ + + @Rule + public TestPipeline p = TestPipeline.create(); + + private List<AggregateCall> aggCalls; + private BeamSQLRecordType keyType = initTypeOfSqlRow( + Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER))); + + /** + * This step equals to below query. + * <pre> + * SELECT `f_int` + * , COUNT(*) AS `size` + * , SUM(`f_long`) AS `sum1`, AVG(`f_long`) AS `avg1` + * , MAX(`f_long`) AS `max1`, MIN(`f_long`) AS `min1` + * , SUM(`f_short`) AS `sum2`, AVG(`f_short`) AS `avg2` + * , MAX(`f_short`) AS `max2`, MIN(`f_short`) AS `min2` + * , SUM(`f_byte`) AS `sum3`, AVG(`f_byte`) AS `avg3` + * , MAX(`f_byte`) AS `max3`, MIN(`f_byte`) AS `min3` + * , SUM(`f_float`) AS `sum4`, AVG(`f_float`) AS `avg4` + * , MAX(`f_float`) AS `max4`, MIN(`f_float`) AS `min4` + * , SUM(`f_double`) AS `sum5`, AVG(`f_double`) AS `avg5` + * , MAX(`f_double`) AS `max5`, MIN(`f_double`) AS `min5` + * , MAX(`f_timestamp`) AS `max7`, MIN(`f_timestamp`) AS `min7` + * ,SUM(`f_int2`) AS `sum8`, AVG(`f_int2`) AS `avg8` + * , MAX(`f_int2`) AS `max8`, MIN(`f_int2`) AS `min8` + * FROM TABLE_NAME + * GROUP BY `f_int` + * </pre> + * @throws ParseException + */ + @Test + public void testCountPerElementBasic() throws ParseException { + setupEnvironment(); + + PCollection<BeamSQLRow> input = p.apply(Create.of(inputRows)); + + //1. extract fields in group-by key part + PCollection<KV<BeamSQLRow, BeamSQLRow>> exGroupByStream = input.apply("exGroupBy", + WithKeys + .of(new BeamAggregationTransforms.AggregationGroupByKeyFn(-1, ImmutableBitSet.of(0)))); + + //2. apply a GroupByKey. + PCollection<KV<BeamSQLRow, Iterable<BeamSQLRow>>> groupedStream = exGroupByStream + .apply("groupBy", GroupByKey.<BeamSQLRow, BeamSQLRow>create()); + + //3. run aggregation functions + PCollection<KV<BeamSQLRow, BeamSQLRow>> aggregatedStream = groupedStream.apply("aggregation", + Combine.<BeamSQLRow, BeamSQLRow, BeamSQLRow>groupedValues( + new BeamAggregationTransforms.AggregationCombineFn(aggCalls, inputRowType))); + + //4. flat KV to a single record + PCollection<BeamSQLRow> mergedStream = aggregatedStream.apply("mergeRecord", + ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord( + BeamSQLRecordType.from(prepareFinalRowType()), aggCalls))); + + //assert function BeamAggregationTransform.AggregationGroupByKeyFn + PAssert.that(exGroupByStream).containsInAnyOrder(prepareResultOfAggregationGroupByKeyFn()); + + //assert BeamAggregationTransform.AggregationCombineFn + PAssert.that(aggregatedStream).containsInAnyOrder(prepareResultOfAggregationCombineFn()); + + //assert BeamAggregationTransform.MergeAggregationRecord + PAssert.that(mergedStream).containsInAnyOrder(prepareResultOfMergeAggregationRecord()); + + p.run(); +} + + private void setupEnvironment() { + regiesterCoder(); + prepareAggregationCalls(); + } + + /** + * Add Coders in BeamSQL. + */ + private void regiesterCoder() { + CoderRegistry cr = p.getCoderRegistry(); + cr.registerCoder(BeamSQLRow.class, BeamSqlRowCoder.of()); + cr.registerCoder(BeamSQLRecordType.class, BeamSQLRecordTypeCoder.of()); + } + + /** + * create list of all {@link AggregateCall}. + */ + @SuppressWarnings("deprecation") + private void prepareAggregationCalls() { + //aggregations for all data type + aggCalls = new ArrayList<>(); + aggCalls.add( + new AggregateCall(new SqlCountAggFunction(), false, + Arrays.<Integer>asList(), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), + "count") + ); + aggCalls.add( + new AggregateCall(new SqlSumAggFunction( + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT)), false, + Arrays.<Integer>asList(1), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), + "sum1") + ); + aggCalls.add( + new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, + Arrays.<Integer>asList(1), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), + "avg1") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, + Arrays.<Integer>asList(1), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), + "max1") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, + Arrays.<Integer>asList(1), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT), + "min1") + ); + + aggCalls.add( + new AggregateCall(new SqlSumAggFunction( + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT)), false, + Arrays.<Integer>asList(2), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), + "sum2") + ); + aggCalls.add( + new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, + Arrays.<Integer>asList(2), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), + "avg2") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, + Arrays.<Integer>asList(2), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), + "max2") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, + Arrays.<Integer>asList(2), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT), + "min2") + ); + + aggCalls.add( + new AggregateCall( + new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT)), + false, + Arrays.<Integer>asList(3), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), + "sum3") + ); + aggCalls.add( + new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, + Arrays.<Integer>asList(3), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), + "avg3") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, + Arrays.<Integer>asList(3), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), + "max3") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, + Arrays.<Integer>asList(3), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT), + "min3") + ); + + aggCalls.add( + new AggregateCall( + new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT)), + false, + Arrays.<Integer>asList(4), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), + "sum4") + ); + aggCalls.add( + new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, + Arrays.<Integer>asList(4), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), + "avg4") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, + Arrays.<Integer>asList(4), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), + "max4") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, + Arrays.<Integer>asList(4), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT), + "min4") + ); + + aggCalls.add( + new AggregateCall( + new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE)), + false, + Arrays.<Integer>asList(5), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), + "sum5") + ); + aggCalls.add( + new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, + Arrays.<Integer>asList(5), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), + "avg5") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, + Arrays.<Integer>asList(5), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), + "max5") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, + Arrays.<Integer>asList(5), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE), + "min5") + ); + + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, + Arrays.<Integer>asList(7), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP), + "max7") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, + Arrays.<Integer>asList(7), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP), + "min7") + ); + + aggCalls.add( + new AggregateCall( + new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER)), + false, + Arrays.<Integer>asList(8), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), + "sum8") + ); + aggCalls.add( + new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false, + Arrays.<Integer>asList(8), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), + "avg8") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false, + Arrays.<Integer>asList(8), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), + "max8") + ); + aggCalls.add( + new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false, + Arrays.<Integer>asList(8), + new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER), + "min8") + ); + } + + /** + * expected results after {@link BeamAggregationTransforms.AggregationGroupByKeyFn}. + */ + private List<KV<BeamSQLRow, BeamSQLRow>> prepareResultOfAggregationGroupByKeyFn() { + return Arrays.asList( + KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(0).getInteger(0))), + inputRows.get(0)), + KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(1).getInteger(0))), + inputRows.get(1)), + KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(2).getInteger(0))), + inputRows.get(2)), + KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(3).getInteger(0))), + inputRows.get(3))); + } + + /** + * expected results after {@link BeamAggregationTransforms.AggregationCombineFn}. + */ + private List<KV<BeamSQLRow, BeamSQLRow>> prepareResultOfAggregationCombineFn() + throws ParseException { + BeamSQLRecordType aggPartType = initTypeOfSqlRow( + Arrays.asList(KV.of("count", SqlTypeName.BIGINT), + + KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT), + KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT), + + KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT), + KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT), + + KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT), + KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT), + + KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT), + KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT), + + KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE), + KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE), + + KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP), + + KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER), + KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER) + )); + return Arrays.asList( + KV.of(new BeamSQLRow(keyType, Arrays.<Object>asList(inputRows.get(0).getInteger(0))), + new BeamSQLRow(aggPartType, Arrays.<Object>asList( + 4L, + 10000L, 2500L, 4000L, 1000L, + (short) 10, (short) 2, (short) 4, (short) 1, + (byte) 10, (byte) 2, (byte) 4, (byte) 1, + 10.0F, 2.5F, 4.0F, 1.0F, + 10.0, 2.5, 4.0, 1.0, + format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"), + 10, 2, 4, 1 + ))) + ); + } + + /** + * Row type of final output row. + */ + private RelDataType prepareFinalRowType() { + FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder(); + List<KV<String, SqlTypeName>> columnMetadata = + Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER), KV.of("count", SqlTypeName.BIGINT), + + KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT), + KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT), + + KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT), + KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT), + + KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT), + KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT), + + KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT), + KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT), + + KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE), + KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE), + + KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP), + + KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER), + KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER) + ); + for (KV<String, SqlTypeName> cm : columnMetadata) { + builder.add(cm.getKey(), cm.getValue()); + } + return builder.build(); + } + + /** + * expected results after {@link BeamAggregationTransforms.MergeAggregationRecord}. + */ + private BeamSQLRow prepareResultOfMergeAggregationRecord() throws ParseException { + return new BeamSQLRow(BeamSQLRecordType.from(prepareFinalRowType()), Arrays.<Object>asList( + 1, 4L, + 10000L, 2500L, 4000L, 1000L, + (short) 10, (short) 2, (short) 4, (short) 1, + (byte) 10, (byte) 2, (byte) 4, (byte) 1, + 10.0F, 2.5F, 4.0F, 1.0F, + 10.0, 2.5, 4.0, 1.0, + format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"), + 10, 2, 4, 1 + )); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/f728fbe5/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java ---------------------------------------------------------------------- diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java new file mode 100644 index 0000000..820d7f5 --- /dev/null +++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamTransformBaseTest.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * <p> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.dsls.sql.schema.transform; + +import java.text.DateFormat; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Arrays; +import java.util.List; +import org.apache.beam.dsls.sql.planner.BeamQueryPlanner; +import org.apache.beam.dsls.sql.schema.BeamSQLRecordType; +import org.apache.beam.dsls.sql.schema.BeamSQLRow; +import org.apache.beam.sdk.values.KV; +import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder; +import org.apache.calcite.sql.type.SqlTypeName; +import org.junit.BeforeClass; + +/** + * shared methods to test PTransforms which execute Beam SQL steps. + * + */ +public class BeamTransformBaseTest { + public static DateFormat format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + + public static BeamSQLRecordType inputRowType; + public static List<BeamSQLRow> inputRows; + + @BeforeClass + public static void prepareInput() throws NumberFormatException, ParseException{ + List<KV<String, SqlTypeName>> columnMetadata = Arrays.asList( + KV.of("f_int", SqlTypeName.INTEGER), KV.of("f_long", SqlTypeName.BIGINT), + KV.of("f_short", SqlTypeName.SMALLINT), KV.of("f_byte", SqlTypeName.TINYINT), + KV.of("f_float", SqlTypeName.FLOAT), KV.of("f_double", SqlTypeName.DOUBLE), + KV.of("f_string", SqlTypeName.VARCHAR), KV.of("f_timestamp", SqlTypeName.TIMESTAMP), + KV.of("f_int2", SqlTypeName.INTEGER) + ); + inputRowType = initTypeOfSqlRow(columnMetadata); + inputRows = Arrays.asList( + initBeamSqlRow(columnMetadata, + Arrays.<Object>asList(1, 1000L, Short.valueOf("1"), Byte.valueOf("1"), 1.0F, 1.0, + "string_row1", format.parse("2017-01-01 01:01:03"), 1)), + initBeamSqlRow(columnMetadata, + Arrays.<Object>asList(1, 2000L, Short.valueOf("2"), Byte.valueOf("2"), 2.0F, 2.0, + "string_row2", format.parse("2017-01-01 01:02:03"), 2)), + initBeamSqlRow(columnMetadata, + Arrays.<Object>asList(1, 3000L, Short.valueOf("3"), Byte.valueOf("3"), 3.0F, 3.0, + "string_row3", format.parse("2017-01-01 01:03:03"), 3)), + initBeamSqlRow(columnMetadata, Arrays.<Object>asList(1, 4000L, Short.valueOf("4"), + Byte.valueOf("4"), 4.0F, 4.0, "string_row4", format.parse("2017-01-01 02:04:03"), 4))); + } + + /** + * create a {@code BeamSQLRecordType} for given column metadata. + */ + public static BeamSQLRecordType initTypeOfSqlRow(List<KV<String, SqlTypeName>> columnMetadata){ + FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder(); + for (KV<String, SqlTypeName> cm : columnMetadata) { + builder.add(cm.getKey(), cm.getValue()); + } + return BeamSQLRecordType.from(builder.build()); + } + + /** + * Create an empty row with given column metadata. + */ + public static BeamSQLRow initBeamSqlRow(List<KV<String, SqlTypeName>> columnMetadata) { + return initBeamSqlRow(columnMetadata, Arrays.asList()); + } + + /** + * Create a row with given column metadata, and values for each column. + * + */ + public static BeamSQLRow initBeamSqlRow(List<KV<String, SqlTypeName>> columnMetadata, + List<Object> rowValues){ + BeamSQLRecordType rowType = initTypeOfSqlRow(columnMetadata); + + return new BeamSQLRow(rowType, rowValues); + } + +}
