This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit f5c99c6f2612bc2ae437e85f5c44cae50f631e4e Author: Sergey Nuyanzin <[email protected]> AuthorDate: Wed Dec 15 18:21:28 2021 +0100 [FLINK-17321][table] Add support casting of map to map and multiset to multiset This closes #18287. --- .../functions/casting/CastRuleProvider.java | 1 + .../MapToMapAndMultisetToMultisetCastRule.java | 198 +++++++++++++++++++++ .../planner/functions/CastFunctionITCase.java | 45 ++++- .../planner/functions/casting/CastRulesTest.java | 59 ++++++ 4 files changed, 297 insertions(+), 6 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java index 5083519..961e81f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/CastRuleProvider.java @@ -81,6 +81,7 @@ public class CastRuleProvider { .addRule(RawToBinaryCastRule.INSTANCE) // Collection rules .addRule(ArrayToArrayCastRule.INSTANCE) + .addRule(MapToMapAndMultisetToMultisetCastRule.INSTANCE) .addRule(RowToRowCastRule.INSTANCE) // Special rules .addRule(CharVarCharTrimPadCastRule.INSTANCE) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapToMapAndMultisetToMultisetCastRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapToMapAndMultisetToMultisetCastRule.java new file mode 100644 index 0000000..89e0351 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/casting/MapToMapAndMultisetToMultisetCastRule.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions.casting; + +import org.apache.flink.table.data.GenericMapData; +import org.apache.flink.table.data.MapData; +import org.apache.flink.table.types.logical.IntType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.MapType; +import org.apache.flink.table.types.logical.MultisetType; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.table.planner.codegen.CodeGenUtils.boxedTypeTermForType; +import static org.apache.flink.table.planner.codegen.CodeGenUtils.className; +import static org.apache.flink.table.planner.codegen.CodeGenUtils.newName; +import static org.apache.flink.table.planner.codegen.CodeGenUtils.rowFieldReadAccess; +import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.constructorCall; +import static org.apache.flink.table.planner.functions.casting.CastRuleUtils.methodCall; + +/** + * {@link LogicalTypeRoot#MAP} to {@link LogicalTypeRoot#MAP} and {@link LogicalTypeRoot#MULTISET} + * to {@link LogicalTypeRoot#MULTISET} cast rule. + */ +class MapToMapAndMultisetToMultisetCastRule + extends AbstractNullAwareCodeGeneratorCastRule<MapData, MapData> { + + static final MapToMapAndMultisetToMultisetCastRule INSTANCE = + new MapToMapAndMultisetToMultisetCastRule(); + + private MapToMapAndMultisetToMultisetCastRule() { + super( + CastRulePredicate.builder() + .predicate( + MapToMapAndMultisetToMultisetCastRule + ::isValidMapToMapOrMultisetToMultisetCasting) + .build()); + } + + private static boolean isValidMapToMapOrMultisetToMultisetCasting( + LogicalType input, LogicalType target) { + return input.is(LogicalTypeRoot.MAP) + && target.is(LogicalTypeRoot.MAP) + && CastRuleProvider.resolve( + ((MapType) input).getKeyType(), + ((MapType) target).getKeyType()) + != null + && CastRuleProvider.resolve( + ((MapType) input).getValueType(), + ((MapType) target).getValueType()) + != null + || input.is(LogicalTypeRoot.MULTISET) + && target.is(LogicalTypeRoot.MULTISET) + && CastRuleProvider.resolve( + ((MultisetType) input).getElementType(), + ((MultisetType) target).getElementType()) + != null; + } + + /* Example generated code for MULTISET<INT> -> MULTISET<FLOAT>: + org.apache.flink.table.data.MapData _myInput = ((org.apache.flink.table.data.MapData)(_myInputObj)); + boolean _myInputIsNull = _myInputObj == null; + boolean isNull$0; + org.apache.flink.table.data.MapData result$1; + float result$2; + isNull$0 = _myInputIsNull; + if (!isNull$0) { + java.util.Map map$838 = new java.util.HashMap(); + for (int i$841 = 0; i$841 < _myInput.size(); i$841++) { + java.lang.Float key$839 = null; + java.lang.Integer value$840 = null; + if (!_myInput.keyArray().isNullAt(i$841)) { + result$2 = ((float)(_myInput.keyArray().getInt(i$841))); + key$839 = result$2; + } + value$840 = _myInput.valueArray().getInt(i$841); + map$838.put(key$839, value$840); + } + result$1 = new org.apache.flink.table.data.GenericMapData(map$838); + isNull$0 = result$1 == null; + } else { + result$1 = null; + } + return result$1; + + */ + @Override + protected String generateCodeBlockInternal( + CodeGeneratorCastRule.Context context, + String inputTerm, + String returnVariable, + LogicalType inputLogicalType, + LogicalType targetLogicalType) { + final LogicalType innerInputKeyType; + final LogicalType innerInputValueType; + + final LogicalType innerTargetKeyType; + final LogicalType innerTargetValueType; + if (inputLogicalType.is(LogicalTypeRoot.MULTISET)) { + innerInputKeyType = ((MultisetType) inputLogicalType).getElementType(); + innerInputValueType = new IntType(false); + innerTargetKeyType = ((MultisetType) targetLogicalType).getElementType(); + innerTargetValueType = new IntType(false); + } else { + innerInputKeyType = ((MapType) inputLogicalType).getKeyType(); + innerInputValueType = ((MapType) inputLogicalType).getValueType(); + innerTargetKeyType = ((MapType) targetLogicalType).getKeyType(); + innerTargetValueType = ((MapType) targetLogicalType).getValueType(); + } + + final String innerTargetKeyTypeTerm = boxedTypeTermForType(innerTargetKeyType); + final String innerTargetValueTypeTerm = boxedTypeTermForType(innerTargetValueType); + final String keyArrayTerm = methodCall(inputTerm, "keyArray"); + final String valueArrayTerm = methodCall(inputTerm, "valueArray"); + final String size = methodCall(inputTerm, "size"); + final String map = newName("map"); + final String key = newName("key"); + final String value = newName("value"); + + return new CastRuleUtils.CodeWriter() + .declStmt(className(Map.class), map, constructorCall(HashMap.class)) + .forStmt( + size, + (index, codeWriter) -> { + final CastCodeBlock keyCodeBlock = + CastRuleProvider.generateAlwaysNonNullCodeBlock( + context, + rowFieldReadAccess( + index, keyArrayTerm, innerInputKeyType), + innerInputKeyType, + innerTargetKeyType); + assert keyCodeBlock != null; + + final CastCodeBlock valueCodeBlock = + CastRuleProvider.generateAlwaysNonNullCodeBlock( + context, + rowFieldReadAccess( + index, valueArrayTerm, innerInputValueType), + innerInputValueType, + innerTargetValueType); + assert valueCodeBlock != null; + + codeWriter + .declStmt(innerTargetKeyTypeTerm, key, null) + .declStmt(innerTargetValueTypeTerm, value, null); + if (innerTargetKeyType.isNullable()) { + codeWriter.ifStmt( + "!" + methodCall(keyArrayTerm, "isNullAt", index), + thenWriter -> + thenWriter + .append(keyCodeBlock) + .assignStmt( + key, keyCodeBlock.getReturnTerm())); + } else { + codeWriter + .append(keyCodeBlock) + .assignStmt(key, keyCodeBlock.getReturnTerm()); + } + + if (inputLogicalType.is(LogicalTypeRoot.MAP) + && innerTargetValueType.isNullable()) { + codeWriter.ifStmt( + "!" + methodCall(valueArrayTerm, "isNullAt", index), + thenWriter -> + thenWriter + .append(valueCodeBlock) + .assignStmt( + value, + valueCodeBlock.getReturnTerm())); + } else { + codeWriter + .append(valueCodeBlock) + .assignStmt(value, valueCodeBlock.getReturnTerm()); + } + codeWriter.stmt(methodCall(map, "put", key, value)); + }) + .assignStmt(returnVariable, constructorCall(GenericMapData.class, map)) + .toString(); + } +} diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java index a0449a8..ae9e595 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/CastFunctionITCase.java @@ -40,10 +40,13 @@ import java.time.LocalTime; import java.time.Period; import java.time.ZoneId; import java.time.ZoneOffset; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.apache.flink.table.api.DataTypes.ARRAY; import static org.apache.flink.table.api.DataTypes.BIGINT; @@ -58,6 +61,7 @@ import static org.apache.flink.table.api.DataTypes.DOUBLE; import static org.apache.flink.table.api.DataTypes.FLOAT; import static org.apache.flink.table.api.DataTypes.INT; import static org.apache.flink.table.api.DataTypes.INTERVAL; +import static org.apache.flink.table.api.DataTypes.MAP; import static org.apache.flink.table.api.DataTypes.MONTH; import static org.apache.flink.table.api.DataTypes.ROW; import static org.apache.flink.table.api.DataTypes.SECOND; @@ -1142,14 +1146,27 @@ public class CastFunctionITCase extends BuiltInFunctionTestBase { public static List<TestSpec> constructedTypes() { return Arrays.asList( - // https://issues.apache.org/jira/browse/FLINK-17321 - // MULTISET - // MAP + CastTestSpecBuilder.testCastTo(MAP(STRING(), STRING())) + .fromCase(MAP(FLOAT(), DOUBLE()), null, null) + .fromCase( + MAP(INT(), INT()), + Collections.singletonMap(1, 2), + Collections.singletonMap("1", "2")) + .build(), + // https://issues.apache.org/jira/browse/FLINK-25567 + // CastTestSpecBuilder.testCastTo(MULTISET(STRING())) + // .fromCase(MULTISET(TIMESTAMP()), null, null) + // .fromCase( + // MULTISET(INT()), + // map(entry(1, 2), entry(3, 4)), + // map(entry("1", 2), entry("3", 4))) + // .build(), CastTestSpecBuilder.testCastTo(ARRAY(INT())) .fromCase(ARRAY(INT()), null, null) - // https://issues.apache.org/jira/browse/FLINK-17321 - // .fromCase(ARRAY(STRING()), new String[] {'1', '2', '3'}, new Integer[] - // {1, 2, 3}) + .fromCase( + ARRAY(STRING()), + new String[] {"1", "2", "3"}, + new Integer[] {1, 2, 3}) // https://issues.apache.org/jira/browse/FLINK-24425 Cast from corresponding // single type // .fromCase(INT(), DEFAULT_POSITIVE_INT, new int[] {DEFAULT_POSITIVE_INT}) @@ -1314,4 +1331,20 @@ public class CastFunctionITCase extends BuiltInFunctionTestBase { private static boolean isTimestampToNumeric(LogicalType srcType, LogicalType trgType) { return srcType.is(LogicalTypeFamily.TIMESTAMP) && trgType.is(LogicalTypeFamily.NUMERIC); } + + private static <K, V> Map.Entry<K, V> entry(K k, V v) { + return new AbstractMap.SimpleImmutableEntry<>(k, v); + } + + @SafeVarargs + private static <K, V> Map<K, V> map(Map.Entry<K, V>... entries) { + if (entries == null) { + return Collections.emptyMap(); + } + Map<K, V> map = new HashMap<>(); + for (Map.Entry<K, V> entry : entries) { + map.put(entry.getKey(), entry.getValue()); + } + return map; + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java index 9bb13a9..a0327c5 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/casting/CastRulesTest.java @@ -1131,6 +1131,56 @@ class CastRulesTest { new GenericArrayData(new Integer[] {3}) }), NullPointerException.class), + CastTestSpecBuilder.testCastTo(MAP(DOUBLE().notNull(), DOUBLE().notNull())) + .fromCase( + MAP(INT().nullable(), INT().nullable()), + mapData(entry(1, 2)), + mapData(entry(1d, 2d))), + CastTestSpecBuilder.testCastTo(MAP(BIGINT().nullable(), BIGINT().nullable())) + .fromCase( + MAP(INT().nullable(), INT().nullable()), + mapData(entry(1, 2)), + mapData(entry(1L, 2L))), + CastTestSpecBuilder.testCastTo(MAP(BIGINT().nullable(), BIGINT().nullable())) + .fromCase( + MAP(INT().nullable(), INT().nullable()), + mapData(entry(1, 2), entry(null, 3), entry(4, null)), + mapData(entry(1L, 2L), entry(null, 3L), entry(4L, null))), + CastTestSpecBuilder.testCastTo(MAP(STRING().nullable(), STRING().nullable())) + .fromCase( + MAP(TIMESTAMP().nullable(), DOUBLE().nullable()), + mapData(entry(TIMESTAMP, 123.456)), + mapData(entry(TIMESTAMP_STRING, fromString("123.456")))), + CastTestSpecBuilder.testCastTo(MAP(STRING().notNull(), STRING().nullable())) + .fail( + MAP(INT().nullable(), DOUBLE().nullable()), + mapData(entry(null, 1d)), + NullPointerException.class), + CastTestSpecBuilder.testCastTo(MAP(STRING().notNull(), STRING().notNull())) + .fail( + MAP(INT().nullable(), DOUBLE().nullable()), + mapData(entry(123, null)), + NullPointerException.class), + CastTestSpecBuilder.testCastTo(MULTISET(DOUBLE().notNull())) + .fromCase( + MULTISET(INT().nullable()), + mapData(entry(1, 1)), + mapData(entry(1d, 1))), + CastTestSpecBuilder.testCastTo(MULTISET(STRING().notNull())) + .fromCase( + MULTISET(INT().nullable()), + mapData(entry(1, 1)), + mapData(entry(fromString("1"), 1))), + CastTestSpecBuilder.testCastTo(MULTISET(FLOAT().nullable())) + .fromCase( + MULTISET(INT().nullable()), + mapData(entry(null, 1)), + mapData(entry(null, 1))), + CastTestSpecBuilder.testCastTo(MULTISET(STRING().notNull())) + .fail( + MULTISET(INT().nullable()), + mapData(entry(null, 1)), + NullPointerException.class), CastTestSpecBuilder.testCastTo( ROW(BIGINT().notNull(), BIGINT(), STRING(), ARRAY(STRING()))) .fromCase( @@ -1174,6 +1224,15 @@ class CastRulesTest { fromString("b"), fromString("c") }))), + CastTestSpecBuilder.testCastTo( + ROW(MAP(BIGINT().notNull(), STRING()), MULTISET(STRING()))) + .fromCase( + ROW(MAP(INT().notNull(), INT()), MULTISET(TIMESTAMP())), + GenericRowData.of( + mapData(entry(1, 2)), mapData(entry(TIMESTAMP, 1))), + GenericRowData.of( + mapData(entry(1L, fromString("2"))), + mapData(entry(TIMESTAMP_STRING, 1)))), CastTestSpecBuilder.testCastTo(MY_STRUCTURED_TYPE) .fromCase( ROW(INT().notNull(), INT(), TIME(5), ARRAY(TIMESTAMP())),
