This is an automated email from the ASF dual-hosted git repository.

twalthr pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 5bf35093944 [FLINK-39015][table] Fix key extractor for multi join by 
changing GenericRowData to BinaryRowData
5bf35093944 is described below

commit 5bf350939440956479dc87bd34ad954e23a4ca37
Author: Dmitriy Linevich <[email protected]>
AuthorDate: Thu Mar 12 22:55:10 2026 +0700

    [FLINK-39015][table] Fix key extractor for multi join by changing 
GenericRowData to BinaryRowData
    
    This closes #27508.
---
 .../planner/runtime/stream/sql/JoinITCase.scala    | 17 ++++++
 .../join/stream/StreamingMultiJoinOperator.java    | 14 +++++
 .../AttributeBasedJoinKeyExtractor.java            | 63 ++++++++++++++++++----
 .../join/stream/keyselector/JoinKeyExtractor.java  |  3 ++
 .../table/runtime/typeutils/RowDataSerializer.java |  6 ++-
 .../StreamingMultiJoinOperatorTestBase.java        |  1 +
 6 files changed, 93 insertions(+), 11 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/JoinITCase.scala
 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/JoinITCase.scala
index e55fa33070b..fa5ef6a2752 100644
--- 
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/JoinITCase.scala
+++ 
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/JoinITCase.scala
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.runtime.stream.sql
 import org.apache.flink.table.api._
 import org.apache.flink.table.api.bridge.scala._
 import org.apache.flink.table.api.config.ExecutionConfigOptions
+import org.apache.flink.table.api.config.OptimizerConfigOptions
 import org.apache.flink.table.planner.expressions.utils.FuncWithOpen
 import org.apache.flink.table.planner.factories.TestValuesTableFactory
 import org.apache.flink.table.planner.runtime.utils._
@@ -623,6 +624,22 @@ class JoinITCase(miniBatch: MiniBatchMode, state: 
StateBackendMode, enableAsyncS
     assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
   }
 
+  @TestTemplate
+  def testInnerMultiJoinWithEqualPk(): Unit = {
+    tEnv.getConfig.getConfiguration
+      
.setString(OptimizerConfigOptions.TABLE_OPTIMIZER_MULTI_JOIN_ENABLED.key(), 
"true")
+    val query1 = "SELECT SUM(a2) AS a2, a1 FROM A group by a1"
+    val query2 = "SELECT SUM(b2) AS b2, b1 FROM B group by b1"
+    val query = s"SELECT a1, b1 FROM ($query1) JOIN ($query2) ON a1 = b1"
+
+    val sink = new TestingRetractSink
+    tEnv.sqlQuery(query).toRetractStream[Row].addSink(sink).setParallelism(1)
+    env.execute()
+
+    val expected = Seq("1,1", "2,2", "3,3")
+    assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
+  }
+
   @TestTemplate
   def testInnerJoinWithPk(): Unit = {
     val query1 = "SELECT SUM(a2) AS a2, a1 FROM A group by a1"
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java
index 3764b886d00..77ab15e0b33 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperator.java
@@ -19,6 +19,7 @@
 package org.apache.flink.table.runtime.operators.join.stream;
 
 import org.apache.flink.api.common.functions.DefaultOpenContext;
+import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.streaming.api.operators.AbstractInput;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperatorV2;
 import org.apache.flink.streaming.api.operators.Input;
@@ -44,6 +45,7 @@ import org.apache.flink.types.RowKind;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 /**
  * Streaming multi-way join operator which supports inner join and left outer 
join, right joins are
@@ -339,6 +341,7 @@ public class StreamingMultiJoinOperator extends 
AbstractStreamOperatorV2<RowData
         implements MultipleInputStreamOperator<RowData> {
     private static final long serialVersionUID = 1L;
 
+    private static final Set<String> HEAP_STATE_BACKENDS = Set.of("hashmap");
     private final List<JoinInputSideSpec> inputSpecs;
     private final List<FlinkJoinType> joinTypes;
     private final List<RowType> inputTypes;
@@ -609,6 +612,12 @@ public class StreamingMultiJoinOperator extends 
AbstractStreamOperatorV2<RowData
         }
     }
 
+    private boolean isHeapBackend() {
+        KeyedStateBackend<?> backend = getKeyedStateBackend();
+        String backendName = backend.getBackendTypeIdentifier();
+        return HEAP_STATE_BACKENDS.contains(backendName);
+    }
+
     private static RowData newJoinedRowData(int depth, RowData joinedRowData, 
RowData record) {
         RowData newJoinedRowData;
         if (depth == 0) {
@@ -808,6 +817,11 @@ public class StreamingMultiJoinOperator extends 
AbstractStreamOperatorV2<RowData
                     "Keyed state store not found when initializing keyed state 
store handlers.");
         }
 
+        boolean prohibitReuseRow = isHeapBackend();
+        if (prohibitReuseRow) {
+            this.keyExtractor.requiresKeyDeepCopy();
+        }
+
         this.stateHandlers = new ArrayList<>(inputSpecs.size());
         for (int i = 0; i < inputSpecs.size(); i++) {
             MultiJoinStateView stateView;
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java
index d5e073c8b9b..6de9e570adc 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java
@@ -20,6 +20,7 @@ package 
org.apache.flink.table.runtime.operators.join.stream.keyselector;
 
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
 import org.apache.flink.table.types.logical.LogicalType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.utils.NoCommonJoinKeyException;
@@ -64,13 +65,22 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
     // leftKeyExtractorsMap: extractors that read the left side (joined row so 
far)
     //   using the same attribute order as in joinAttributeMap.
     private final Map<Integer, List<KeyExtractor>> leftKeyExtractorsMap;
+    // RowData serializers for left key.
+    private final Map<Integer, RowDataSerializer> leftKeySerializersMap;
     // rightKeyExtractorsMap: extractors to extract the right-side key from 
each input.
     private final Map<Integer, List<KeyExtractor>> rightKeyExtractorsMap;
+    // RowData serializers for right key.
+    private final Map<Integer, RowDataSerializer> rightKeySerializersMap;
 
     // Data structures for the "common join key" shared by all inputs.
     // Input 0 provides the canonical order and defines commonJoinKeyType.
     private final Map<Integer, List<KeyExtractor>> commonJoinKeyExtractors;
+    // RowData serializers for common join key.
+    private final Map<Integer, RowDataSerializer> commonJoinKeySerializersMap;
     private RowType commonJoinKeyType;
+    // Controls whether key rows built are serialized and copied. This is 
required for the
+    // Heap State Backend to prevent object reuse issues, which can lead to 
data corruption
+    private boolean requiresKeyDeepCopy;
 
     /**
      * Creates an AttributeBasedJoinKeyExtractor.
@@ -87,8 +97,12 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
         this.joinAttributeMap = joinAttributeMap;
         this.inputTypes = inputTypes;
         this.leftKeyExtractorsMap = new HashMap<>();
+        this.leftKeySerializersMap = new HashMap<>();
         this.rightKeyExtractorsMap = new HashMap<>();
+        this.rightKeySerializersMap = new HashMap<>();
         this.commonJoinKeyExtractors = new HashMap<>();
+        this.commonJoinKeySerializersMap = new HashMap<>();
+        this.requiresKeyDeepCopy = false;
 
         initializeCaches();
         initializeCommonJoinKeyStructures();
@@ -113,7 +127,7 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
             return null;
         }
 
-        return buildKeyRowFromSourceRow(row, keyExtractors);
+        return buildKeyRowFromSourceRow(row, keyExtractors, 
rightKeySerializersMap.get(inputId));
     }
 
     @Override
@@ -127,7 +141,8 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
             return null;
         }
 
-        return buildKeyRowFromJoinedRow(keyExtractors, joinedRowData);
+        return buildKeyRowFromJoinedRow(
+                keyExtractors, joinedRowData, 
leftKeySerializersMap.get(depth));
     }
 
     @Override
@@ -168,7 +183,7 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
             return null;
         }
 
-        return buildKeyRowFromSourceRow(row, extractors);
+        return buildKeyRowFromSourceRow(row, extractors, 
commonJoinKeySerializersMap.get(inputId));
     }
 
     @Override
@@ -181,13 +196,23 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
         return 
extractors.stream().mapToInt(KeyExtractor::getFieldIndexInSourceRow).toArray();
     }
 
+    @Override
+    public void requiresKeyDeepCopy() {
+        this.requiresKeyDeepCopy = true;
+    }
+
     // ==================== Initialization Methods ====================
 
     private void initializeCaches() {
         if (this.inputTypes != null) {
             for (int i = 0; i < this.inputTypes.size(); i++) {
-                this.leftKeyExtractorsMap.put(i, 
createLeftJoinKeyFieldExtractors(i));
-                this.rightKeyExtractorsMap.put(i, 
createRightJoinKeyExtractors(i));
+                List<KeyExtractor> leftJoinKeyExtractors = 
createLeftJoinKeyFieldExtractors(i);
+                this.leftKeyExtractorsMap.put(i, leftJoinKeyExtractors);
+                this.leftKeySerializersMap.put(i, 
createJoinKeySerializer(leftJoinKeyExtractors));
+
+                List<KeyExtractor> rightJoinKeyExtractors = 
createRightJoinKeyExtractors(i);
+                this.rightKeyExtractorsMap.put(i, rightJoinKeyExtractors);
+                this.rightKeySerializersMap.put(i, 
createJoinKeySerializer(rightJoinKeyExtractors));
             }
         }
     }
@@ -226,6 +251,11 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
         return keyExtractors;
     }
 
+    private RowDataSerializer createJoinKeySerializer(List<KeyExtractor> 
keyExtractors) {
+        return new RowDataSerializer(
+                keyExtractors.stream().map(e -> 
e.fieldType).toArray(LogicalType[]::new));
+    }
+
     private static AttributeRef getLeftAttributeRef(
             final int inputId, final ConditionAttributeRef entry) {
         final AttributeRef leftAttrRef = new AttributeRef(entry.leftInputId, 
entry.leftFieldIndex);
@@ -279,7 +309,9 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
     // ==================== Key Building Methods ====================
 
     private RowData buildKeyRowFromJoinedRow(
-            final List<KeyExtractor> keyExtractors, final RowData 
joinedRowData) {
+            final List<KeyExtractor> keyExtractors,
+            final RowData joinedRowData,
+            RowDataSerializer keySerializer) {
         if (keyExtractors.isEmpty()) {
             return null;
         }
@@ -288,11 +320,17 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
         for (int i = 0; i < keyExtractors.size(); i++) {
             keyRow.setField(i, 
keyExtractors.get(i).getLeftSideKey(joinedRowData));
         }
-        return keyRow;
+        if (requiresKeyDeepCopy) {
+            return keySerializer.toBinaryRow(keyRow, true);
+        } else {
+            return keyRow;
+        }
     }
 
-    private GenericRowData buildKeyRowFromSourceRow(
-            final RowData sourceRow, final List<KeyExtractor> keyExtractors) {
+    private RowData buildKeyRowFromSourceRow(
+            final RowData sourceRow,
+            final List<KeyExtractor> keyExtractors,
+            RowDataSerializer keySerializer) {
         if (keyExtractors.isEmpty()) {
             return null;
         }
@@ -301,7 +339,11 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
         for (int i = 0; i < keyExtractors.size(); i++) {
             keyRow.setField(i, 
keyExtractors.get(i).getRightSideKey(sourceRow));
         }
-        return keyRow;
+        if (requiresKeyDeepCopy) {
+            return keySerializer.toBinaryRow(keyRow, true);
+        } else {
+            return keyRow;
+        }
     }
 
     private RowType buildJoinKeyType(final int inputId, final 
List<KeyExtractor> keyExtractors) {
@@ -608,6 +650,7 @@ public class AttributeBasedJoinKeyExtractor implements 
JoinKeyExtractor, Seriali
 
         final LogicalType[] keyFieldTypes =
                 extractors.stream().map(e -> 
e.fieldType).toArray(LogicalType[]::new);
+        this.commonJoinKeySerializersMap.put(currentInputId, new 
RowDataSerializer(keyFieldTypes));
         if (currentInputId == 0 && !extractors.isEmpty()) {
             this.commonJoinKeyType = RowType.of(keyFieldTypes, keyFieldNames);
         }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/JoinKeyExtractor.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/JoinKeyExtractor.java
index 7e6f91d6ff7..4fb96d28ed1 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/JoinKeyExtractor.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/JoinKeyExtractor.java
@@ -109,4 +109,7 @@ public interface JoinKeyExtractor extends Serializable {
      *     common join key.
      */
     int[] getCommonJoinKeyIndices(int inputId);
+
+    /** Enables copying of row data. */
+    void requiresKeyDeepCopy();
 }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/RowDataSerializer.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/RowDataSerializer.java
index 3106d9e65d5..a83ea7a5e55 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/RowDataSerializer.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/typeutils/RowDataSerializer.java
@@ -188,10 +188,14 @@ public class RowDataSerializer extends 
AbstractRowDataSerializer<RowData> {
     /** Convert {@link RowData} into {@link BinaryRowData}. TODO modify it to 
code gen. */
     @Override
     public BinaryRowData toBinaryRow(RowData row) {
+        return toBinaryRow(row, false);
+    }
+
+    public BinaryRowData toBinaryRow(RowData row, boolean requiresDeepCopy) {
         if (row instanceof BinaryRowData) {
             return (BinaryRowData) row;
         }
-        if (reuseRow == null) {
+        if (reuseRow == null || requiresDeepCopy) {
             reuseRow = new BinaryRowData(types.length);
             reuseWriter = new BinaryRowWriter(reuseRow);
         }
diff --git 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingMultiJoinOperatorTestBase.java
 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingMultiJoinOperatorTestBase.java
index b09c7f13a95..27861fafb11 100644
--- 
a/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingMultiJoinOperatorTestBase.java
+++ 
b/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/multijoin/StreamingMultiJoinOperatorTestBase.java
@@ -362,6 +362,7 @@ public abstract class StreamingMultiJoinOperatorTestBase 
extends StateParameteri
 
         public SerializableKeySelector(JoinKeyExtractor keyExtractor, int 
inputIndex) {
             this.keyExtractor = keyExtractor;
+            this.keyExtractor.requiresKeyDeepCopy();
             this.inputIndex = inputIndex;
         }
 

Reply via email to