This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 3832967cb73 [FLINK-39088][table] Fix upsert key preservation by 
introducing injective cast checks for CAST
3832967cb73 is described below

commit 3832967cb736055b369be037c6394686e4e3b3ab
Author: Gustavo de Morais <[email protected]>
AuthorDate: Thu Feb 19 15:27:44 2026 +0100

    [FLINK-39088][table] Fix upsert key preservation by introducing injective 
cast checks for CAST
    
    This closes #27603.
---
 .../types/logical/utils/LogicalTypeCasts.java      | 266 +++++++++++++++++++--
 .../flink/table/types/LogicalTypeCastsTest.java    | 219 +++++++++++++++++
 .../plan/metadata/FlinkRelMdUniqueKeys.scala       |  15 +-
 .../planner/plan/stream/sql/TableSinkTest.xml      |  48 ++++
 .../plan/metadata/FlinkRelMdUniqueKeysTest.scala   |  97 +++++++-
 .../plan/metadata/FlinkRelMdUpsertKeysTest.scala   |  75 +++++-
 .../planner/plan/stream/sql/TableSinkTest.scala    |  49 ++++
 .../runtime/stream/sql/ChangelogSourceITCase.scala |   8 +-
 8 files changed, 750 insertions(+), 27 deletions(-)

diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
index 310bc71e29b..307f47f76d3 100644
--- 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java
@@ -31,6 +31,7 @@ import org.apache.flink.table.types.logical.VarBinaryType;
 import org.apache.flink.table.types.logical.VarCharType;
 import org.apache.flink.table.types.logical.YearMonthIntervalType;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -38,6 +39,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.BiFunction;
+import java.util.function.BiPredicate;
 import java.util.stream.Collectors;
 
 import static 
org.apache.flink.table.types.logical.LogicalTypeFamily.BINARY_STRING;
@@ -79,6 +81,8 @@ import static 
org.apache.flink.table.types.logical.LogicalTypeRoot.VARCHAR;
 import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getDayPrecision;
 import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFractionalPrecision;
 import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getLength;
+import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getPrecision;
+import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getScale;
 import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getYearPrecision;
 import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isSingleFieldInterval;
 
@@ -110,116 +114,203 @@ public final class LogicalTypeCasts {
 
     private static final Map<LogicalTypeRoot, Set<LogicalTypeRoot>> 
explicitCastingRules;
 
+    private static final Map<LogicalTypeRoot, List<InjectiveRule>> 
injectiveRules;
+
+    // Types with deterministic, unique string representations (for injective 
casts to STRING)
+    private static final LogicalTypeRoot[] STRING_INJECTIVE_SOURCES = {
+        TINYINT,
+        SMALLINT,
+        INTEGER,
+        BIGINT,
+        FLOAT,
+        DOUBLE,
+        BOOLEAN,
+        DATE,
+        TIME_WITHOUT_TIME_ZONE,
+        TIMESTAMP_WITHOUT_TIME_ZONE,
+        TIMESTAMP_WITH_LOCAL_TIME_ZONE
+    };
+
+    // ----- Injective cast conditions -----
+
+    /** Injective when the target length can hold any value of the source 
length. */
+    private static final BiPredicate<LogicalType, LogicalType> 
WHEN_LENGTH_FITS =
+            (source, target) -> getLength(target) >= getLength(source);
+
+    /** Injective when the target length can hold the source type's max string 
representation. */
+    private static final BiPredicate<LogicalType, LogicalType> 
WHEN_MAX_CHAR_LENGTH_FITS =
+            (source, target) -> getLength(target) >= 
maxStringRepresentationLength(source);
+
+    /** Injective when source and target share identical precision. */
+    private static final BiPredicate<LogicalType, LogicalType> 
WHEN_PRECISION_MATCHES =
+            (source, target) -> getPrecision(source) == getPrecision(target);
+
+    /** Injective when source and target share identical precision and scale 
(DECIMAL). */
+    private static final BiPredicate<LogicalType, LogicalType> 
WHEN_PRECISION_AND_SCALE_MATCH =
+            (source, target) ->
+                    getPrecision(source) == getPrecision(target)
+                            && getScale(source) == getScale(target);
+
     static {
         implicitCastingRules = new HashMap<>();
         explicitCastingRules = new HashMap<>();
+        injectiveRules = new HashMap<>();
 
-        // identity casts
-
+        // Identity casts: all types can be implicitly cast to themselves.
+        // Injective identity for parameterized types is declared per-type 
with conditions.
         for (LogicalTypeRoot typeRoot : allTypes()) {
             castTo(typeRoot).implicitFrom(typeRoot).build();
         }
 
-        // cast specification
+        // 
-----------------------------------------------------------------------------------------
+        // Character string types
+        // 
-----------------------------------------------------------------------------------------
 
         castTo(CHAR)
                 .implicitFrom(CHAR)
                 .explicitFromFamily(PREDEFINED, CONSTRUCTED)
                 .explicitFrom(RAW, NULL, STRUCTURED_TYPE)
+                .injectiveFrom(WHEN_LENGTH_FITS, CHAR)
+                .injectiveFrom(WHEN_MAX_CHAR_LENGTH_FITS, 
STRING_INJECTIVE_SOURCES)
                 .build();
 
         castTo(VARCHAR)
                 .implicitFromFamily(CHARACTER_STRING)
                 .explicitFromFamily(PREDEFINED, CONSTRUCTED)
                 .explicitFrom(RAW, NULL, STRUCTURED_TYPE)
+                .injectiveFrom(WHEN_LENGTH_FITS, CHAR, VARCHAR)
+                .injectiveFrom(WHEN_MAX_CHAR_LENGTH_FITS, 
STRING_INJECTIVE_SOURCES)
                 .build();
 
-        castTo(BOOLEAN)
-                .implicitFrom(BOOLEAN)
-                .explicitFromFamily(CHARACTER_STRING, INTEGER_NUMERIC)
-                .build();
+        // 
-----------------------------------------------------------------------------------------
+        // Binary string types
+        // 
-----------------------------------------------------------------------------------------
 
         castTo(BINARY)
                 .implicitFrom(BINARY)
                 .explicitFromFamily(CHARACTER_STRING)
-                .explicitFrom(VARBINARY)
-                .explicitFrom(RAW)
+                .explicitFrom(VARBINARY, RAW)
+                .injectiveFrom(WHEN_LENGTH_FITS, BINARY)
                 .build();
 
         castTo(VARBINARY)
                 .implicitFromFamily(BINARY_STRING)
                 .explicitFromFamily(CHARACTER_STRING)
-                .explicitFrom(BINARY)
-                .explicitFrom(RAW)
+                .explicitFrom(BINARY, RAW)
+                .injectiveFrom(WHEN_LENGTH_FITS, BINARY, VARBINARY)
                 .build();
 
-        castTo(DECIMAL)
-                .implicitFromFamily(NUMERIC)
-                .explicitFromFamily(CHARACTER_STRING, INTERVAL)
-                .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
-                .build();
+        // 
-----------------------------------------------------------------------------------------
+        // Exact numeric types
+        // 
-----------------------------------------------------------------------------------------
 
         castTo(TINYINT)
                 .implicitFrom(TINYINT)
                 .explicitFromFamily(NUMERIC, CHARACTER_STRING, INTERVAL)
                 .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+                .injectiveFrom(TINYINT)
                 .build();
 
         castTo(SMALLINT)
                 .implicitFrom(TINYINT, SMALLINT)
                 .explicitFromFamily(NUMERIC, CHARACTER_STRING, INTERVAL)
                 .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+                .injectiveFrom(TINYINT, SMALLINT)
                 .build();
 
         castTo(INTEGER)
                 .implicitFrom(TINYINT, SMALLINT, INTEGER)
                 .explicitFromFamily(NUMERIC, CHARACTER_STRING, INTERVAL)
                 .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+                .injectiveFrom(TINYINT, SMALLINT, INTEGER)
                 .build();
 
         castTo(BIGINT)
                 .implicitFrom(TINYINT, SMALLINT, INTEGER, BIGINT)
                 .explicitFromFamily(NUMERIC, CHARACTER_STRING, INTERVAL)
                 .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+                .injectiveFrom(TINYINT, SMALLINT, INTEGER, BIGINT)
                 .build();
 
+        castTo(DECIMAL)
+                .implicitFromFamily(NUMERIC)
+                .explicitFromFamily(CHARACTER_STRING, INTERVAL)
+                .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+                .injectiveFrom(WHEN_PRECISION_AND_SCALE_MATCH, DECIMAL)
+                .build();
+
+        // 
-----------------------------------------------------------------------------------------
+        // Approximate numeric types
+        // 
-----------------------------------------------------------------------------------------
+
         castTo(FLOAT)
                 .implicitFrom(TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, 
DECIMAL)
                 .explicitFromFamily(NUMERIC, CHARACTER_STRING)
                 .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+                .injectiveFrom(FLOAT)
                 .build();
 
         castTo(DOUBLE)
                 .implicitFromFamily(NUMERIC)
                 .explicitFromFamily(CHARACTER_STRING)
                 .explicitFrom(BOOLEAN, TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
+                .injectiveFrom(DOUBLE)
                 .build();
 
+        // 
-----------------------------------------------------------------------------------------
+        // Boolean type
+        // 
-----------------------------------------------------------------------------------------
+
+        castTo(BOOLEAN)
+                .implicitFrom(BOOLEAN)
+                .explicitFromFamily(CHARACTER_STRING, INTEGER_NUMERIC)
+                .injectiveFrom(BOOLEAN)
+                .build();
+
+        // 
-----------------------------------------------------------------------------------------
+        // Date and time types
+        // 
-----------------------------------------------------------------------------------------
+
         castTo(DATE)
                 .implicitFrom(DATE, TIMESTAMP_WITHOUT_TIME_ZONE)
                 .explicitFromFamily(TIMESTAMP, CHARACTER_STRING)
+                .injectiveFrom(DATE)
                 .build();
 
         castTo(TIME_WITHOUT_TIME_ZONE)
                 .implicitFrom(TIME_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITHOUT_TIME_ZONE)
                 .explicitFromFamily(TIME, TIMESTAMP, CHARACTER_STRING)
+                .injectiveFrom(WHEN_PRECISION_MATCHES, TIME_WITHOUT_TIME_ZONE)
                 .build();
 
         castTo(TIMESTAMP_WITHOUT_TIME_ZONE)
                 .implicitFrom(TIMESTAMP_WITHOUT_TIME_ZONE, 
TIMESTAMP_WITH_LOCAL_TIME_ZONE)
                 .explicitFromFamily(DATETIME, CHARACTER_STRING, NUMERIC)
+                .injectiveFrom(
+                        WHEN_PRECISION_MATCHES,
+                        TIMESTAMP_WITHOUT_TIME_ZONE,
+                        TIMESTAMP_WITH_LOCAL_TIME_ZONE)
                 .build();
 
         castTo(TIMESTAMP_WITH_TIME_ZONE)
                 .implicitFrom(TIMESTAMP_WITH_TIME_ZONE)
                 .explicitFromFamily(DATETIME, CHARACTER_STRING)
+                .injectiveFrom(WHEN_PRECISION_MATCHES, 
TIMESTAMP_WITH_TIME_ZONE)
                 .build();
 
         castTo(TIMESTAMP_WITH_LOCAL_TIME_ZONE)
                 .implicitFrom(TIMESTAMP_WITH_LOCAL_TIME_ZONE, 
TIMESTAMP_WITHOUT_TIME_ZONE)
                 .explicitFromFamily(DATETIME, CHARACTER_STRING, NUMERIC)
+                .injectiveFrom(
+                        WHEN_PRECISION_MATCHES,
+                        TIMESTAMP_WITH_LOCAL_TIME_ZONE,
+                        TIMESTAMP_WITHOUT_TIME_ZONE)
                 .build();
 
+        // 
-----------------------------------------------------------------------------------------
+        // Interval types
+        // 
-----------------------------------------------------------------------------------------
+
         castTo(INTERVAL_YEAR_MONTH)
                 .implicitFrom(INTERVAL_YEAR_MONTH)
                 .explicitFromFamily(EXACT_NUMERIC, CHARACTER_STRING)
@@ -293,6 +384,118 @@ public final class LogicalTypeCasts {
         return supportsCasting(sourceType, targetType, true);
     }
 
+    /**
+     * Returns whether the cast from source type to target type is injective
+     * (uniqueness-preserving).
+     *
+     * <p>An injective cast guarantees that every distinct input value maps to 
a distinct output
+     * value. This property is useful for upsert key tracking through 
projections: if a key column
+     * is cast using an injective conversion, the uniqueness of the key is 
preserved.
+     *
+     * <p>Injective casts are explicitly defined in the casting rules, 
separate from implicit casts.
+     * Not all implicit casts are injective (e.g., TIMESTAMP → DATE loses time 
information).
+     *
+     * <p>For constructed types (ROW), this method recursively checks if all 
field casts are
+     * injective.
+     *
+     * @param sourceType the source type
+     * @param targetType the target type
+     * @return {@code true} if the cast preserves uniqueness
+     */
+    public static boolean supportsInjectiveCast(
+            final LogicalType sourceType, final LogicalType targetType) {
+        final LogicalTypeRoot sourceRoot = sourceType.getTypeRoot();
+        final LogicalTypeRoot targetRoot = targetType.getTypeRoot();
+
+        // Handle DISTINCT types by unwrapping
+        if (sourceRoot == DISTINCT_TYPE) {
+            return supportsInjectiveCast(((DistinctType) 
sourceType).getSourceType(), targetType);
+        }
+        if (targetRoot == DISTINCT_TYPE) {
+            return supportsInjectiveCast(sourceType, ((DistinctType) 
targetType).getSourceType());
+        }
+
+        // Handle NULL type
+        if (sourceRoot == NULL) {
+            return true;
+        }
+
+        // Handle constructed types (ROW, STRUCTURED_TYPE) with recursive 
field checks
+        final boolean isSourceConstructed = sourceRoot == ROW || sourceRoot == 
STRUCTURED_TYPE;
+        final boolean isTargetConstructed = targetRoot == ROW || targetRoot == 
STRUCTURED_TYPE;
+        if (isSourceConstructed && isTargetConstructed) {
+            return supportsInjectiveConstructedCast(sourceType, targetType);
+        }
+
+        // Check declarative injective rules
+        final List<InjectiveRule> rules = injectiveRules.get(targetRoot);
+        if (rules == null) {
+            return false;
+        }
+        for (final InjectiveRule rule : rules) {
+            if (rule.test(sourceType, targetType)) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private static boolean supportsInjectiveConstructedCast(
+            final LogicalType sourceType, final LogicalType targetType) {
+        final List<LogicalType> sourceChildren = sourceType.getChildren();
+        final List<LogicalType> targetChildren = targetType.getChildren();
+
+        if (sourceChildren.size() != targetChildren.size()) {
+            return false;
+        }
+
+        for (int i = 0; i < sourceChildren.size(); i++) {
+            if (!supportsInjectiveCast(sourceChildren.get(i), 
targetChildren.get(i))) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    /**
+     * Returns the maximum number of characters needed to represent any value 
of the given type as a
+     * string, or -1 if the type does not have a bounded deterministic string 
representation that
+     * qualifies for injective casts.
+     */
+    private static int maxStringRepresentationLength(final LogicalType type) {
+        switch (type.getTypeRoot()) {
+            case BOOLEAN:
+                return 5; // "false"
+            case TINYINT:
+                return 4; // "-128"
+            case SMALLINT:
+                return 6; // "-32768"
+            case INTEGER:
+                return 11; // "-2147483648"
+            case BIGINT:
+                return 20; // "-9223372036854775808"
+            case FLOAT:
+                return 15; // e.g. "-3.4028235E38"
+            case DOUBLE:
+                return 24; // e.g. "-2.2250738585072014E-308"
+            case DATE:
+                return 10; // "9999-12-31"
+            case TIME_WITHOUT_TIME_ZONE:
+                {
+                    final int p = getPrecision(type);
+                    return p > 0 ? 9 + p : 8; // "HH:MM:SS" + optional 
".fff..."
+                }
+            case TIMESTAMP_WITHOUT_TIME_ZONE:
+            case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
+                {
+                    final int p = getPrecision(type);
+                    return p > 0 ? 20 + p : 19; // "YYYY-MM-DD HH:MM:SS" + 
optional ".fff..."
+                }
+            default:
+                return -1;
+        }
+    }
+
     /**
      * Returns whether the source type can be reinterpreted as the target type.
      *
@@ -484,11 +687,30 @@ public final class LogicalTypeCasts {
         return LogicalTypeRoot.values();
     }
 
+    /** A declarative rule describing when a cast from certain source types is 
injective. */
+    private static final class InjectiveRule {
+
+        private final Set<LogicalTypeRoot> sourceRoots;
+        private final BiPredicate<LogicalType, LogicalType> condition;
+
+        InjectiveRule(
+                Set<LogicalTypeRoot> sourceRoots, BiPredicate<LogicalType, 
LogicalType> condition) {
+            this.sourceRoots = sourceRoots;
+            this.condition = condition;
+        }
+
+        boolean test(LogicalType sourceType, LogicalType targetType) {
+            return sourceRoots.contains(sourceType.getTypeRoot())
+                    && condition.test(sourceType, targetType);
+        }
+    }
+
     private static class CastingRuleBuilder {
 
         private final LogicalTypeRoot targetType;
         private final Set<LogicalTypeRoot> implicitSourceTypes = new 
HashSet<>();
         private final Set<LogicalTypeRoot> explicitSourceTypes = new 
HashSet<>();
+        private final List<InjectiveRule> injectiveRuleList = new 
ArrayList<>();
 
         CastingRuleBuilder(LogicalTypeRoot targetType) {
             this.targetType = targetType;
@@ -526,9 +748,21 @@ public final class LogicalTypeCasts {
             return this;
         }
 
+        CastingRuleBuilder injectiveFrom(LogicalTypeRoot... sourceTypes) {
+            return injectiveFrom((s, t) -> true, sourceTypes);
+        }
+
+        CastingRuleBuilder injectiveFrom(
+                BiPredicate<LogicalType, LogicalType> condition, 
LogicalTypeRoot... sourceTypes) {
+            injectiveRuleList.add(
+                    new InjectiveRule(new 
HashSet<>(Arrays.asList(sourceTypes)), condition));
+            return this;
+        }
+
         void build() {
             implicitCastingRules.put(targetType, implicitSourceTypes);
             explicitCastingRules.put(targetType, explicitSourceTypes);
+            injectiveRules.put(targetType, injectiveRuleList);
         }
     }
 
diff --git 
a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
 
b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
index 383874a29c5..7428e5adc10 100644
--- 
a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
+++ 
b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java
@@ -26,7 +26,10 @@ import org.apache.flink.table.types.logical.ArrayType;
 import org.apache.flink.table.types.logical.BigIntType;
 import org.apache.flink.table.types.logical.BinaryType;
 import org.apache.flink.table.types.logical.BooleanType;
+import org.apache.flink.table.types.logical.CharType;
+import org.apache.flink.table.types.logical.DateType;
 import org.apache.flink.table.types.logical.DecimalType;
+import org.apache.flink.table.types.logical.DoubleType;
 import org.apache.flink.table.types.logical.FloatType;
 import org.apache.flink.table.types.logical.IntType;
 import org.apache.flink.table.types.logical.LocalZonedTimestampType;
@@ -38,10 +41,13 @@ import 
org.apache.flink.table.types.logical.RowType.RowField;
 import org.apache.flink.table.types.logical.SmallIntType;
 import org.apache.flink.table.types.logical.StructuredType;
 import org.apache.flink.table.types.logical.StructuredType.StructuredAttribute;
+import org.apache.flink.table.types.logical.TimeType;
 import org.apache.flink.table.types.logical.TimestampType;
 import org.apache.flink.table.types.logical.TinyIntType;
+import org.apache.flink.table.types.logical.VarBinaryType;
 import org.apache.flink.table.types.logical.VarCharType;
 import org.apache.flink.table.types.logical.YearMonthIntervalType;
+import org.apache.flink.table.types.logical.ZonedTimestampType;
 import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
 
 import org.junit.jupiter.api.parallel.Execution;
@@ -273,4 +279,217 @@ class LogicalTypeCastsTest {
                 .as("Supports explicit casting")
                 .isEqualTo(supportsExplicit);
     }
+
+    /**
+     * Test data for injective cast tests. Each argument contains: 
(sourceType, targetType,
+     * expectedInjective).
+     */
+    private static Stream<Arguments> injectiveCastTestData() {
+        return Stream.of(
+                // Integer widenings are injective
+                Arguments.of(new SmallIntType(), new BigIntType(), true),
+                Arguments.of(new IntType(), new BigIntType(), true),
+                Arguments.of(new TinyIntType(), new IntType(), true),
+                Arguments.of(new TinyIntType(), new SmallIntType(), true),
+
+                // Explicit casts to STRING from integer types are injective
+                Arguments.of(new TinyIntType(), VarCharType.STRING_TYPE, true),
+                Arguments.of(new SmallIntType(), VarCharType.STRING_TYPE, 
true),
+                Arguments.of(new IntType(), VarCharType.STRING_TYPE, true),
+                Arguments.of(new BigIntType(), VarCharType.STRING_TYPE, true),
+
+                // FLOAT/DOUBLE to STRING are injective
+                Arguments.of(new FloatType(), VarCharType.STRING_TYPE, true),
+                Arguments.of(new DoubleType(), VarCharType.STRING_TYPE, true),
+
+                // Explicit casts to STRING from boolean are injective
+                Arguments.of(new BooleanType(), VarCharType.STRING_TYPE, true),
+
+                // Explicit casts to STRING from date/time types are injective
+                Arguments.of(new DateType(), VarCharType.STRING_TYPE, true),
+                Arguments.of(new TimeType(3), VarCharType.STRING_TYPE, true),
+                Arguments.of(new TimestampType(3), VarCharType.STRING_TYPE, 
true),
+                Arguments.of(new TimestampType(9), VarCharType.STRING_TYPE, 
true),
+                Arguments.of(new LocalZonedTimestampType(3), 
VarCharType.STRING_TYPE, true),
+
+                // Casts to CHAR are injective if the target length is 
sufficient
+                Arguments.of(new IntType(), new CharType(100), true),
+                Arguments.of(new BigIntType(), new CharType(100), true),
+                Arguments.of(new IntType(), new CharType(11), true), // exact 
minimum
+                Arguments.of(new IntType(), new CharType(3), false), // too 
short for "-2147483648"
+                Arguments.of(new BigIntType(), new CharType(20), true), // 
exact minimum
+                Arguments.of(new BigIntType(), new CharType(19), false), // 
too short
+                Arguments.of(new BooleanType(), new CharType(5), true), // 
exact minimum for "false"
+                Arguments.of(new BooleanType(), new CharType(4), false), // 
too short
+
+                // CHAR → VARCHAR widening is injective
+                Arguments.of(new CharType(10), VarCharType.STRING_TYPE, true),
+
+                // BINARY → VARBINARY widening is injective
+                Arguments.of(new BinaryType(10), new VarBinaryType(100), true),
+
+                // Narrowing casts are NOT injective (lossy)
+                Arguments.of(VarCharType.STRING_TYPE, new IntType(), false),
+                Arguments.of(new BigIntType(), new IntType(), false),
+                Arguments.of(new DoubleType(), new FloatType(), false),
+
+                // TIMESTAMP → DATE is NOT injective (loses time-of-day 
information)
+                // even though it is an implicit cast
+                Arguments.of(new TimestampType(3), new DateType(), false),
+
+                // DECIMAL to STRING is NOT considered injective
+                Arguments.of(new DecimalType(10, 2), VarCharType.STRING_TYPE, 
false),
+
+                // BYTES to STRING is NOT injective (invalid UTF-8 sequences 
collapse)
+                Arguments.of(new VarBinaryType(100), VarCharType.STRING_TYPE, 
false),
+                Arguments.of(new BinaryType(100), VarCharType.STRING_TYPE, 
false),
+
+                // TIMESTAMP_WITH_TIME_ZONE to STRING is NOT injective
+                // (theory: two timestamps with different zones could produce 
same string
+                // depending on the implementation)
+                Arguments.of(new ZonedTimestampType(3), 
VarCharType.STRING_TYPE, false),
+
+                // INT → FLOAT/DOUBLE are theoretically injective
+                // However, we decided not to support decimal, float and double
+                // injective conversions at first since it's not a practical 
use case
+                Arguments.of(new IntType(), new FloatType(), false),
+                Arguments.of(new IntType(), new DoubleType(), false),
+
+                // STRING → BOOLEAN is NOT injective
+                Arguments.of(VarCharType.STRING_TYPE, new BooleanType(), 
false),
+
+                // DOUBLE → INT is NOT injective
+                Arguments.of(new DoubleType(), new IntType(), false),
+
+                // DECIMAL → DECIMAL: only identity casts are injective
+                // (changing precision/scale can lose data in various ways)
+                Arguments.of(new DecimalType(10, 2), new DecimalType(10, 2), 
true), // identity
+                Arguments.of(new DecimalType(10, 2), new DecimalType(20, 4), 
false), // not identity
+                Arguments.of(
+                        new DecimalType(10, 2), new DecimalType(15, 2), 
false), // precision change
+                Arguments.of(new DecimalType(10, 2), new DecimalType(10, 4), 
false), // scale change
+                Arguments.of(new DecimalType(20, 4), new DecimalType(10, 2), 
false), // narrowing
+                Arguments.of(
+                        new DecimalType(10, 4), new DecimalType(10, 2), 
false), // scale narrowing
+
+                // Timestamp conversions between variants are injective
+                Arguments.of(new TimestampType(3), new 
LocalZonedTimestampType(3), true),
+                Arguments.of(new LocalZonedTimestampType(3), new 
TimestampType(3), true),
+
+                // ROW types with injective field casts
+                Arguments.of(
+                        new RowType(
+                                Arrays.asList(
+                                        new RowField("id", new IntType()),
+                                        new RowField("name", 
VarCharType.STRING_TYPE))),
+                        new RowType(
+                                Arrays.asList(
+                                        new RowField("id", new BigIntType()),
+                                        new RowField("name", 
VarCharType.STRING_TYPE))),
+                        true),
+
+                // ROW types with non-injective field cast (TIMESTAMP → DATE)
+                Arguments.of(
+                        new RowType(
+                                Arrays.asList(
+                                        new RowField("id", new IntType()),
+                                        new RowField("ts", new 
TimestampType(3)))),
+                        new RowType(
+                                Arrays.asList(
+                                        new RowField("id", new IntType()),
+                                        new RowField("ts", new DateType()))),
+                        false),
+
+                // ROW types with mixed casts (one injective, one not)
+                Arguments.of(
+                        new RowType(
+                                Arrays.asList(
+                                        new RowField("id", new IntType()),
+                                        new RowField("val", new 
DoubleType()))),
+                        new RowType(
+                                Arrays.asList(
+                                        new RowField("id", 
VarCharType.STRING_TYPE),
+                                        new RowField("val", new IntType()))),
+                        false),
+
+                // ---- Parameter-aware injective cast checks ----
+
+                // CHAR length checks
+                Arguments.of(new CharType(10), new CharType(10), true), // 
identity
+                Arguments.of(new CharType(10), new CharType(20), true), // 
widening
+                Arguments.of(new CharType(20), new CharType(10), false), // 
truncation
+
+                // VARCHAR length checks
+                Arguments.of(new VarCharType(10), new VarCharType(100), true), 
// widening
+                Arguments.of(new VarCharType(100), new VarCharType(10), 
false), // truncation
+
+                // CHAR → VARCHAR length checks
+                Arguments.of(new CharType(10), new VarCharType(20), true), // 
widening
+                Arguments.of(new CharType(20), new VarCharType(10), false), // 
truncation
+
+                // BINARY length checks
+                Arguments.of(new BinaryType(10), new BinaryType(10), true), // 
identity
+                Arguments.of(new BinaryType(10), new BinaryType(20), true), // 
widening
+                Arguments.of(new BinaryType(20), new BinaryType(10), false), 
// truncation
+
+                // VARBINARY length checks
+                Arguments.of(new VarBinaryType(10), new VarBinaryType(100), 
true), // widening
+                Arguments.of(new VarBinaryType(100), new VarBinaryType(10), 
false), // truncation
+
+                // BINARY → VARBINARY length checks
+                Arguments.of(new BinaryType(10), new VarBinaryType(20), true), 
// widening
+                Arguments.of(new BinaryType(20), new VarBinaryType(10), 
false), // truncation
+
+                // TIMESTAMP precision checks (identity only)
+                Arguments.of(new TimestampType(3), new TimestampType(3), 
true), // identity
+                Arguments.of(
+                        new TimestampType(3), new TimestampType(6), false), // 
widening rejected
+                Arguments.of(
+                        new TimestampType(6), new TimestampType(3), false), // 
narrowing rejected
+
+                // TIMESTAMP ↔ TIMESTAMP_LTZ precision checks (identity only)
+                Arguments.of(
+                        new TimestampType(3), new LocalZonedTimestampType(6), 
false), // rejected
+                Arguments.of(
+                        new LocalZonedTimestampType(6), new TimestampType(3), 
false), // rejected
+
+                // TIMESTAMP_LTZ precision checks (identity only)
+                Arguments.of(
+                        new LocalZonedTimestampType(3),
+                        new LocalZonedTimestampType(3),
+                        true), // identity
+                Arguments.of(
+                        new LocalZonedTimestampType(3),
+                        new LocalZonedTimestampType(6),
+                        false), // rejected
+
+                // TIMESTAMP_TZ precision checks (identity only)
+                Arguments.of(
+                        new ZonedTimestampType(3), new ZonedTimestampType(3), 
true), // identity
+                Arguments.of(
+                        new ZonedTimestampType(3), new ZonedTimestampType(6), 
false), // rejected
+
+                // TIME precision checks (identity only)
+                Arguments.of(new TimeType(0), new TimeType(0), true), // 
identity
+                Arguments.of(new TimeType(0), new TimeType(3), false), // 
widening rejected
+                Arguments.of(new TimeType(3), new TimeType(0), false), // 
narrowing rejected
+
+                // Cross-family casts to VARCHAR with insufficient length
+                Arguments.of(new IntType(), new VarCharType(3), false), // too 
short
+                Arguments.of(new TimestampType(3), new VarCharType(5), false), 
// too short
+
+                // DOUBLE identity
+                Arguments.of(new DoubleType(), new DoubleType(), true));
+    }
+
+    @ParameterizedTest(name = "{index}: [From: {0}, To: {1}, Injective: {2}]")
+    @MethodSource("injectiveCastTestData")
+    void testInjectiveCast(
+            LogicalType sourceType, LogicalType targetType, boolean 
expectedInjective) {
+        assertThat(LogicalTypeCasts.supportsInjectiveCast(sourceType, 
targetType))
+                .as(
+                        "Cast from %s to %s should %s injective",
+                        sourceType, targetType, expectedInjective ? "be" : 
"not be")
+                .isEqualTo(expectedInjective);
+    }
 }
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
index 046c131edf5..c445d8b2949 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala
@@ -173,9 +173,9 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
                 }
               case _ => // ignore
             }
-          // rename or cast
+          // rename or key-preserving cast (fidelity or injective)
           case a: RexCall
-              if (a.getKind.equals(SqlKind.AS) || isFidelityCast(a)) &&
+              if (a.getKind.equals(SqlKind.AS) || isKeyPreservingCast(a)) &&
                 a.getOperands.get(0).isInstanceOf[RexInputRef] =>
             
appendMapInToOutPos(a.getOperands.get(0).asInstanceOf[RexInputRef].getIndex, i)
           case _ => // ignore
@@ -214,14 +214,19 @@ class FlinkRelMdUniqueKeys private extends 
MetadataHandler[BuiltInMetadata.Uniqu
     projUniqueKeySet
   }
 
-  /** Whether the [[RexCall]] is a cast that doesn't lose any information. */
-  private def isFidelityCast(call: RexCall): Boolean = {
+  /**
+   * Whether the [[RexCall]] is a cast that preserves key uniqueness 
(injective cast).
+   *
+   * An injective cast is one where each distinct input maps to a distinct 
output, ensuring that
+   * unique keys remain unique after the cast.
+   */
+  private def isKeyPreservingCast(call: RexCall): Boolean = {
     if (call.getKind != SqlKind.CAST) {
       return false
     }
     val originalType = 
FlinkTypeFactory.toLogicalType(call.getOperands.get(0).getType)
     val newType = FlinkTypeFactory.toLogicalType(call.getType)
-    LogicalTypeCasts.supportsImplicitCast(originalType, newType)
+    LogicalTypeCasts.supportsInjectiveCast(originalType, newType)
   }
 
   def getUniqueKeys(
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.xml
index edf9e924f6b..179e09c8ad3 100644
--- 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.xml
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.xml
@@ -485,6 +485,27 @@ Sink(table=[default_catalog.default_database.sink], 
fields=[a, b])
    +- Exchange(distribution=[single])
       +- Calc(select=[a, b])
          +- DataStreamScan(table=[[default_catalog, default_database, 
MyTable]], fields=[a, b, c])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testInjectiveCastPreservesUpsertKey">
+    <Resource name="ast">
+      <![CDATA[
+LogicalSink(table=[default_catalog.default_database.sink_agg_with_string_pk], 
fields=[EXPR$0, EXPR$1])
++- LogicalProject(EXPR$0=[CAST($0):VARCHAR(2147483647) CHARACTER SET 
"UTF-16LE"], EXPR$1=[$1])
+   +- LogicalAggregate(group=[{0}], EXPR$1=[COUNT()])
+      +- LogicalProject(a=[$0])
+         +- LogicalTableScan(table=[[default_catalog, default_database, 
MyTable]])
+]]>
+    </Resource>
+    <Resource name="optimized rel plan">
+      <![CDATA[
+Sink(table=[default_catalog.default_database.sink_agg_with_string_pk], 
fields=[EXPR$0, EXPR$1], changelogMode=[NONE])
++- Calc(select=[CAST(a AS VARCHAR(2147483647)) AS EXPR$0, EXPR$1], 
changelogMode=[I,UA])
+   +- GroupAggregate(groupBy=[a], select=[a, COUNT(*) AS EXPR$1], 
changelogMode=[I,UA])
+      +- Exchange(distribution=[hash[a]], changelogMode=[I])
+         +- Calc(select=[a], changelogMode=[I])
+            +- DataStreamScan(table=[[default_catalog, default_database, 
MyTable]], fields=[a, b, c], changelogMode=[I])
 ]]>
     </Resource>
   </TestCase>
@@ -596,6 +617,33 @@ 
LogicalSink(table=[default_catalog.default_database.MetadataTable], fields=[meta
       <![CDATA[
 Sink(table=[default_catalog.default_database.MetadataTable], 
fields=[metadata_1, metadata_2, other, m_2])
 +- TableSourceScan(table=[[default_catalog, default_database, MetadataTable, 
project=[metadata_1, metadata_2, other], metadata=[metadata_2]]], 
fields=[metadata_1, metadata_2, other, m_2])
+]]>
+    </Resource>
+  </TestCase>
+  <TestCase name="testNonInjectiveCastLosesUpsertKey">
+    <Resource name="explain">
+      <![CDATA[== Abstract Syntax Tree ==
+LogicalSink(table=[default_catalog.default_database.sink_agg_with_int_pk], 
fields=[EXPR$0, EXPR$1])
++- LogicalProject(EXPR$0=[CAST($0):INTEGER], EXPR$1=[$1])
+   +- LogicalAggregate(group=[{0}], EXPR$1=[COUNT()])
+      +- LogicalProject(c=[$2])
+         +- LogicalTableScan(table=[[default_catalog, default_database, 
MyTable]])
+
+== Optimized Physical Plan ==
+Sink(table=[default_catalog.default_database.sink_agg_with_int_pk], 
fields=[EXPR$0, EXPR$1], upsertMaterialize=[true], 
conflictStrategy=[DEDUPLICATE], changelogMode=[NONE])
++- Calc(select=[CAST(c AS INTEGER) AS EXPR$0, EXPR$1], changelogMode=[I,UB,UA])
+   +- GroupAggregate(groupBy=[c], select=[c, COUNT(*) AS EXPR$1], 
changelogMode=[I,UB,UA])
+      +- Exchange(distribution=[hash[c]], changelogMode=[I])
+         +- Calc(select=[c], changelogMode=[I])
+            +- DataStreamScan(table=[[default_catalog, default_database, 
MyTable]], fields=[a, b, c], changelogMode=[I])
+
+== Optimized Execution Plan ==
+Sink(table=[default_catalog.default_database.sink_agg_with_int_pk], 
fields=[EXPR$0, EXPR$1], upsertMaterialize=[true], 
conflictStrategy=[DEDUPLICATE])
++- Calc(select=[CAST(c AS INTEGER) AS EXPR$0, EXPR$1])
+   +- GroupAggregate(groupBy=[c], select=[c, COUNT(*) AS EXPR$1])
+      +- Exchange(distribution=[hash[c]])
+         +- Calc(select=[c])
+            +- DataStreamScan(table=[[default_catalog, default_database, 
MyTable]], fields=[a, b, c])
 ]]>
     </Resource>
   </TestCase>
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala
index ddb440bc320..41fd47cc62a 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala
@@ -25,6 +25,7 @@ import org.apache.flink.table.planner.plan.utils.ExpandUtil
 import com.google.common.collect.{ImmutableList, ImmutableSet}
 import org.apache.calcite.prepare.CalciteCatalogReader
 import org.apache.calcite.rel.hint.RelHint
+import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR
 import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN}
 import org.apache.calcite.util.ImmutableBitSet
 import org.junit.jupiter.api.Assertions._
@@ -92,10 +93,101 @@ class FlinkRelMdUniqueKeysTest extends 
FlinkRelMdHandlerTestBase {
       relBuilder.field(1)
     )
     val project1 = relBuilder.project(exprs).build()
-    assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(project1).toSet)
+    // INT -> BIGINT is an injective cast, so position 2 is also a unique key
+    assertEquals(uniqueKeys(Array(1), Array(2)), 
mq.getUniqueKeys(project1).toSet)
     assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(project1, 
true).toSet)
   }
 
+  @Test
+  def testGetUniqueKeysOnProjectWithInjectiveCastToString(): Unit = {
+    // INT -> STRING is an injective cast (each distinct int maps to a 
distinct string).
+    // When a unique key column is cast this way, the uniqueness is preserved.
+    relBuilder.push(studentLogicalScan)
+
+    val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+    // Project: CAST(id AS STRING), name
+    // id (position 0 in source) is the unique key
+    val exprs = List(
+      rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS 
STRING)
+      relBuilder.field(1) // name
+    )
+    val project = relBuilder.project(exprs).build()
+
+    // The casted id at position 0 should still be recognized as unique
+    assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(project).toSet)
+  }
+
+  @Test
+  def testGetUniqueKeysOnProjectWithMultipleKeyReferences(): Unit = {
+    // When the same unique key column appears multiple times in a projection
+    // (either raw or via injective cast), all references are recognized as 
keys.
+    relBuilder.push(studentLogicalScan)
+
+    val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+    // Project: CAST(id AS STRING), id, name
+    val exprs = List(
+      rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS 
STRING) - injective
+      relBuilder.field(0), // id (raw reference)
+      relBuilder.field(1) // name
+    )
+    val project = relBuilder.project(exprs).build()
+
+    // Both position 0 (STRING cast of id) and position 1 (raw id) are unique 
keys
+    assertEquals(uniqueKeys(Array(0), Array(1)), 
mq.getUniqueKeys(project).toSet)
+  }
+
+  @Test
+  def testGetUniqueKeysOnProjectInjectiveCastOnlyPreservesExistingKeys(): Unit 
= {
+    // Injective casts PRESERVE uniqueness but don't CREATE it.
+    // Casting a non-key column with an injective cast doesn't make it a key.
+    relBuilder.push(studentLogicalScan)
+
+    val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+    // Project: id, CAST(name AS STRING)
+    // id is the unique key; name is NOT a key (even after casting)
+    val exprs = List(
+      relBuilder.field(0), // id - the unique key
+      rexBuilder.makeCast(stringType, relBuilder.field(1)) // CAST(name AS 
STRING) - not a key
+    )
+    val project = relBuilder.project(exprs).build()
+
+    // Only position 0 (id) is a unique key
+    // Position 1 (cast of name) is NOT a key because name wasn't a key to 
begin with
+    assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(project).toSet)
+  }
+
+  @Test
+  def testGetUniqueKeysOnProjectNonInjectiveCastLosesKey(): Unit = {
+    // STRING -> INT is NOT an injective cast (e.g., "1" and "01" both become 
1).
+    // When a unique key is cast this way, the uniqueness cannot be guaranteed.
+    relBuilder.push(studentLogicalScan)
+
+    val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+    // First, project id as STRING to simulate a STRING key column
+    val stringKeyExprs = List(
+      rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS 
STRING)
+      relBuilder.field(1) // name
+    )
+    val stringKeyProject = relBuilder.project(stringKeyExprs).build()
+    // At this point, position 0 is a STRING that's still a unique key
+    assertEquals(uniqueKeys(Array(0)), 
mq.getUniqueKeys(stringKeyProject).toSet)
+
+    // Now cast the STRING back to INT - this is a non-injective (narrowing) 
cast
+    relBuilder.push(stringKeyProject)
+    val narrowedExprs = List(
+      rexBuilder.makeCast(intType, relBuilder.field(0)), // CAST(string_id AS 
INT) - NOT injective
+      relBuilder.field(1) // name
+    )
+    val narrowedProject = relBuilder.project(narrowedExprs).build()
+
+    // The key is LOST because STRING->INT is not injective
+    assertEquals(uniqueKeys(), mq.getUniqueKeys(narrowedProject).toSet)
+  }
+
   @Test
   def testGetUniqueKeysOnFilter(): Unit = {
     assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalFilter).toSet)
@@ -133,7 +225,8 @@ class FlinkRelMdUniqueKeysTest extends 
FlinkRelMdHandlerTestBase {
     )
     val rowType = relBuilder.project(exprs).build().getRowType
     val calc2 = createLogicalCalc(studentLogicalScan, rowType, exprs, 
List(expr))
-    assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(calc2).toSet)
+    // INT -> BIGINT is an injective cast, so position 2 is also a unique key
+    assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(calc2).toSet)
     assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(calc2, 
true).toSet)
   }
 
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
index 40e38dda161..3b7524e1bd3 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala
@@ -26,6 +26,7 @@ import com.google.common.collect.{ImmutableList, ImmutableSet}
 import org.apache.calcite.prepare.CalciteCatalogReader
 import org.apache.calcite.rel.RelNode
 import org.apache.calcite.rel.hint.RelHint
+import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR
 import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN}
 import org.apache.calcite.util.ImmutableBitSet
 import org.junit.jupiter.api.Assertions._
@@ -88,12 +89,81 @@ class FlinkRelMdUpsertKeysTest extends 
FlinkRelMdHandlerTestBase {
     val exprs = List(
       relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)),
       relBuilder.field(0),
+      // INT -> BIGINT is an injective cast, so position 2 is now also an 
upsert key
       rexBuilder.makeCast(longType, relBuilder.field(0)),
       rexBuilder.makeCast(intType, relBuilder.field(0)),
       relBuilder.field(1)
     )
     val project1 = relBuilder.project(exprs).build()
-    assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(project1).toSet)
+
+    assertEquals(toBitSet(Array(1), Array(2)), 
mq.getUpsertKeys(project1).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnProjectWithInjectiveCastToString(): Unit = {
+    // INT -> STRING is an injective cast (each distinct int maps to a 
distinct string).
+    // When an upsert key column is cast this way, the key property is 
preserved.
+    relBuilder.push(studentLogicalScan)
+
+    val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+    // Project: CAST(id AS STRING), name
+    val exprs = List(
+      rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS 
STRING)
+      relBuilder.field(1) // name
+    )
+    val project = relBuilder.project(exprs).build()
+
+    // The casted id at position 0 should still be recognized as upsert key
+    assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(project).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnProjectWithMultipleKeyReferences(): Unit = {
+    // When the same upsert key column appears multiple times in a projection
+    // (either raw or via injective cast), all references are recognized as 
keys.
+    relBuilder.push(studentLogicalScan)
+
+    val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+    // Project: CAST(id AS STRING), id, name
+    val exprs = List(
+      rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS 
STRING) - injective
+      relBuilder.field(0), // id (raw reference)
+      relBuilder.field(1) // name
+    )
+    val project = relBuilder.project(exprs).build()
+
+    // Both position 0 (STRING cast of id) and position 1 (raw id) are upsert 
keys
+    assertEquals(toBitSet(Array(0), Array(1)), mq.getUpsertKeys(project).toSet)
+  }
+
+  @Test
+  def testGetUpsertKeysOnProjectNonInjectiveCastLosesKey(): Unit = {
+    // STRING -> INT is NOT an injective cast (e.g., "1" and "01" both become 
1).
+    // When an upsert key is cast this way, the key property is lost.
+    relBuilder.push(studentLogicalScan)
+
+    val stringType = typeFactory.createSqlType(VARCHAR, 100)
+
+    // First, project id as STRING to simulate a STRING key column
+    val stringKeyExprs = List(
+      rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS 
STRING)
+      relBuilder.field(1) // name
+    )
+    val stringKeyProject = relBuilder.project(stringKeyExprs).build()
+    assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(stringKeyProject).toSet)
+
+    // Now cast the STRING back to INT - this is a non-injective cast
+    relBuilder.push(stringKeyProject)
+    val narrowedExprs = List(
+      rexBuilder.makeCast(intType, relBuilder.field(0)), // CAST(string_id AS 
INT) - NOT injective
+      relBuilder.field(1) // name
+    )
+    val narrowedProject = relBuilder.project(narrowedExprs).build()
+
+    // The key is LOST because STRING->INT is not injective
+    assertEquals(toBitSet(), mq.getUpsertKeys(narrowedProject).toSet)
   }
 
   @Test
@@ -133,7 +203,8 @@ class FlinkRelMdUpsertKeysTest extends 
FlinkRelMdHandlerTestBase {
     )
     val rowType = relBuilder.project(exprs).build().getRowType
     val calc2 = createLogicalCalc(studentLogicalScan, rowType, exprs, 
List(expr))
-    assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(calc2).toSet)
+    // INT -> BIGINT is an injective cast, so position 2 is now also an upsert 
key
+    assertEquals(toBitSet(Array(1), Array(2)), mq.getUpsertKeys(calc2).toSet)
   }
 
   @Test
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.scala
index 6e0b5323964..ef9776caedc 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/TableSinkTest.scala
@@ -947,6 +947,55 @@ class TableSinkTest extends TableTestBase {
     util.verifyRelPlan(stmtSet, ExplainDetail.CHANGELOG_MODE)
   }
 
+  @Test
+  def testInjectiveCastPreservesUpsertKey(): Unit = {
+    // Aggregation produces upsert stream with key (a).
+    // Sink expects STRING primary key.
+    // CAST(INT AS STRING) is injective, so the upsert key is preserved - no 
materializer needed.
+    util.tableEnv.executeSql("""
+                               |CREATE TABLE sink_agg_with_string_pk (
+                               |  id STRING NOT NULL,
+                               |  cnt BIGINT,
+                               |  PRIMARY KEY (id) NOT ENFORCED
+                               |) WITH (
+                               |  'connector' = 'values',
+                               |  'sink-insert-only' = 'false'
+                               |)
+                               |""".stripMargin)
+    val stmtSet = util.tableEnv.createStatementSet()
+    // GROUP BY a produces upsert key (a), then CAST(a AS STRING) preserves it
+    stmtSet.addInsertSql(
+      "INSERT INTO sink_agg_with_string_pk SELECT CAST(a AS STRING), COUNT(*) 
FROM MyTable GROUP BY a")
+    // The plan should NOT contain upsertMaterialize=[true] because 
INT->STRING is injective
+    util.verifyRelPlan(stmtSet, ExplainDetail.CHANGELOG_MODE)
+  }
+
+  @Test
+  def testNonInjectiveCastLosesUpsertKey(): Unit = {
+    // Aggregation produces upsert stream with key (c) which is STRING.
+    // Sink expects INT primary key.
+    // CAST(STRING AS INT) is NOT injective ("1" and "01" both become 1), so 
key is lost.
+    util.tableEnv.executeSql("""
+                               |CREATE TABLE sink_agg_with_int_pk (
+                               |  id INT NOT NULL,
+                               |  cnt BIGINT,
+                               |  PRIMARY KEY (id) NOT ENFORCED
+                               |) WITH (
+                               |  'connector' = 'values',
+                               |  'sink-insert-only' = 'false'
+                               |)
+                               |""".stripMargin)
+    // GROUP BY c (STRING) produces upsert key (c), then CAST(c AS INT) loses 
it
+    util.verifyExplainInsert(
+      """
+        |INSERT INTO sink_agg_with_int_pk
+        |SELECT CAST(c AS INT), COUNT(*) FROM MyTable GROUP BY c
+        |ON CONFLICT DO DEDUPLICATE
+        |""".stripMargin,
+      ExplainDetail.CHANGELOG_MODE
+    )
+  }
+
 }
 
 /** tests table factory use ParallelSourceFunction which support parallelism 
by env */
diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/ChangelogSourceITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/ChangelogSourceITCase.scala
index d2926b54325..f593875a5a3 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/ChangelogSourceITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/ChangelogSourceITCase.scala
@@ -46,8 +46,8 @@ import scala.collection.JavaConversions._
 class ChangelogSourceITCase(
     sourceMode: SourceMode,
     miniBatch: MiniBatchMode,
-    state: StateBackendMode)
-  extends StreamingWithMiniBatchTestBase(miniBatch, state) {
+    state: StateBackendMode
+) extends StreamingWithMiniBatchTestBase(miniBatch, state) {
 
   @RegisterExtension private val _: EachCallbackWrapper[LegacyRowExtension] =
     new EachCallbackWrapper[LegacyRowExtension](new LegacyRowExtension)
@@ -192,12 +192,16 @@ class ChangelogSourceITCase(
          | 'sink-insert-only' = 'false'
          |)
          |""".stripMargin
+    // Note: balance2 is computed as balance*2 which results in a higher 
precision DECIMAL type.
+    // Since the sink's balance column has fixed precision, this requires ON 
CONFLICT handling
+    // because narrowing DECIMAL casts (e.g., DECIMAL(28,2) -> DECIMAL(18,2)) 
are not injective.
     val dml =
       s"""
          |INSERT INTO user_sink
          |SELECT balance2, count(*), max(email)
          |FROM users
          |GROUP BY balance2
+         |ON CONFLICT DO DEDUPLICATE
          |""".stripMargin
     tEnv.executeSql(sinkDDL)
     tEnv.executeSql(dml).await()

Reply via email to