This is an automated email from the ASF dual-hosted git repository.
ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new f4fcccc BEAM-12166:Beam Sql - Combine Accumulator return Map fails
with class cast exception
new fc873f0 Merge pull request #14534 from anupd22/BEAM-12166
f4fcccc is described below
commit f4fccccd726481cacc47182ccf3fc12b7c93012b
Author: Anup D <[email protected]>
AuthorDate: Wed Apr 14 20:40:43 2021 +0530
BEAM-12166:Beam Sql - Combine Accumulator return Map fails with class cast
exception
---
.../extensions/sql/impl/utils/CalciteUtils.java | 21 ++--
.../sdk/extensions/sql/BeamSqlDslUdfUdafTest.java | 114 +++++++++++++++++++++
2 files changed, 128 insertions(+), 7 deletions(-)
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
index 10ad199..34664ac 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/utils/CalciteUtils.java
@@ -280,7 +280,8 @@ public class CalciteUtils {
/**
* SQL-Java type mapping, with specified Beam rules: <br>
* 1. redirect {@link AbstractInstant} to {@link Date} so Calcite can
recognize it. <br>
- * 2. For a list, the component type is needed to create a Sql array type.
+ * 2. For a list, the component type is needed to create a Sql array type.
<br>
+ * 3. For a Map, the component type is needed to create a Sql map type.
*
* @param type
* @return Calcite RelDataType
@@ -291,13 +292,19 @@ public class CalciteUtils {
return typeFactory.createJavaType(Date.class);
} else if (type instanceof Class &&
ByteString.class.isAssignableFrom((Class<?>) type)) {
return typeFactory.createJavaType(byte[].class);
- } else if (type instanceof ParameterizedType
- && java.util.List.class.isAssignableFrom(
- (Class<?>) ((ParameterizedType) type).getRawType())) {
+ } else if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
- Class<?> genericType = (Class<?>)
parameterizedType.getActualTypeArguments()[0];
- RelDataType collectionElementType =
typeFactory.createJavaType(genericType);
- return typeFactory.createArrayType(collectionElementType,
UNLIMITED_ARRAY_SIZE);
+ if (java.util.List.class.isAssignableFrom((Class<?>)
parameterizedType.getRawType())) {
+ Class<?> genericType = (Class<?>)
parameterizedType.getActualTypeArguments()[0];
+ RelDataType collectionElementType =
typeFactory.createJavaType(genericType);
+ return typeFactory.createArrayType(collectionElementType,
UNLIMITED_ARRAY_SIZE);
+ } else if (java.util.Map.class.isAssignableFrom((Class<?>)
parameterizedType.getRawType())) {
+ Class<?> genericKeyType = (Class<?>)
parameterizedType.getActualTypeArguments()[0];
+ Class<?> genericValueType = (Class<?>)
parameterizedType.getActualTypeArguments()[1];
+ RelDataType mapElementKeyType =
typeFactory.createJavaType(genericKeyType);
+ RelDataType mapElementValueType =
typeFactory.createJavaType(genericValueType);
+ return typeFactory.createMapType(mapElementKeyType,
mapElementValueType);
+ }
}
return typeFactory.createJavaType((Class) type);
}
diff --git
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
index b563ea1..9f4c9a2 100644
---
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
+++
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
@@ -27,7 +27,11 @@ import java.sql.Time;
import java.sql.Timestamp;
import java.time.LocalDate;
import java.time.LocalTime;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.apache.beam.sdk.extensions.sql.impl.BeamCalciteTable;
@@ -137,6 +141,56 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
pipeline.run().waitUntilFinish();
}
+ /** GROUP-BY with UDAF that returns Map. */
+ @Test
+ public void testUdafWithMapOutput() throws Exception {
+ Schema resultType =
+ Schema.builder()
+ .addInt32Field("f_int2")
+ .addMapField("squareAndAccumulateInMap", FieldType.STRING,
FieldType.INT32)
+ .build();
+
+ Map<String, Integer> resultMap = new HashMap<String, Integer>();
+ resultMap.put("squareOf-1", 1);
+ resultMap.put("squareOf-2", 4);
+ resultMap.put("squareOf-3", 9);
+ resultMap.put("squareOf-4", 16);
+ Row row = Row.withSchema(resultType).addValues(0, resultMap).build();
+
+ String sql =
+ "SELECT f_int2,squareAndAccumulateInMap(f_int) AS
`squareAndAccumulateInMap` FROM PCOLLECTION GROUP BY f_int2";
+ PCollection<Row> result =
+ boundedInput1.apply(
+ "testUdafWithMapOutput",
+ SqlTransform.query(sql)
+ .registerUdaf("squareAndAccumulateInMap", new
SquareAndAccumulateInMap()));
+ PAssert.that(result).containsInAnyOrder(row);
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ /** GROUP-BY with UDAF that returns List. */
+ @Test
+ public void testUdafWithListOutput() throws Exception {
+ Schema resultType =
+ Schema.builder()
+ .addInt32Field("f_int2")
+ .addArrayField("squareAndAccumulateInList", FieldType.INT32)
+ .build();
+ Row row = Row.withSchema(resultType).addValue(0).addArray(Arrays.asList(1,
4, 9, 16)).build();
+
+ String sql =
+ "SELECT f_int2,squareAndAccumulateInList(f_int) AS
`squareAndAccumulateInList` FROM PCOLLECTION GROUP BY f_int2";
+ PCollection<Row> result =
+ boundedInput1.apply(
+ "testUdafWithListOutput",
+ SqlTransform.query(sql)
+ .registerUdaf("squareAndAccumulateInList", new
SquareAndAccumulateInList()));
+ PAssert.that(result).containsInAnyOrder(row);
+
+ pipeline.run().waitUntilFinish();
+ }
+
@Test
public void testUdfWithListOutput() throws Exception {
Schema resultType = Schema.builder().addArrayField("array_field",
FieldType.INT64).build();
@@ -458,4 +512,64 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
return BeamCalciteTable.of(new TestBoundedTable(schema).addRows(values));
}
}
+
+ /** UDAF(CombineFn) for test, which squares each input, tags it and returns
them all in a Map. */
+ public static class SquareAndAccumulateInMap
+ extends CombineFn<Integer, Map<String, Integer>, Map<String, Integer>> {
+ @Override
+ public Map<String, Integer> createAccumulator() {
+ return new HashMap<String, Integer>();
+ }
+
+ @Override
+ public Map<String, Integer> addInput(Map<String, Integer> accumulator,
Integer input) {
+ accumulator.put("squareOf-" + input, input * input);
+ return accumulator;
+ }
+
+ @Override
+ public Map<String, Integer> mergeAccumulators(Iterable<Map<String,
Integer>> accumulators) {
+ Map<String, Integer> merged = createAccumulator();
+ for (Map<String, Integer> accumulator : accumulators) {
+ merged.putAll(accumulator);
+ }
+ return merged;
+ }
+
+ @Override
+ public Map<String, Integer> extractOutput(Map<String, Integer>
accumulator) {
+ return accumulator;
+ }
+ }
+
+ /** UDAF(CombineFn) for test, which squares each input and returns them all
in a List. */
+ public static class SquareAndAccumulateInList
+ extends CombineFn<Integer, List<Integer>, List<Integer>> {
+
+ @Override
+ public List<Integer> createAccumulator() {
+ return new ArrayList<Integer>();
+ }
+
+ @Override
+ public List<Integer> addInput(List<Integer> accumulator, Integer input) {
+ accumulator.add(input * input);
+ return accumulator;
+ }
+
+ @Override
+ public List<Integer> mergeAccumulators(Iterable<List<Integer>>
accumulators) {
+ List<Integer> merged = createAccumulator();
+ for (List<Integer> accumulator : accumulators) {
+ merged.addAll(accumulator);
+ }
+ return merged;
+ }
+
+ @Override
+ public List<Integer> extractOutput(List<Integer> accumulator) {
+ Collections.sort(accumulator);
+ return accumulator;
+ }
+ }
}