This is an automated email from the ASF dual-hosted git repository.
twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 3832967cb73 [FLINK-39088][table] Fix upsert key preservation by
introducing injective cast checks for CAST
3832967cb73 is described below
commit 3832967cb736055b369be037c6394686e4e3b3ab
Author: Gustavo de Morais <[email protected]>
AuthorDate: Thu Feb 19 15:27:44 2026 +0100
[FLINK-39088][table] Fix upsert key preservation by introducing injective
cast checks for CAST
This closes #27603.
---
.../types/logical/utils/LogicalTypeCasts.java | 266 +++++++++++++++++++--
.../flink/table/types/LogicalTypeCastsTest.java | 219 +++++++++++++++++
.../plan/metadata/FlinkRelMdUniqueKeys.scala | 15 +-
.../planner/plan/stream/sql/TableSinkTest.xml | 48 ++++
.../plan/metadata/FlinkRelMdUniqueKeysTest.scala | 97 +++++++-
.../plan/metadata/FlinkRelMdUpsertKeysTest.scala | 75 +++++-
.../planner/plan/stream/sql/TableSinkTest.scala | 49 ++++
.../runtime/stream/sql/ChangelogSourceITCase.scala | 8 +-
8 files changed, 750 insertions(+), 27 deletions(-)
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
index 310bc71e29b..307f47f76d3 100644
---
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
@@ -31,6 +31,7 @@ import org.apache.flink.table.types.logical.VarBinaryType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.table.types.logical.YearMonthIntervalType;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
@@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
+import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import static
org.apache.flink.table.types.logical.LogicalTypeFamily.BINARY_STRING;
@@ -79,6 +81,8 @@ import static
org.apache.flink.table.types.logical.LogicalTypeRoot.VARCHAR;
import static
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getDayPrecision;
import static
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFractionalPrecision;
import static
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getLength;
+import static
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision;
+import static
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale;
import static
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getYearPrecision;
import static
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isSingleFieldInterval;
@@ -110,116 +114,203 @@ public final class LogicalTypeCasts {
private static final Map<LogicalTypeRoot, Set<LogicalTypeRoot>>
explicitCastingRules;
+ private static final Map<LogicalTypeRoot, List<InjectiveRule>>
injectiveRules;
+
+ // Types with deterministic, unique string representations (for injective
casts to STRING)
+ private static final LogicalTypeRoot[] STRING_INJECTIVE_SOURCES = {
+ TINYINT,
+ SMALLINT,
+ INTEGER,
+ BIGINT,
+ FLOAT,
+ DOUBLE,
+ BOOLEAN,
+ DATE,
+ TIME_WITHOUT_TIME_ZONE,
+ TIMESTAMP_WITHOUT_TIME_ZONE,
+ TIMESTAMP_WITH_LOCAL_TIME_ZONE
+ };
+
+ // ----- Injective cast conditions -----
+
+ /** Injective when the target length can hold any value of the source
length. */
+ private static final BiPredicate<LogicalType, LogicalType>
WHEN_LENGTH_FITS =
+ (source, target) -> getLength(target) >= getLength(source);
+
+ /** Injective when the target length can hold the source type's max string
representation. */
+ private static final BiPredicate<LogicalType, LogicalType>
WHEN_MAX_CHAR_LENGTH_FITS =
+ (source, target) -> getLength(target) >=
maxStringRepresentationLength(source);
+
+ /** Injective when source and target share identical precision. */
+ private static final BiPredicate<LogicalType, LogicalType>
WHEN_PRECISION_MATCHES =
+ (source, target) -> getPrecision(source) == getPrecision(target);
+
+ /** Injective when source and target share identical precision and scale
(DECIMAL). */
+ private static final BiPredicate<LogicalType, LogicalType>
WHEN_PRECISION_AND_SCALE_MATCH =
+ (source, target) ->
+ getPrecision(source) == getPrecision(target)
+ && getScale(source) == getScale(target);
+
static {
implicitCastingRules = new HashMap<>();
explicitCastingRules = new HashMap<>();
+ injectiveRules = new HashMap<>();
- // identity casts
-
+ // Identity casts: all types can be implicitly cast to themselves.
+ // Injective identity for parameterized types is declared per-type
with conditions.
for (LogicalTypeRoot typeRoot : allTypes()) {
castTo(typeRoot).implicitFrom(typeRoot).build();
}
- // cast specification
+ //
-----------------------------------------------------------------------------------------
+ // Character string types
+ //
-----------------------------------------------------------------------------------------
castTo(CHAR)
.implicitFrom(CHAR)
.explicitFromFamily(PREDEFINED, CONSTRUCTED)
.explicitFrom(RAW, NULL, STRUCTURED_TYPE)
+ .injectiveFrom(WHEN_LENGTH_FITS, CHAR)
+ .injectiveFrom(WHEN_MAX_CHAR_LENGTH_FITS,
STRING_INJECTIVE_SOURCES)
.build();
castTo(VARCHAR)
.implicitFromFamily(CHARACTER_STRING)
.explicitFromFamily(PREDEFINED, CONSTRUCTED)
.explicitFrom(RAW, NULL, STRUCTURED_TYPE)
+ .injectiveFrom(WHEN_LENGTH_FITS, CHAR, VARCHAR)
+ .injectiveFrom(WHEN_MAX_CHAR_LENGTH_FITS,
STRING_INJECTIVE_SOURCES)
.build();
- castTo(BOOLEAN)
- .implicitFrom(BOOLEAN)
- .explicitFromFamily(CHARACTER_STRING, INTEGER_NUMERIC)
- .build();
+ //
-----------------------------------------------------------------------------------------
+ // Binary string types
+ //
-----------------------------------------------------------------------------------------
castTo(BINARY)
.implicitFrom(BINARY)
.explicitFromFamily(CHARACTER_STRING)
- .explicitFrom(VARBINARY)
- .explicitFrom(RAW)
+ .explicitFrom(VARBINARY, RAW)
+ .injectiveFrom(WHEN_LENGTH_FITS, BINARY)
.build();
castTo(VARBINARY)
.implicitFromFamily(BINARY_STRING)
.explicitFromFamily(CHARACTER_STRING)
- .explicitFrom(BINARY)
- .explicitFrom(RAW)
+ .explicitFrom(BINARY, RAW)
+ .injectiveFrom(WHEN_LENGTH_FITS, BINARY, VARBINARY)
.build();
- castTo(DECIMAL)
- .implicitFromFamily(NUMERIC)
- .explicitFromFamily(CHARACTER_STRING, INTERVAL)
- .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
- .build();
+ //
-----------------------------------------------------------------------------------------
+ // Exact numeric types
+ //
-----------------------------------------------------------------------------------------
castTo(TINYINT)
.implicitFrom(TINYINT)
.explicitFromFamily(NUMERIC, CHARACTER_STRING, INTERVAL)
.explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+ .injectiveFrom(TINYINT)
.build();
castTo(SMALLINT)
.implicitFrom(TINYINT, SMALLINT)
.explicitFromFamily(NUMERIC, CHARACTER_STRING, INTERVAL)
.explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+ .injectiveFrom(TINYINT, SMALLINT)
.build();
castTo(INTEGER)
.implicitFrom(TINYINT, SMALLINT, INTEGER)
.explicitFromFamily(NUMERIC, CHARACTER_STRING, INTERVAL)
.explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+ .injectiveFrom(TINYINT, SMALLINT, INTEGER)
.build();
castTo(BIGINT)
.implicitFrom(TINYINT, SMALLINT, INTEGER, BIGINT)
.explicitFromFamily(NUMERIC, CHARACTER_STRING, INTERVAL)
.explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+ .injectiveFrom(TINYINT, SMALLINT, INTEGER, BIGINT)
.build();
+ castTo(DECIMAL)
+ .implicitFromFamily(NUMERIC)
+ .explicitFromFamily(CHARACTER_STRING, INTERVAL)
+ .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+ .injectiveFrom(WHEN_PRECISION_AND_SCALE_MATCH, DECIMAL)
+ .build();
+
+ //
-----------------------------------------------------------------------------------------
+ // Approximate numeric types
+ //
-----------------------------------------------------------------------------------------
+
castTo(FLOAT)
.implicitFrom(TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT,
DECIMAL)
.explicitFromFamily(NUMERIC, CHARACTER_STRING)
.explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+ .injectiveFrom(FLOAT)
.build();
castTo(DOUBLE)
.implicitFromFamily(NUMERIC)
.explicitFromFamily(CHARACTER_STRING)
.explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+ .injectiveFrom(DOUBLE)
.build();
+ //
-----------------------------------------------------------------------------------------
+ // Boolean type
+ //
-----------------------------------------------------------------------------------------
+
+ castTo(BOOLEAN)
+ .implicitFrom(BOOLEAN)
+ .explicitFromFamily(CHARACTER_STRING, INTEGER_NUMERIC)
+ .injectiveFrom(BOOLEAN)
+ .build();
+
+ //
-----------------------------------------------------------------------------------------
+ // Date and time types
+ //
-----------------------------------------------------------------------------------------
+
castTo(DATE)
.implicitFrom(DATE, TIMESTAMP_WITHOUT_TIME_ZONE)
.explicitFromFamily(TIMESTAMP, CHARACTER_STRING)
+ .injectiveFrom(DATE)
.build();
castTo(TIME_WITHOUT_TIME_ZONE)
.implicitFrom(TIME_WITHOUT_TIME_ZONE,
TIMESTAMP_WITHOUT_TIME_ZONE)
.explicitFromFamily(TIME, TIMESTAMP, CHARACTER_STRING)
+ .injectiveFrom(WHEN_PRECISION_MATCHES, TIME_WITHOUT_TIME_ZONE)
.build();
castTo(TIMESTAMP_WITHOUT_TIME_ZONE)
.implicitFrom(TIMESTAMP_WITHOUT_TIME_ZONE,
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
.explicitFromFamily(DATETIME, CHARACTER_STRING, NUMERIC)
+ .injectiveFrom(
+ WHEN_PRECISION_MATCHES,
+ TIMESTAMP_WITHOUT_TIME_ZONE,
+ TIMESTAMP_WITH_LOCAL_TIME_ZONE)
.build();
castTo(TIMESTAMP_WITH_TIME_ZONE)
.implicitFrom(TIMESTAMP_WITH_TIME_ZONE)
.explicitFromFamily(DATETIME, CHARACTER_STRING)
+ .injectiveFrom(WHEN_PRECISION_MATCHES,
TIMESTAMP_WITH_TIME_ZONE)
.build();
castTo(TIMESTAMP_WITH_LOCAL_TIME_ZONE)
.implicitFrom(TIMESTAMP_WITH_LOCAL_TIME_ZONE,
TIMESTAMP_WITHOUT_TIME_ZONE)
.explicitFromFamily(DATETIME, CHARACTER_STRING, NUMERIC)
+ .injectiveFrom(
+ WHEN_PRECISION_MATCHES,
+ TIMESTAMP_WITH_LOCAL_TIME_ZONE,
+ TIMESTAMP_WITHOUT_TIME_ZONE)
.build();
+ //
-----------------------------------------------------------------------------------------
+ // Interval types
+ //
-----------------------------------------------------------------------------------------
+
castTo(INTERVAL_YEAR_MONTH)
.implicitFrom(INTERVAL_YEAR_MONTH)
.explicitFromFamily(EXACT_NUMERIC, CHARACTER_STRING)
@@ -293,6 +384,118 @@ public final class LogicalTypeCasts {
return supportsCasting(sourceType, targetType, true);
}
+ /**
+ * Returns whether the cast from source type to target type is injective
+ * (uniqueness-preserving).
+ *
+ * <p>An injective cast guarantees that every distinct input value maps to
a distinct output
+ * value. This property is useful for upsert key tracking through
projections: if a key column
+ * is cast using an injective conversion, the uniqueness of the key is
preserved.
+ *
+ * <p>Injective casts are explicitly defined in the casting rules,
separate from implicit casts.
+ * Not all implicit casts are injective (e.g., TIMESTAMP → DATE loses time
information).
+ *
+ * <p>For constructed types (ROW), this method recursively checks if all
field casts are
+ * injective.
+ *
+ * @param sourceType the source type
+ * @param targetType the target type
+ * @return {@code true} if the cast preserves uniqueness
+ */
+ public static boolean supportsInjectiveCast(
+ final LogicalType sourceType, final LogicalType targetType) {
+ final LogicalTypeRoot sourceRoot = sourceType.getTypeRoot();
+ final LogicalTypeRoot targetRoot = targetType.getTypeRoot();
+
+ // Handle DISTINCT types by unwrapping
+ if (sourceRoot == DISTINCT_TYPE) {
+ return supportsInjectiveCast(((DistinctType)
sourceType).getSourceType(), targetType);
+ }
+ if (targetRoot == DISTINCT_TYPE) {
+ return supportsInjectiveCast(sourceType, ((DistinctType)
targetType).getSourceType());
+ }
+
+ // Handle NULL type
+ if (sourceRoot == NULL) {
+ return true;
+ }
+
+ // Handle constructed types (ROW, STRUCTURED_TYPE) with recursive
field checks
+ final boolean isSourceConstructed = sourceRoot == ROW || sourceRoot ==
STRUCTURED_TYPE;
+ final boolean isTargetConstructed = targetRoot == ROW || targetRoot ==
STRUCTURED_TYPE;
+ if (isSourceConstructed && isTargetConstructed) {
+ return supportsInjectiveConstructedCast(sourceType, targetType);
+ }
+
+ // Check declarative injective rules
+ final List<InjectiveRule> rules = injectiveRules.get(targetRoot);
+ if (rules == null) {
+ return false;
+ }
+ for (final InjectiveRule rule : rules) {
+ if (rule.test(sourceType, targetType)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private static boolean supportsInjectiveConstructedCast(
+ final LogicalType sourceType, final LogicalType targetType) {
+ final List<LogicalType> sourceChildren = sourceType.getChildren();
+ final List<LogicalType> targetChildren = targetType.getChildren();
+
+ if (sourceChildren.size() != targetChildren.size()) {
+ return false;
+ }
+
+ for (int i = 0; i < sourceChildren.size(); i++) {
+ if (!supportsInjectiveCast(sourceChildren.get(i),
targetChildren.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Returns the maximum number of characters needed to represent any value
of the given type as a
+ * string, or -1 if the type does not have a bounded deterministic string
representation that
+ * qualifies for injective casts.
+ */
+ private static int maxStringRepresentationLength(final LogicalType type) {
+ switch (type.getTypeRoot()) {
+ case BOOLEAN:
+ return 5; // "false"
+ case TINYINT:
+ return 4; // "-128"
+ case SMALLINT:
+ return 6; // "-32768"
+ case INTEGER:
+ return 11; // "-2147483648"
+ case BIGINT:
+ return 20; // "-9223372036854775808"
+ case FLOAT:
+ return 15; // e.g. "-3.4028235E38"
+ case DOUBLE:
+ return 24; // e.g. "-2.2250738585072014E-308"
+ case DATE:
+ return 10; // "9999-12-31"
+ case TIME_WITHOUT_TIME_ZONE:
+ {
+ final int p = getPrecision(type);
+ return p > 0 ? 9 + p : 8; // "HH:MM:SS" + optional
".fff..."
+ }
+ case TIMESTAMP_WITHOUT_TIME_ZONE:
+ case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
+ {
+ final int p = getPrecision(type);
+ return p > 0 ? 20 + p : 19; // "YYYY-MM-DD HH:MM:SS" +
optional ".fff..."
+ }
+ default:
+ return -1;
+ }
+ }
+
/**
* Returns whether the source type can be reinterpreted as the target type.
*
@@ -484,11 +687,30 @@ public final class LogicalTypeCasts {
return LogicalTypeRoot.values();
}
+ /** A declarative rule describing when a cast from certain source types is
injective. */
+ private static final class InjectiveRule {
+
+ private final Set<LogicalTypeRoot> sourceRoots;
+ private final BiPredicate<LogicalType, LogicalType> condition;
+
+ InjectiveRule(
+ Set<LogicalTypeRoot> sourceRoots, BiPredicate<LogicalType,
LogicalType> condition) {
+ this.sourceRoots = sourceRoots;
+ this.condition = condition;
+ }
+
+ boolean test(LogicalType sourceType, LogicalType targetType) {
+ return sourceRoots.contains(sourceType.getTypeRoot())
+ && condition.test(sourceType, targetType);
+ }
+ }
+
private static class CastingRuleBuilder {
private final LogicalTypeRoot targetType;
private final Set<LogicalTypeRoot> implicitSourceTypes = new
HashSet<>();
private final Set<LogicalTypeRoot> explicitSourceTypes = new
HashSet<>();
+ private final List<InjectiveRule> injectiveRuleList = new
ArrayList<>();
CastingRuleBuilder(LogicalTypeRoot targetType) {
this.targetType = targetType;
@@ -526,9 +748,21 @@ public final class LogicalTypeCasts {
return this;
}
+ CastingRuleBuilder injectiveFrom(LogicalTypeRoot... sourceTypes) {
+ return injectiveFrom((s, t) -> true, sourceTypes);
+ }
+
+ CastingRuleBuilder injectiveFrom(
+ BiPredicate<LogicalType, LogicalType> condition,
LogicalTypeRoot... sourceTypes) {
+ injectiveRuleList.add(
+ new InjectiveRule(new
HashSet<>(Arrays.asList(sourceTypes)), condition));
+ return this;
+ }
+
void build() {
implicitCastingRules.put(targetType, implicitSourceTypes);
explicitCastingRules.put(targetType, explicitSourceTypes);
+ injectiveRules.put(targetType, injectiveRuleList);
}
}
diff --git
a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
index 383874a29c5..7428e5adc10 100644
---
a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
+++
b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
@@ -26,7 +26,10 @@ import org.apache.flink.table.types.logical.ArrayType;
import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.BinaryType;
import org.apache.flink.table.types.logical.BooleanType;
+import org.apache.flink.table.types.logical.CharType;
+import org.apache.flink.table.types.logical.DateType;
import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.DoubleType;
import org.apache.flink.table.types.logical.FloatType;
import org.apache.flink.table.types.logical.IntType;
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
@@ -38,10 +41,13 @@ import
org.apache.flink.table.types.logical.RowType.RowField;
import org.apache.flink.table.types.logical.SmallIntType;
import org.apache.flink.table.types.logical.StructuredType;
import org.apache.flink.table.types.logical.StructuredType.StructuredAttribute;
+import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.table.types.logical.TimestampType;
import org.apache.flink.table.types.logical.TinyIntType;
+import org.apache.flink.table.types.logical.VarBinaryType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.table.types.logical.YearMonthIntervalType;
+import org.apache.flink.table.types.logical.ZonedTimestampType;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
import org.junit.jupiter.api.parallel.Execution;
@@ -273,4 +279,217 @@ class LogicalTypeCastsTest {
.as("Supports explicit casting")
.isEqualTo(supportsExplicit);
}
+
+ /**
+ * Test data for injective cast tests. Each argument contains:
(sourceType, targetType,
+ * expectedInjective).
+ */
+ private static Stream<Arguments> injectiveCastTestData() {
+ return Stream.of(
+ // Integer widenings are injective
+ Arguments.of(new SmallIntType(), new BigIntType(), true),
+ Arguments.of(new IntType(), new BigIntType(), true),
+ Arguments.of(new TinyIntType(), new IntType(), true),
+ Arguments.of(new TinyIntType(), new SmallIntType(), true),
+
+ // Explicit casts to STRING from integer types are injective
+ Arguments.of(new TinyIntType(), VarCharType.STRING_TYPE, true),
+ Arguments.of(new SmallIntType(), VarCharType.STRING_TYPE,
true),
+ Arguments.of(new IntType(), VarCharType.STRING_TYPE, true),
+ Arguments.of(new BigIntType(), VarCharType.STRING_TYPE, true),
+
+ // FLOAT/DOUBLE to STRING are injective
+ Arguments.of(new FloatType(), VarCharType.STRING_TYPE, true),
+ Arguments.of(new DoubleType(), VarCharType.STRING_TYPE, true),
+
+ // Explicit casts to STRING from boolean are injective
+ Arguments.of(new BooleanType(), VarCharType.STRING_TYPE, true),
+
+ // Explicit casts to STRING from date/time types are injective
+ Arguments.of(new DateType(), VarCharType.STRING_TYPE, true),
+ Arguments.of(new TimeType(3), VarCharType.STRING_TYPE, true),
+ Arguments.of(new TimestampType(3), VarCharType.STRING_TYPE,
true),
+ Arguments.of(new TimestampType(9), VarCharType.STRING_TYPE,
true),
+ Arguments.of(new LocalZonedTimestampType(3),
VarCharType.STRING_TYPE, true),
+
+ // Casts to CHAR are injective if the target length is
sufficient
+ Arguments.of(new IntType(), new CharType(100), true),
+ Arguments.of(new BigIntType(), new CharType(100), true),
+ Arguments.of(new IntType(), new CharType(11), true), // exact
minimum
+ Arguments.of(new IntType(), new CharType(3), false), // too
short for "-2147483648"
+ Arguments.of(new BigIntType(), new CharType(20), true), //
exact minimum
+ Arguments.of(new BigIntType(), new CharType(19), false), //
too short
+ Arguments.of(new BooleanType(), new CharType(5), true), //
exact minimum for "false"
+ Arguments.of(new BooleanType(), new CharType(4), false), //
too short
+
+ // CHAR → VARCHAR widening is injective
+ Arguments.of(new CharType(10), VarCharType.STRING_TYPE, true),
+
+ // BINARY → VARBINARY widening is injective
+ Arguments.of(new BinaryType(10), new VarBinaryType(100), true),
+
+ // Narrowing casts are NOT injective (lossy)
+ Arguments.of(VarCharType.STRING_TYPE, new IntType(), false),
+ Arguments.of(new BigIntType(), new IntType(), false),
+ Arguments.of(new DoubleType(), new FloatType(), false),
+
+ // TIMESTAMP → DATE is NOT injective (loses time-of-day
information)
+ // even though it is an implicit cast
+ Arguments.of(new TimestampType(3), new DateType(), false),
+
+ // DECIMAL to STRING is NOT considered injective
+ Arguments.of(new DecimalType(10, 2), VarCharType.STRING_TYPE,
false),
+
+ // BYTES to STRING is NOT injective (invalid UTF-8 sequences
collapse)
+ Arguments.of(new VarBinaryType(100), VarCharType.STRING_TYPE,
false),
+ Arguments.of(new BinaryType(100), VarCharType.STRING_TYPE,
false),
+
+ // TIMESTAMP_WITH_TIME_ZONE to STRING is NOT injective
+ // (theory: two timestamps with different zones could produce
same string
+ // depending on the implementation)
+ Arguments.of(new ZonedTimestampType(3),
VarCharType.STRING_TYPE, false),
+
+ // INT → FLOAT/DOUBLE are theoretically injective
+ // However, we decided not to support decimal, float and double
+ // injective conversions at first since it's not a practical
use case
+ Arguments.of(new IntType(), new FloatType(), false),
+ Arguments.of(new IntType(), new DoubleType(), false),
+
+ // STRING → BOOLEAN is NOT injective
+ Arguments.of(VarCharType.STRING_TYPE, new BooleanType(),
false),
+
+ // DOUBLE → INT is NOT injective
+ Arguments.of(new DoubleType(), new IntType(), false),
+
+ // DECIMAL → DECIMAL: only identity casts are injective
+ // (changing precision/scale can lose data in various ways)
+ Arguments.of(new DecimalType(10, 2), new DecimalType(10, 2),
true), // identity
+ Arguments.of(new DecimalType(10, 2), new DecimalType(20, 4),
false), // not identity
+ Arguments.of(
+ new DecimalType(10, 2), new DecimalType(15, 2),
false), // precision change
+ Arguments.of(new DecimalType(10, 2), new DecimalType(10, 4),
false), // scale change
+ Arguments.of(new DecimalType(20, 4), new DecimalType(10, 2),
false), // narrowing
+ Arguments.of(
+ new DecimalType(10, 4), new DecimalType(10, 2),
false), // scale narrowing
+
+ // Timestamp conversions between variants are injective
+ Arguments.of(new TimestampType(3), new
LocalZonedTimestampType(3), true),
+ Arguments.of(new LocalZonedTimestampType(3), new
TimestampType(3), true),
+
+ // ROW types with injective field casts
+ Arguments.of(
+ new RowType(
+ Arrays.asList(
+ new RowField("id", new IntType()),
+ new RowField("name",
VarCharType.STRING_TYPE))),
+ new RowType(
+ Arrays.asList(
+ new RowField("id", new BigIntType()),
+ new RowField("name",
VarCharType.STRING_TYPE))),
+ true),
+
+ // ROW types with non-injective field cast (TIMESTAMP → DATE)
+ Arguments.of(
+ new RowType(
+ Arrays.asList(
+ new RowField("id", new IntType()),
+ new RowField("ts", new
TimestampType(3)))),
+ new RowType(
+ Arrays.asList(
+ new RowField("id", new IntType()),
+ new RowField("ts", new DateType()))),
+ false),
+
+ // ROW types with mixed casts (one injective, one not)
+ Arguments.of(
+ new RowType(
+ Arrays.asList(
+ new RowField("id", new IntType()),
+ new RowField("val", new
DoubleType()))),
+ new RowType(
+ Arrays.asList(
+ new RowField("id",
VarCharType.STRING_TYPE),
+ new RowField("val", new IntType()))),
+ false),
+
+ // ---- Parameter-aware injective cast checks ----
+
+ // CHAR length checks
+ Arguments.of(new CharType(10), new CharType(10), true), //
identity
+ Arguments.of(new CharType(10), new CharType(20), true), //
widening
+ Arguments.of(new CharType(20), new CharType(10), false), //
truncation
+
+ // VARCHAR length checks
+ Arguments.of(new VarCharType(10), new VarCharType(100), true),
// widening
+ Arguments.of(new VarCharType(100), new VarCharType(10),
false), // truncation
+
+ // CHAR → VARCHAR length checks
+ Arguments.of(new CharType(10), new VarCharType(20), true), //
widening
+ Arguments.of(new CharType(20), new VarCharType(10), false), //
truncation
+
+ // BINARY length checks
+ Arguments.of(new BinaryType(10), new BinaryType(10), true), //
identity
+ Arguments.of(new BinaryType(10), new BinaryType(20), true), //
widening
+ Arguments.of(new BinaryType(20), new BinaryType(10), false),
// truncation
+
+ // VARBINARY length checks
+ Arguments.of(new VarBinaryType(10), new VarBinaryType(100),
true), // widening
+ Arguments.of(new VarBinaryType(100), new VarBinaryType(10),
false), // truncation
+
+ // BINARY → VARBINARY length checks
+ Arguments.of(new BinaryType(10), new VarBinaryType(20), true),
// widening
+ Arguments.of(new BinaryType(20), new VarBinaryType(10),
false), // truncation
+
+ // TIMESTAMP precision checks (identity only)
+ Arguments.of(new TimestampType(3), new TimestampType(3),
true), // identity
+ Arguments.of(
+ new TimestampType(3), new TimestampType(6), false), //
widening rejected
+ Arguments.of(
+ new TimestampType(6), new TimestampType(3), false), //
narrowing rejected
+
+ // TIMESTAMP ↔ TIMESTAMP_LTZ precision checks (identity only)
+ Arguments.of(
+ new TimestampType(3), new LocalZonedTimestampType(6),
false), // rejected
+ Arguments.of(
+ new LocalZonedTimestampType(6), new TimestampType(3),
false), // rejected
+
+ // TIMESTAMP_LTZ precision checks (identity only)
+ Arguments.of(
+ new LocalZonedTimestampType(3),
+ new LocalZonedTimestampType(3),
+ true), // identity
+ Arguments.of(
+ new LocalZonedTimestampType(3),
+ new LocalZonedTimestampType(6),
+ false), // rejected
+
+ // TIMESTAMP_TZ precision checks (identity only)
+ Arguments.of(
+ new ZonedTimestampType(3), new ZonedTimestampType(3),
true), // identity
+ Arguments.of(
+ new ZonedTimestampType(3), new ZonedTimestampType(6),
false), // rejected
+
+ // TIME precision checks (identity only)
+ Arguments.of(new TimeType(0), new TimeType(0), true), //
identity
+ Arguments.of(new TimeType(0), new TimeType(3), false), //
widening rejected
+ Arguments.of(new TimeType(3), new TimeType(0), false), //
narrowing rejected
+
+ // Cross-family casts to VARCHAR with insufficient length
+ Arguments.of(new IntType(), new VarCharType(3), false), // too
short
+ Arguments.of(new TimestampType(3), new VarCharType(5), false),
// too short
+
+ // DOUBLE identity
+ Arguments.of(new DoubleType(), new DoubleType(), true));
+ }
+
+ @ParameterizedTest(name = "{index}: [From: {0}, To: {1}, Injective: {2}]")
+ @MethodSource("injectiveCastTestData")
+ void testInjectiveCast(
+ LogicalType sourceType, LogicalType targetType, boolean
expectedInjective) {
+ assertThat(LogicalTypeCasts.supportsInjectiveCast(sourceType,
targetType))
+ .as(
+ "Cast from %s to %s should %s injective",
+ sourceType, targetType, expectedInjective ? "be" :
"not be")
+ .isEqualTo(expectedInjective);
+ }
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
index 046c131edf5..c445d8b2949 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
@@ -173,9 +173,9 @@ class FlinkRelMdUniqueKeys private extends
MetadataHandler[BuiltInMetadata.Uniqu
}
case _ => // ignore
}
- // rename or cast
+ // rename or key-preserving cast (fidelity or injective)
case a: RexCall
- if (a.getKind.equals(SqlKind.AS) || isFidelityCast(a)) &&
+ if (a.getKind.equals(SqlKind.AS) || isKeyPreservingCast(a)) &&
a.getOperands.get(0).isInstanceOf[RexInputRef] =>
appendMapInToOutPos(a.getOperands.get(0).asInstanceOf[RexInputRef].getIndex, i)
case _ => // ignore
@@ -214,14 +214,19 @@ class FlinkRelMdUniqueKeys private extends
MetadataHandler[BuiltInMetadata.Uniqu
projUniqueKeySet
}
- /** Whether the [[RexCall]] is a cast that doesn't lose any information. */
- private def isFidelityCast(call: RexCall): Boolean = {
+ /**
+ * Whether the [[RexCall]] is a cast that preserves key uniqueness
(injective cast).
+ *
+ * An injective cast is one where each distinct input maps to a distinct
output, ensuring that
+ * unique keys remain unique after the cast.
+ */
+ private def isKeyPreservingCast(call: RexCall): Boolean = {
if (call.getKind != SqlKind.CAST) {
return false
}
val originalType =
FlinkTypeFactory.toLogicalType(call.getOperands.get(0).getType)
val newType = FlinkTypeFactory.toLogicalType(call.getType)
- LogicalTypeCasts.supportsImplicitCast(originalType, newType)
+ LogicalTypeCasts.supportsInjectiveCast(originalType, newType)
}
def getUniqueKeys(
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.xml
index edf9e924f6b..179e09c8ad3 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.xml
@@ -485,6 +485,27 @@ Sink(table=[default_catalog.default_database.sink],
fields=[a, b])
+- Exchange(distribution=[single])
+- Calc(select=[a, b])
+- DataStreamScan(table=[[default_catalog, default_database,
MyTable]], fields=[a, b, c])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testInjectiveCastPreservesUpsertKey">
+ <Resource name="ast">
+ <![CDATA[
+LogicalSink(table=[default_catalog.default_database.sink_agg_with_string_pk],
fields=[EXPR$0, EXPR$1])
++- LogicalProject(EXPR$0=[CAST($0):VARCHAR(2147483647) CHARACTER SET
"UTF-16LE"], EXPR$1=[$1])
+ +- LogicalAggregate(group=[{0}], EXPR$1=[COUNT()])
+ +- LogicalProject(a=[$0])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
MyTable]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Sink(table=[default_catalog.default_database.sink_agg_with_string_pk],
fields=[EXPR$0, EXPR$1], changelogMode=[NONE])
++- Calc(select=[CAST(a AS VARCHAR(2147483647)) AS EXPR$0, EXPR$1],
changelogMode=[I,UA])
+ +- GroupAggregate(groupBy=[a], select=[a, COUNT(*) AS EXPR$1],
changelogMode=[I,UA])
+ +- Exchange(distribution=[hash[a]], changelogMode=[I])
+ +- Calc(select=[a], changelogMode=[I])
+ +- DataStreamScan(table=[[default_catalog, default_database,
MyTable]], fields=[a, b, c], changelogMode=[I])
]]>
</Resource>
</TestCase>
@@ -596,6 +617,33 @@
LogicalSink(table=[default_catalog.default_database.MetadataTable], fields=[meta
<![CDATA[
Sink(table=[default_catalog.default_database.MetadataTable],
fields=[metadata_1, metadata_2, other, m_2])
+- TableSourceScan(table=[[default_catalog, default_database, MetadataTable,
project=[metadata_1, metadata_2, other], metadata=[metadata_2]]],
fields=[metadata_1, metadata_2, other, m_2])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testNonInjectiveCastLosesUpsertKey">
+ <Resource name="explain">
+ <![CDATA[== Abstract Syntax Tree ==
+LogicalSink(table=[default_catalog.default_database.sink_agg_with_int_pk],
fields=[EXPR$0, EXPR$1])
++- LogicalProject(EXPR$0=[CAST($0):INTEGER], EXPR$1=[$1])
+ +- LogicalAggregate(group=[{0}], EXPR$1=[COUNT()])
+ +- LogicalProject(c=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
MyTable]])
+
+== Optimized Physical Plan ==
+Sink(table=[default_catalog.default_database.sink_agg_with_int_pk],
fields=[EXPR$0, EXPR$1], upsertMaterialize=[true],
conflictStrategy=[DEDUPLICATE], changelogMode=[NONE])
++- Calc(select=[CAST(c AS INTEGER) AS EXPR$0, EXPR$1], changelogMode=[I,UB,UA])
+ +- GroupAggregate(groupBy=[c], select=[c, COUNT(*) AS EXPR$1],
changelogMode=[I,UB,UA])
+ +- Exchange(distribution=[hash[c]], changelogMode=[I])
+ +- Calc(select=[c], changelogMode=[I])
+ +- DataStreamScan(table=[[default_catalog, default_database,
MyTable]], fields=[a, b, c], changelogMode=[I])
+
+== Optimized Execution Plan ==
+Sink(table=[default_catalog.default_database.sink_agg_with_int_pk],
fields=[EXPR$0, EXPR$1], upsertMaterialize=[true],
conflictStrategy=[DEDUPLICATE])
++- Calc(select=[CAST(c AS INTEGER) AS EXPR$0, EXPR$1])
+ +- GroupAggregate(groupBy=[c], select=[c, COUNT(*) AS EXPR$1])
+ +- Exchange(distribution=[hash[c]])
+ +- Calc(select=[c])
+ +- DataStreamScan(table=[[default_catalog, default_database,
MyTable]], fields=[a, b, c])
]]>
</Resource>
</TestCase>
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala
index ddb440bc320..41fd47cc62a 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala
@@ -25,6 +25,7 @@ import org.apache.flink.table.planner.plan.utils.ExpandUtil
import com.google.common.collect.{ImmutableList, ImmutableSet}
import org.apache.calcite.prepare.CalciteCatalogReader
import org.apache.calcite.rel.hint.RelHint
+import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR
import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN}
import org.apache.calcite.util.ImmutableBitSet
import org.junit.jupiter.api.Assertions._
@@ -92,10 +93,101 @@ class FlinkRelMdUniqueKeysTest extends
FlinkRelMdHandlerTestBase {
relBuilder.field(1)
)
val project1 = relBuilder.project(exprs).build()
- assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(project1).toSet)
+ // INT -> BIGINT is an injective cast, so position 2 is also a unique key
+ assertEquals(uniqueKeys(Array(1), Array(2)),
mq.getUniqueKeys(project1).toSet)
assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(project1,
true).toSet)
}
+ @Test
+ def testGetUniqueKeysOnProjectWithInjectiveCastToString(): Unit = {
+ // INT -> STRING is an injective cast (each distinct int maps to a
distinct string).
+ // When a unique key column is cast this way, the uniqueness is preserved.
+ relBuilder.push(studentLogicalScan)
+
+ val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+ // Project: CAST(id AS STRING), name
+ // id (position 0 in source) is the unique key
+ val exprs = List(
+ rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS
STRING)
+ relBuilder.field(1) // name
+ )
+ val project = relBuilder.project(exprs).build()
+
+ // The casted id at position 0 should still be recognized as unique
+ assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(project).toSet)
+ }
+
+ @Test
+ def testGetUniqueKeysOnProjectWithMultipleKeyReferences(): Unit = {
+ // When the same unique key column appears multiple times in a projection
+ // (either raw or via injective cast), all references are recognized as
keys.
+ relBuilder.push(studentLogicalScan)
+
+ val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+ // Project: CAST(id AS STRING), id, name
+ val exprs = List(
+ rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS
STRING) - injective
+ relBuilder.field(0), // id (raw reference)
+ relBuilder.field(1) // name
+ )
+ val project = relBuilder.project(exprs).build()
+
+ // Both position 0 (STRING cast of id) and position 1 (raw id) are unique
keys
+ assertEquals(uniqueKeys(Array(0), Array(1)),
mq.getUniqueKeys(project).toSet)
+ }
+
+ @Test
+ def testGetUniqueKeysOnProjectInjectiveCastOnlyPreservesExistingKeys(): Unit
= {
+ // Injective casts PRESERVE uniqueness but don't CREATE it.
+ // Casting a non-key column with an injective cast doesn't make it a key.
+ relBuilder.push(studentLogicalScan)
+
+ val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+ // Project: id, CAST(name AS STRING)
+ // id is the unique key; name is NOT a key (even after casting)
+ val exprs = List(
+ relBuilder.field(0), // id - the unique key
+ rexBuilder.makeCast(stringType, relBuilder.field(1)) // CAST(name AS
STRING) - not a key
+ )
+ val project = relBuilder.project(exprs).build()
+
+ // Only position 0 (id) is a unique key
+ // Position 1 (cast of name) is NOT a key because name wasn't a key to
begin with
+ assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(project).toSet)
+ }
+
+ @Test
+ def testGetUniqueKeysOnProjectNonInjectiveCastLosesKey(): Unit = {
+ // STRING -> INT is NOT an injective cast (e.g., "1" and "01" both become
1).
+ // When a unique key is cast this way, the uniqueness cannot be guaranteed.
+ relBuilder.push(studentLogicalScan)
+
+ val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+ // First, project id as STRING to simulate a STRING key column
+ val stringKeyExprs = List(
+ rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS
STRING)
+ relBuilder.field(1) // name
+ )
+ val stringKeyProject = relBuilder.project(stringKeyExprs).build()
+ // At this point, position 0 is a STRING that's still a unique key
+ assertEquals(uniqueKeys(Array(0)),
mq.getUniqueKeys(stringKeyProject).toSet)
+
+ // Now cast the STRING back to INT - this is a non-injective (narrowing)
cast
+ relBuilder.push(stringKeyProject)
+ val narrowedExprs = List(
+ rexBuilder.makeCast(intType, relBuilder.field(0)), // CAST(string_id AS
INT) - NOT injective
+ relBuilder.field(1) // name
+ )
+ val narrowedProject = relBuilder.project(narrowedExprs).build()
+
+ // The key is LOST because STRING->INT is not injective
+ assertEquals(uniqueKeys(), mq.getUniqueKeys(narrowedProject).toSet)
+ }
+
@Test
def testGetUniqueKeysOnFilter(): Unit = {
assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalFilter).toSet)
@@ -133,7 +225,8 @@ class FlinkRelMdUniqueKeysTest extends
FlinkRelMdHandlerTestBase {
)
val rowType = relBuilder.project(exprs).build().getRowType
val calc2 = createLogicalCalc(studentLogicalScan, rowType, exprs,
List(expr))
- assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(calc2).toSet)
+ // INT -> BIGINT is an injective cast, so position 2 is also a unique key
+ assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(calc2).toSet)
assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(calc2,
true).toSet)
}
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
index 40e38dda161..3b7524e1bd3 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
@@ -26,6 +26,7 @@ import com.google.common.collect.{ImmutableList, ImmutableSet}
import org.apache.calcite.prepare.CalciteCatalogReader
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.hint.RelHint
+import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR
import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN}
import org.apache.calcite.util.ImmutableBitSet
import org.junit.jupiter.api.Assertions._
@@ -88,12 +89,81 @@ class FlinkRelMdUpsertKeysTest extends
FlinkRelMdHandlerTestBase {
val exprs = List(
relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)),
relBuilder.field(0),
+ // INT -> BIGINT is an injective cast, so position 2 is now also an
upsert key
rexBuilder.makeCast(longType, relBuilder.field(0)),
rexBuilder.makeCast(intType, relBuilder.field(0)),
relBuilder.field(1)
)
val project1 = relBuilder.project(exprs).build()
- assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(project1).toSet)
+
+ assertEquals(toBitSet(Array(1), Array(2)),
mq.getUpsertKeys(project1).toSet)
+ }
+
+ @Test
+ def testGetUpsertKeysOnProjectWithInjectiveCastToString(): Unit = {
+ // INT -> STRING is an injective cast (each distinct int maps to a
distinct string).
+ // When an upsert key column is cast this way, the key property is
preserved.
+ relBuilder.push(studentLogicalScan)
+
+ val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+ // Project: CAST(id AS STRING), name
+ val exprs = List(
+ rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS
STRING)
+ relBuilder.field(1) // name
+ )
+ val project = relBuilder.project(exprs).build()
+
+ // The casted id at position 0 should still be recognized as upsert key
+ assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(project).toSet)
+ }
+
+ @Test
+ def testGetUpsertKeysOnProjectWithMultipleKeyReferences(): Unit = {
+ // When the same upsert key column appears multiple times in a projection
+ // (either raw or via injective cast), all references are recognized as
keys.
+ relBuilder.push(studentLogicalScan)
+
+ val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+ // Project: CAST(id AS STRING), id, name
+ val exprs = List(
+ rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS
STRING) - injective
+ relBuilder.field(0), // id (raw reference)
+ relBuilder.field(1) // name
+ )
+ val project = relBuilder.project(exprs).build()
+
+ // Both position 0 (STRING cast of id) and position 1 (raw id) are upsert
keys
+ assertEquals(toBitSet(Array(0), Array(1)), mq.getUpsertKeys(project).toSet)
+ }
+
+ @Test
+ def testGetUpsertKeysOnProjectNonInjectiveCastLosesKey(): Unit = {
+ // STRING -> INT is NOT an injective cast (e.g., "1" and "01" both become
1).
+ // When an upsert key is cast this way, the key property is lost.
+ relBuilder.push(studentLogicalScan)
+
+ val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+ // First, project id as STRING to simulate a STRING key column
+ val stringKeyExprs = List(
+ rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS
STRING)
+ relBuilder.field(1) // name
+ )
+ val stringKeyProject = relBuilder.project(stringKeyExprs).build()
+ assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(stringKeyProject).toSet)
+
+ // Now cast the STRING back to INT - this is a non-injective cast
+ relBuilder.push(stringKeyProject)
+ val narrowedExprs = List(
+ rexBuilder.makeCast(intType, relBuilder.field(0)), // CAST(string_id AS
INT) - NOT injective
+ relBuilder.field(1) // name
+ )
+ val narrowedProject = relBuilder.project(narrowedExprs).build()
+
+ // The key is LOST because STRING->INT is not injective
+ assertEquals(toBitSet(), mq.getUpsertKeys(narrowedProject).toSet)
}
@Test
@@ -133,7 +203,8 @@ class FlinkRelMdUpsertKeysTest extends
FlinkRelMdHandlerTestBase {
)
val rowType = relBuilder.project(exprs).build().getRowType
val calc2 = createLogicalCalc(studentLogicalScan, rowType, exprs,
List(expr))
- assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(calc2).toSet)
+ // INT -> BIGINT is an injective cast, so position 2 is now also an upsert
key
+ assertEquals(toBitSet(Array(1), Array(2)), mq.getUpsertKeys(calc2).toSet)
}
@Test
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.scala
index 6e0b5323964..ef9776caedc 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.scala
@@ -947,6 +947,55 @@ class TableSinkTest extends TableTestBase {
util.verifyRelPlan(stmtSet, ExplainDetail.CHANGELOG_MODE)
}
+ @Test
+ def testInjectiveCastPreservesUpsertKey(): Unit = {
+ // Aggregation produces upsert stream with key (a).
+ // Sink expects STRING primary key.
+ // CAST(INT AS STRING) is injective, so the upsert key is preserved - no
materializer needed.
+ util.tableEnv.executeSql("""
+ |CREATE TABLE sink_agg_with_string_pk (
+ | id STRING NOT NULL,
+ | cnt BIGINT,
+ | PRIMARY KEY (id) NOT ENFORCED
+ |) WITH (
+ | 'connector' = 'values',
+ | 'sink-insert-only' = 'false'
+ |)
+ |""".stripMargin)
+ val stmtSet = util.tableEnv.createStatementSet()
+ // GROUP BY a produces upsert key (a), then CAST(a AS STRING) preserves it
+ stmtSet.addInsertSql(
+ "INSERT INTO sink_agg_with_string_pk SELECT CAST(a AS STRING), COUNT(*)
FROM MyTable GROUP BY a")
+ // The plan should NOT contain upsertMaterialize=[true] because
INT->STRING is injective
+ util.verifyRelPlan(stmtSet, ExplainDetail.CHANGELOG_MODE)
+ }
+
+ @Test
+ def testNonInjectiveCastLosesUpsertKey(): Unit = {
+ // Aggregation produces upsert stream with key (c) which is STRING.
+ // Sink expects INT primary key.
+ // CAST(STRING AS INT) is NOT injective ("1" and "01" both become 1), so
key is lost.
+ util.tableEnv.executeSql("""
+ |CREATE TABLE sink_agg_with_int_pk (
+ | id INT NOT NULL,
+ | cnt BIGINT,
+ | PRIMARY KEY (id) NOT ENFORCED
+ |) WITH (
+ | 'connector' = 'values',
+ | 'sink-insert-only' = 'false'
+ |)
+ |""".stripMargin)
+ // GROUP BY c (STRING) produces upsert key (c), then CAST(c AS INT) loses
it
+ util.verifyExplainInsert(
+ """
+ |INSERT INTO sink_agg_with_int_pk
+ |SELECT CAST(c AS INT), COUNT(*) FROM MyTable GROUP BY c
+ |ON CONFLICT DO DEDUPLICATE
+ |""".stripMargin,
+ ExplainDetail.CHANGELOG_MODE
+ )
+ }
+
}
/** tests table factory use ParallelSourceFunction which support parallelism
by env */
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/ChangelogSourceITCase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/ChangelogSourceITCase.scala
index d2926b54325..f593875a5a3 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/ChangelogSourceITCase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/ChangelogSourceITCase.scala
@@ -46,8 +46,8 @@ import scala.collection.JavaConversions._
class ChangelogSourceITCase(
sourceMode: SourceMode,
miniBatch: MiniBatchMode,
- state: StateBackendMode)
- extends StreamingWithMiniBatchTestBase(miniBatch, state) {
+ state: StateBackendMode
+) extends StreamingWithMiniBatchTestBase(miniBatch, state) {
@RegisterExtension private val _: EachCallbackWrapper[LegacyRowExtension] =
new EachCallbackWrapper[LegacyRowExtension](new LegacyRowExtension)
@@ -192,12 +192,16 @@ class ChangelogSourceITCase(
| 'sink-insert-only' = 'false'
|)
|""".stripMargin
+ // Note: balance2 is computed as balance*2 which results in a higher
precision DECIMAL type.
+ // Since the sink's balance column has fixed precision, this requires ON
CONFLICT handling
+ // because narrowing DECIMAL casts (e.g., DECIMAL(28,2) -> DECIMAL(18,2))
are not injective.
val dml =
s"""
|INSERT INTO user_sink
|SELECT balance2, count(*), max(email)
|FROM users
|GROUP BY balance2
+ |ON CONFLICT DO DEDUPLICATE
|""".stripMargin
tEnv.executeSql(sinkDDL)
tEnv.executeSql(dml).await()