dawidwys commented on code in PR #26313: URL: https://github.com/apache/flink/pull/26313#discussion_r2081393700
########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java: ########## @@ -0,0 +1,299 @@ +package org.apache.flink.table.runtime.operators.join.stream.keyselector; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; +import java.util.stream.Collectors; + +/** + * A {@link JoinKeyExtractor} that derives keys based on {@link AttributeRef} mappings provided in + * {@code joinAttributeMap}. It defines how attributes from different input streams are related + * through equi-join conditions, assuming input 0 is the base and subsequent inputs join to + * preceding ones. + */ +public class AttributeBasedJoinKeyExtractor implements JoinKeyExtractor { + private static final long serialVersionUID = 1L; + + // Default key/type used when no specific join keys are applicable (e.g., input 0, cross joins). + private static final GenericRowData DEFAULT_KEY = new GenericRowData(1); + + static { + DEFAULT_KEY.setField(0, "__DEFAULT_MULTI_JOIN_SATE_KEY__"); Review Comment: ```suggestion DEFAULT_KEY.setField(0, "__DEFAULT_MULTI_JOIN_STATE_KEY__"); ``` ########## flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperatorTest.java: ########## @@ -0,0 +1,1257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.stream; + +import org.apache.flink.table.runtime.generated.GeneratedMultiJoinCondition; +import org.apache.flink.table.runtime.operators.join.stream.StreamingMultiJoinOperator.JoinType; +import org.apache.flink.table.runtime.operators.join.stream.keyselector.AttributeBasedJoinKeyExtractor.AttributeRef; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class StreamingTwoWayInnerMultiJoinOperatorTest extends StreamingMultiJoinOperatorTestBase { + + public StreamingTwoWayInnerMultiJoinOperatorTest() { + // For inner join test, set outerJoinFlags to false for all inputs + super(2, List.of(JoinType.INNER, JoinType.INNER), defaultConditions(), false); + } + + /** SELECT u.*, o.* FROM Users u INNER JOIN Orders o ON u.id = o.user_id. */ + @Test + void testTwoWayInnerJoin() throws Exception { + /* -------- APPEND TESTS ----------- */ + + // Users without orders aren't emitted + insertUser("1", "Gus", "User 1 Details"); + emitsNothing(); + + // User joins with matching order + insertOrder("1", "order_1", "Order 1 Details"); + emits(INSERT, "1", "Gus", "User 1 Details", "1", "order_1", "Order 1 Details"); + + // Orders without users aren't emitted + insertOrder("2", "order_2", "Order 2 Details"); + emitsNothing(); + + // Adding matching user triggers join + insertUser("2", "Bob", "User 2 Details"); + emits(INSERT, "2", "Bob", "User 2 Details", "2", "order_2", "Order 2 Details"); + } + + /** + * SELECT u.*, o.* FROM Users u INNER JOIN Orders o ON u.id = o.user_id -- Test updates and + * deletes on both sides. + */ + @Test + void testTwoWayInnerJoinUpdating() throws Exception { + /* -------- SETUP BASE DATA ----------- */ + insertUser("1", "Gus", "User 1 Details"); + emitsNothing(); + + insertOrder("1", "order_1", "Order 1 Details"); + emits(INSERT, "1", "Gus", "User 1 Details", "1", "order_1", "Order 1 Details"); + + /* -------- UPDATE TESTS ----------- */ + + // +U on user.details emits +U + updateAfterUser("1", "Gus", "User 1 Details Updated"); + emits( + UPDATE_AFTER, + "1", + "Gus", + "User 1 Details Updated", + "1", + "order_1", + "Order 1 Details"); + + // +U on order.details emits +U + updateAfterOrder("1", "order_1", "Order 1 Details Updated"); + emits( + UPDATE_AFTER, + "1", + "Gus", + "User 1 Details Updated", + "1", + "order_1", + "Order 1 Details Updated"); + + /* -------- DELETE TESTS ----------- */ + + // -D on order emits -D + deleteOrder("1", "order_1", "Order 1 Details Updated"); + emits( + DELETE, + "1", + "Gus", + "User 1 Details Updated", + "1", + "order_1", + "Order 1 Details Updated"); + + // Re-insert order emits +I + insertOrder("1", "order_1", "Order 1 New Details"); + emits(INSERT, "1", "Gus", "User 1 Details Updated", "1", "order_1", "Order 1 New Details"); + } +} + +class StreamingTwoWayOuterMultiJoinOperatorTest extends StreamingMultiJoinOperatorTestBase { Review Comment: Please put the tests in separate files ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/state/MultiJoinStateViews.java: ########## @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.stream.state; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.operators.join.stream.utils.JoinInputSideSpec; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.util.IterableIterator; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Factory class to create different implementations of {@link MultiJoinStateView} based on the + * characteristics described in {@link JoinInputSideSpec}. + * + * <p>Each state view uses a {@link MapState} where the primary key is the `mapKey` derived from the + * join conditions (via {@link + * org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor}). The value + * stored within this map depends on whether the input side has a unique key and how it relates to + * the join key, optimizing storage and access patterns. + */ +public final class MultiJoinStateViews { + + /** Creates a {@link MultiJoinStateView} depends on {@link JoinInputSideSpec}. */ + public static MultiJoinStateView create( + RuntimeContext ctx, + String stateName, + JoinInputSideSpec inputSideSpec, + InternalTypeInfo<RowData> mapKeyType, // Type info for the outer map key + InternalTypeInfo<RowData> recordType, + long retentionTime) { + StateTtlConfig ttlConfig = createTtlConfig(retentionTime); + + if (inputSideSpec.hasUniqueKey()) { + if (inputSideSpec.joinKeyContainsUniqueKey()) { + return new JoinKeyContainsUniqueKey( + ctx, stateName, mapKeyType, recordType, ttlConfig); + } else { + return new InputSideHasUniqueKey( + ctx, + stateName, + mapKeyType, + recordType, + inputSideSpec.getUniqueKeyType(), + inputSideSpec.getUniqueKeySelector(), + ttlConfig); + } + } else { + return new InputSideHasNoUniqueKey(ctx, stateName, mapKeyType, recordType, ttlConfig); + } + } + + /** + * Creates a {@link MapStateDescriptor} with the given parameters and applies TTL configuration. + * + * @param <K> Key type + * @param <V> Value type + * @param stateName Unique name for the state + * @param keyTypeInfo Type information for the key + * @param valueTypeInfo Type information for the value + * @param ttlConfig State TTL configuration + * @return Configured MapStateDescriptor + */ + private static <K, V> MapStateDescriptor<K, V> createStateDescriptor( + String stateName, + TypeInformation<K> keyTypeInfo, + TypeInformation<V> valueTypeInfo, + StateTtlConfig ttlConfig) { + MapStateDescriptor<K, V> descriptor = + new MapStateDescriptor<>(stateName, keyTypeInfo, valueTypeInfo); + if (ttlConfig.isEnabled()) { + descriptor.enableTimeToLive(ttlConfig); + } + return descriptor; + } + + // ------------------------------------------------------------------------------------ + // Multi Join State View Implementations + // ------------------------------------------------------------------------------------ + + /** + * State view for input sides where the unique key is fully contained within the join key. + * + * <p>Stores data as {@code MapState<MapKey, Record>}. + */ + private static final class JoinKeyContainsUniqueKey implements MultiJoinStateView { Review Comment: If `joinKey` contains `uniqueKey` if I understand correctly the `uniqueKey` will be equal to the shuffle key and thus : 1. The state is scoped by the `mapKey` already 2. There will be at most 1. record per key ########## flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperatorTest.java: ########## @@ -0,0 +1,1257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.stream; + +import org.apache.flink.table.runtime.generated.GeneratedMultiJoinCondition; +import org.apache.flink.table.runtime.operators.join.stream.StreamingMultiJoinOperator.JoinType; +import org.apache.flink.table.runtime.operators.join.stream.keyselector.AttributeBasedJoinKeyExtractor.AttributeRef; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class StreamingTwoWayInnerMultiJoinOperatorTest extends StreamingMultiJoinOperatorTestBase { Review Comment: Could you please add tests for: * non unique key conditions * non equi-join conditions (if they're supported) ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/state/MultiJoinStateViews.java: ########## @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.stream.state; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.operators.join.stream.utils.JoinInputSideSpec; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.util.IterableIterator; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Factory class to create different implementations of {@link MultiJoinStateView} based on the + * characteristics described in {@link JoinInputSideSpec}. + * + * <p>Each state view uses a {@link MapState} where the primary key is the `mapKey` derived from the + * join conditions (via {@link + * org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor}). The value + * stored within this map depends on whether the input side has a unique key and how it relates to + * the join key, optimizing storage and access patterns. + */ +public final class MultiJoinStateViews { + + /** Creates a {@link MultiJoinStateView} depends on {@link JoinInputSideSpec}. */ + public static MultiJoinStateView create( + RuntimeContext ctx, + String stateName, + JoinInputSideSpec inputSideSpec, + InternalTypeInfo<RowData> mapKeyType, // Type info for the outer map key + InternalTypeInfo<RowData> recordType, + long retentionTime) { + StateTtlConfig ttlConfig = createTtlConfig(retentionTime); + + if (inputSideSpec.hasUniqueKey()) { + if (inputSideSpec.joinKeyContainsUniqueKey()) { + return new JoinKeyContainsUniqueKey( + ctx, stateName, mapKeyType, recordType, ttlConfig); + } else { + return new InputSideHasUniqueKey( + ctx, + stateName, + mapKeyType, + recordType, + inputSideSpec.getUniqueKeyType(), + inputSideSpec.getUniqueKeySelector(), + ttlConfig); + } + } else { + return new InputSideHasNoUniqueKey(ctx, stateName, mapKeyType, recordType, ttlConfig); + } + } + + /** + * Creates a {@link MapStateDescriptor} with the given parameters and applies TTL configuration. + * + * @param <K> Key type + * @param <V> Value type + * @param stateName Unique name for the state + * @param keyTypeInfo Type information for the key + * @param valueTypeInfo Type information for the value + * @param ttlConfig State TTL configuration + * @return Configured MapStateDescriptor + */ + private static <K, V> MapStateDescriptor<K, V> createStateDescriptor( + String stateName, + TypeInformation<K> keyTypeInfo, + TypeInformation<V> valueTypeInfo, + StateTtlConfig ttlConfig) { + MapStateDescriptor<K, V> descriptor = + new MapStateDescriptor<>(stateName, keyTypeInfo, valueTypeInfo); + if (ttlConfig.isEnabled()) { + descriptor.enableTimeToLive(ttlConfig); + } + return descriptor; + } + + // ------------------------------------------------------------------------------------ + // Multi Join State View Implementations + // ------------------------------------------------------------------------------------ + + /** + * State view for input sides where the unique key is fully contained within the join key. + * + * <p>Stores data as {@code MapState<MapKey, Record>}. + */ + private static final class JoinKeyContainsUniqueKey implements MultiJoinStateView { + + // stores record in the mapping <MapKey, Record> + private final MapState<RowData, RowData> recordState; + private final List<RowData> reusedList; + + private JoinKeyContainsUniqueKey( + RuntimeContext ctx, + final String stateName, + final InternalTypeInfo<RowData> mapKeyType, + final InternalTypeInfo<RowData> recordType, + final StateTtlConfig ttlConfig) { + + MapStateDescriptor<RowData, RowData> recordStateDesc = + createStateDescriptor(stateName, mapKeyType, recordType, ttlConfig); + + this.recordState = ctx.getMapState(recordStateDesc); + // the result records always not more than 1 per mapKey + this.reusedList = new ArrayList<>(1); + } + + @Override + public void addRecord(RowData mapKey, RowData record) throws Exception { + recordState.put(mapKey, record); + } + + @Override + public void retractRecord(RowData mapKey, RowData record) throws Exception { + // Only one record is kept per mapKey, remove it directly. + recordState.remove(mapKey); + } + + @Override + public Iterable<RowData> getRecords(RowData mapKey) throws Exception { + reusedList.clear(); + RowData record = recordState.get(mapKey); + if (record != null) { + reusedList.add(record); + } + return reusedList; + } + + @Override + public void cleanup(RowData mapKey) throws Exception { + recordState.remove(mapKey); + } + } + + /** + * State view for input sides that have a unique key, but it differs from the join key. + * + * <p>Stores data as {@code MapState<MapKey, Map<UK, Record>>}. + */ + private static final class InputSideHasUniqueKey implements MultiJoinStateView { + + // stores map in the mapping <MapKey, Map<UK, Record>> + private final MapState<RowData, Map<RowData, RowData>> recordState; Review Comment: Very quick idea/note. Maybe we could do: ``` MapState<RowData, List<RowData, RowData>> mapKeyToUniqueKey; MapState<RowData, RowData> records // uniqueKey. -> record `` ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/state/MultiJoinStateViews.java: ########## @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.stream.state; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.operators.join.stream.utils.JoinInputSideSpec; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.util.IterableIterator; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Factory class to create different implementations of {@link MultiJoinStateView} based on the + * characteristics described in {@link JoinInputSideSpec}. + * + * <p>Each state view uses a {@link MapState} where the primary key is the `mapKey` derived from the + * join conditions (via {@link + * org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor}). The value + * stored within this map depends on whether the input side has a unique key and how it relates to + * the join key, optimizing storage and access patterns. + */ +public final class MultiJoinStateViews { + + /** Creates a {@link MultiJoinStateView} depends on {@link JoinInputSideSpec}. */ + public static MultiJoinStateView create( + RuntimeContext ctx, + String stateName, + JoinInputSideSpec inputSideSpec, + InternalTypeInfo<RowData> mapKeyType, // Type info for the outer map key + InternalTypeInfo<RowData> recordType, + long retentionTime) { + StateTtlConfig ttlConfig = createTtlConfig(retentionTime); + + if (inputSideSpec.hasUniqueKey()) { + if (inputSideSpec.joinKeyContainsUniqueKey()) { + return new JoinKeyContainsUniqueKey( + ctx, stateName, mapKeyType, recordType, ttlConfig); + } else { + return new InputSideHasUniqueKey( + ctx, + stateName, + mapKeyType, + recordType, + inputSideSpec.getUniqueKeyType(), + inputSideSpec.getUniqueKeySelector(), + ttlConfig); + } + } else { + return new InputSideHasNoUniqueKey(ctx, stateName, mapKeyType, recordType, ttlConfig); + } + } + + /** + * Creates a {@link MapStateDescriptor} with the given parameters and applies TTL configuration. + * + * @param <K> Key type + * @param <V> Value type + * @param stateName Unique name for the state + * @param keyTypeInfo Type information for the key + * @param valueTypeInfo Type information for the value + * @param ttlConfig State TTL configuration + * @return Configured MapStateDescriptor + */ + private static <K, V> MapStateDescriptor<K, V> createStateDescriptor( + String stateName, + TypeInformation<K> keyTypeInfo, + TypeInformation<V> valueTypeInfo, + StateTtlConfig ttlConfig) { + MapStateDescriptor<K, V> descriptor = + new MapStateDescriptor<>(stateName, keyTypeInfo, valueTypeInfo); + if (ttlConfig.isEnabled()) { + descriptor.enableTimeToLive(ttlConfig); + } + return descriptor; + } + + // ------------------------------------------------------------------------------------ + // Multi Join State View Implementations + // ------------------------------------------------------------------------------------ + + /** + * State view for input sides where the unique key is fully contained within the join key. + * + * <p>Stores data as {@code MapState<MapKey, Record>}. + */ + private static final class JoinKeyContainsUniqueKey implements MultiJoinStateView { + + // stores record in the mapping <MapKey, Record> + private final MapState<RowData, RowData> recordState; + private final List<RowData> reusedList; + + private JoinKeyContainsUniqueKey( + RuntimeContext ctx, + final String stateName, + final InternalTypeInfo<RowData> mapKeyType, + final InternalTypeInfo<RowData> recordType, + final StateTtlConfig ttlConfig) { + + MapStateDescriptor<RowData, RowData> recordStateDesc = + createStateDescriptor(stateName, mapKeyType, recordType, ttlConfig); + + this.recordState = ctx.getMapState(recordStateDesc); + // the result records always not more than 1 per mapKey + this.reusedList = new ArrayList<>(1); + } + + @Override + public void addRecord(RowData mapKey, RowData record) throws Exception { + recordState.put(mapKey, record); + } + + @Override + public void retractRecord(RowData mapKey, RowData record) throws Exception { + // Only one record is kept per mapKey, remove it directly. + recordState.remove(mapKey); + } + + @Override + public Iterable<RowData> getRecords(RowData mapKey) throws Exception { + reusedList.clear(); + RowData record = recordState.get(mapKey); + if (record != null) { + reusedList.add(record); + } + return reusedList; + } + + @Override + public void cleanup(RowData mapKey) throws Exception { + recordState.remove(mapKey); + } + } + + /** + * State view for input sides that have a unique key, but it differs from the join key. + * + * <p>Stores data as {@code MapState<MapKey, Map<UK, Record>>}. + */ + private static final class InputSideHasUniqueKey implements MultiJoinStateView { + + // stores map in the mapping <MapKey, Map<UK, Record>> + private final MapState<RowData, Map<RowData, RowData>> recordState; + private final KeySelector<RowData, RowData> uniqueKeySelector; + + private InputSideHasUniqueKey( + RuntimeContext ctx, + final String stateName, + final InternalTypeInfo<RowData> mapKeyType, + final InternalTypeInfo<RowData> recordType, + final InternalTypeInfo<RowData> uniqueKeyType, + final KeySelector<RowData, RowData> uniqueKeySelector, + final StateTtlConfig ttlConfig) { + checkNotNull(uniqueKeyType); + checkNotNull(uniqueKeySelector); + this.uniqueKeySelector = uniqueKeySelector; + + TypeInformation<Map<RowData, RowData>> mapValueTypeInfo = + Types.MAP(uniqueKeyType, recordType); // UK is the key in the inner map + + MapStateDescriptor<RowData, Map<RowData, RowData>> recordStateDesc = + createStateDescriptor(stateName, mapKeyType, mapValueTypeInfo, ttlConfig); + + this.recordState = ctx.getMapState(recordStateDesc); + } + + @Override + public void addRecord(RowData mapKey, RowData record) throws Exception { + RowData uniqueKey = uniqueKeySelector.getKey(record); + Map<RowData, RowData> uniqueKeyToRecordMap = recordState.get(mapKey); + if (uniqueKeyToRecordMap == null) { + uniqueKeyToRecordMap = new HashMap<>(); + } + uniqueKeyToRecordMap.put(uniqueKey, record); + recordState.put(mapKey, uniqueKeyToRecordMap); + } + + @Override + public void retractRecord(RowData mapKey, RowData record) throws Exception { + RowData uniqueKey = uniqueKeySelector.getKey(record); + Map<RowData, RowData> uniqueKeyToRecordMap = recordState.get(mapKey); + if (uniqueKeyToRecordMap != null) { + uniqueKeyToRecordMap.remove(uniqueKey); + if (uniqueKeyToRecordMap.isEmpty()) { + // Clean up the entry for mapKey if the inner map becomes empty + recordState.remove(mapKey); + } else { + recordState.put(mapKey, uniqueKeyToRecordMap); + } + } + // ignore uniqueKeyToRecordMap == null + } + + @Override + public Iterable<RowData> getRecords(RowData mapKey) throws Exception { + Map<RowData, RowData> uniqueKeyToRecordMap = recordState.get(mapKey); Review Comment: Imo, it should be possible to add a test case which produces wrong results if `mapKey` is not equal to the `uniqueKey` ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/keyselector/AttributeBasedJoinKeyExtractor.java: ########## @@ -0,0 +1,299 @@ +package org.apache.flink.table.runtime.operators.join.stream.keyselector; + +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; +import java.util.stream.Collectors; + +/** + * A {@link JoinKeyExtractor} that derives keys based on {@link AttributeRef} mappings provided in + * {@code joinAttributeMap}. It defines how attributes from different input streams are related + * through equi-join conditions, assuming input 0 is the base and subsequent inputs join to + * preceding ones. + */ +public class AttributeBasedJoinKeyExtractor implements JoinKeyExtractor { + private static final long serialVersionUID = 1L; + + // Default key/type used when no specific join keys are applicable (e.g., input 0, cross joins). + private static final GenericRowData DEFAULT_KEY = new GenericRowData(1); + + static { + DEFAULT_KEY.setField(0, "__DEFAULT_MULTI_JOIN_SATE_KEY__"); + } + + private static final InternalTypeInfo<RowData> DEFAULT_KEY_TYPE = + InternalTypeInfo.of( + RowType.of( + new LogicalType[] { + // Fixed type for the default key. Length matches the static key + // value. + new VarCharType(false, 31) + }, + new String[] {"default_key"})); + + private final transient Map<Integer, Map<AttributeRef, AttributeRef>> + joinAttributeMap; // Transient as it's configuration + private final List<InternalTypeInfo<RowData>> inputTypes; + + /** + * Creates an AttributeBasedJoinKeyExtractor. + * + * @param joinAttributeMap Map defining equi-join conditions. Outer key: inputId (>= 1). Inner + * key: {@link AttributeRef} to a field in a *previous* input. Inner value: {@link + * AttributeRef} to the corresponding field in the *current* input (inputId == outer key). + * @param inputTypes Type information for all input streams (indexed 0 to N-1). + */ + public AttributeBasedJoinKeyExtractor( + final Map<Integer, Map<AttributeRef, AttributeRef>> joinAttributeMap, + final List<InternalTypeInfo<RowData>> inputTypes) { + this.joinAttributeMap = joinAttributeMap; + this.inputTypes = inputTypes; + } + + @Override + public RowData getKeyForStateStorage(RowData row, int inputId) { + if (inputId == 0) { + // Input 0 uses the fixed default key as it's the start of the join chain. + return DEFAULT_KEY; + } + + // For inputs > 0, storage key derived from current row's equi-join fields. + final Map<AttributeRef, AttributeRef> attributeMapping = joinAttributeMap.get(inputId); + if (attributeMapping == null || attributeMapping.isEmpty()) { + // No equi-join conditions defined for this input, use default key. + return DEFAULT_KEY; + } + + // Indices of fields in the *current* input (inputId) used as the *right* side of joins. + final List<Integer> keyFieldIndices = determineKeyFieldIndices(inputId); + + if (keyFieldIndices.isEmpty()) { + // Mappings exist, but none point to fields *within* this inputId (config error?), use + // default key. + return DEFAULT_KEY; + } + + return buildKeyRow(row, inputId, keyFieldIndices); + } + + @Override + public RowData getKeyForStateLookup(int depth, RowData[] currentRows) { + if (depth == 0) { + // Input 0 lookup always uses the fixed default key. + return DEFAULT_KEY; + } + + // For depths > 0, lookup key derived from *previous* rows (indices < depth) + // using the *left* side of equi-join conditions for the *current* depth. + final Map<AttributeRef, AttributeRef> attributeMapping = joinAttributeMap.get(depth); + if (attributeMapping == null || attributeMapping.isEmpty()) { + // No equi-join conditions link previous inputs to this depth (e.g. cross join). + // Use default key. + return DEFAULT_KEY; + } + + // TreeMap ensures deterministic key structure: left inputId -> left fieldIndex + final Map<Integer, Map<Integer, Object>> sortedKeyComponents = new TreeMap<>(); + + // Iterate through join attributes for the current depth. + // Key (leftAttrRef) points to previous input (< depth). + // Value (rightAttrRef) points to current input (== depth). + for (Map.Entry<AttributeRef, AttributeRef> entry : attributeMapping.entrySet()) { Review Comment: Why do we need to iterate over all the entries? Can't we precompute indices for a given join depth? I'd imagine in a pseudocode: ``` val mask = joinMasks.get(depth); RowData value = currentRows[depth]; RowData key = mask.apply(value); ``` I am a bit afraid of doing all the expensive computations for every record. ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/state/MultiJoinStateViews.java: ########## @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.stream.state; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.operators.join.stream.utils.JoinInputSideSpec; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.util.IterableIterator; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Factory class to create different implementations of {@link MultiJoinStateView} based on the + * characteristics described in {@link JoinInputSideSpec}. + * + * <p>Each state view uses a {@link MapState} where the primary key is the `mapKey` derived from the + * join conditions (via {@link + * org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor}). The value + * stored within this map depends on whether the input side has a unique key and how it relates to + * the join key, optimizing storage and access patterns. + */ +public final class MultiJoinStateViews { + + /** Creates a {@link MultiJoinStateView} depends on {@link JoinInputSideSpec}. */ + public static MultiJoinStateView create( + RuntimeContext ctx, + String stateName, + JoinInputSideSpec inputSideSpec, + InternalTypeInfo<RowData> mapKeyType, // Type info for the outer map key + InternalTypeInfo<RowData> recordType, + long retentionTime) { + StateTtlConfig ttlConfig = createTtlConfig(retentionTime); + + if (inputSideSpec.hasUniqueKey()) { + if (inputSideSpec.joinKeyContainsUniqueKey()) { + return new JoinKeyContainsUniqueKey( + ctx, stateName, mapKeyType, recordType, ttlConfig); + } else { + return new InputSideHasUniqueKey( + ctx, + stateName, + mapKeyType, + recordType, + inputSideSpec.getUniqueKeyType(), + inputSideSpec.getUniqueKeySelector(), + ttlConfig); + } + } else { + return new InputSideHasNoUniqueKey(ctx, stateName, mapKeyType, recordType, ttlConfig); + } + } + + /** + * Creates a {@link MapStateDescriptor} with the given parameters and applies TTL configuration. + * + * @param <K> Key type + * @param <V> Value type + * @param stateName Unique name for the state + * @param keyTypeInfo Type information for the key + * @param valueTypeInfo Type information for the value + * @param ttlConfig State TTL configuration + * @return Configured MapStateDescriptor + */ + private static <K, V> MapStateDescriptor<K, V> createStateDescriptor( + String stateName, + TypeInformation<K> keyTypeInfo, + TypeInformation<V> valueTypeInfo, + StateTtlConfig ttlConfig) { + MapStateDescriptor<K, V> descriptor = + new MapStateDescriptor<>(stateName, keyTypeInfo, valueTypeInfo); + if (ttlConfig.isEnabled()) { + descriptor.enableTimeToLive(ttlConfig); + } + return descriptor; + } + + // ------------------------------------------------------------------------------------ + // Multi Join State View Implementations + // ------------------------------------------------------------------------------------ + + /** + * State view for input sides where the unique key is fully contained within the join key. + * + * <p>Stores data as {@code MapState<MapKey, Record>}. + */ + private static final class JoinKeyContainsUniqueKey implements MultiJoinStateView { + + // stores record in the mapping <MapKey, Record> + private final MapState<RowData, RowData> recordState; + private final List<RowData> reusedList; + + private JoinKeyContainsUniqueKey( + RuntimeContext ctx, + final String stateName, + final InternalTypeInfo<RowData> mapKeyType, + final InternalTypeInfo<RowData> recordType, + final StateTtlConfig ttlConfig) { + + MapStateDescriptor<RowData, RowData> recordStateDesc = + createStateDescriptor(stateName, mapKeyType, recordType, ttlConfig); + + this.recordState = ctx.getMapState(recordStateDesc); + // the result records always not more than 1 per mapKey + this.reusedList = new ArrayList<>(1); + } + + @Override + public void addRecord(RowData mapKey, RowData record) throws Exception { + recordState.put(mapKey, record); + } + + @Override + public void retractRecord(RowData mapKey, RowData record) throws Exception { + // Only one record is kept per mapKey, remove it directly. + recordState.remove(mapKey); + } + + @Override + public Iterable<RowData> getRecords(RowData mapKey) throws Exception { + reusedList.clear(); + RowData record = recordState.get(mapKey); + if (record != null) { + reusedList.add(record); + } + return reusedList; + } + + @Override + public void cleanup(RowData mapKey) throws Exception { + recordState.remove(mapKey); + } + } + + /** + * State view for input sides that have a unique key, but it differs from the join key. + * + * <p>Stores data as {@code MapState<MapKey, Map<UK, Record>>}. + */ + private static final class InputSideHasUniqueKey implements MultiJoinStateView { + + // stores map in the mapping <MapKey, Map<UK, Record>> + private final MapState<RowData, Map<RowData, RowData>> recordState; Review Comment: I am really concerned about the state design for this case. If I understand correctly the join key may be completely different from the records `uniqueKey`. Having the `uniqueKey` as the main key in the state, render the optimisation to store records per joinKey mostly useless, as we need to retrieve all entries and then filter the maps using the `mapKey`. ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/stream/state/MultiJoinStateViews.java: ########## @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.stream.state; + +import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.state.MapState; +import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.StateTtlConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.operators.join.stream.utils.JoinInputSideSpec; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.util.IterableIterator; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +import static org.apache.flink.table.runtime.util.StateConfigUtil.createTtlConfig; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Factory class to create different implementations of {@link MultiJoinStateView} based on the + * characteristics described in {@link JoinInputSideSpec}. + * + * <p>Each state view uses a {@link MapState} where the primary key is the `mapKey` derived from the + * join conditions (via {@link + * org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor}). The value + * stored within this map depends on whether the input side has a unique key and how it relates to + * the join key, optimizing storage and access patterns. + */ +public final class MultiJoinStateViews { + + /** Creates a {@link MultiJoinStateView} depends on {@link JoinInputSideSpec}. */ + public static MultiJoinStateView create( + RuntimeContext ctx, + String stateName, + JoinInputSideSpec inputSideSpec, + InternalTypeInfo<RowData> mapKeyType, // Type info for the outer map key + InternalTypeInfo<RowData> recordType, + long retentionTime) { + StateTtlConfig ttlConfig = createTtlConfig(retentionTime); + + if (inputSideSpec.hasUniqueKey()) { + if (inputSideSpec.joinKeyContainsUniqueKey()) { + return new JoinKeyContainsUniqueKey( + ctx, stateName, mapKeyType, recordType, ttlConfig); + } else { + return new InputSideHasUniqueKey( + ctx, + stateName, + mapKeyType, + recordType, + inputSideSpec.getUniqueKeyType(), + inputSideSpec.getUniqueKeySelector(), + ttlConfig); + } + } else { + return new InputSideHasNoUniqueKey(ctx, stateName, mapKeyType, recordType, ttlConfig); + } + } + + /** + * Creates a {@link MapStateDescriptor} with the given parameters and applies TTL configuration. + * + * @param <K> Key type + * @param <V> Value type + * @param stateName Unique name for the state + * @param keyTypeInfo Type information for the key + * @param valueTypeInfo Type information for the value + * @param ttlConfig State TTL configuration + * @return Configured MapStateDescriptor + */ + private static <K, V> MapStateDescriptor<K, V> createStateDescriptor( + String stateName, + TypeInformation<K> keyTypeInfo, + TypeInformation<V> valueTypeInfo, + StateTtlConfig ttlConfig) { + MapStateDescriptor<K, V> descriptor = + new MapStateDescriptor<>(stateName, keyTypeInfo, valueTypeInfo); + if (ttlConfig.isEnabled()) { + descriptor.enableTimeToLive(ttlConfig); + } + return descriptor; + } + + // ------------------------------------------------------------------------------------ + // Multi Join State View Implementations + // ------------------------------------------------------------------------------------ + + /** + * State view for input sides where the unique key is fully contained within the join key. + * + * <p>Stores data as {@code MapState<MapKey, Record>}. + */ + private static final class JoinKeyContainsUniqueKey implements MultiJoinStateView { + + // stores record in the mapping <MapKey, Record> + private final MapState<RowData, RowData> recordState; + private final List<RowData> reusedList; + + private JoinKeyContainsUniqueKey( + RuntimeContext ctx, + final String stateName, + final InternalTypeInfo<RowData> mapKeyType, + final InternalTypeInfo<RowData> recordType, + final StateTtlConfig ttlConfig) { + + MapStateDescriptor<RowData, RowData> recordStateDesc = + createStateDescriptor(stateName, mapKeyType, recordType, ttlConfig); + + this.recordState = ctx.getMapState(recordStateDesc); + // the result records always not more than 1 per mapKey + this.reusedList = new ArrayList<>(1); + } + + @Override + public void addRecord(RowData mapKey, RowData record) throws Exception { + recordState.put(mapKey, record); + } + + @Override + public void retractRecord(RowData mapKey, RowData record) throws Exception { + // Only one record is kept per mapKey, remove it directly. + recordState.remove(mapKey); + } + + @Override + public Iterable<RowData> getRecords(RowData mapKey) throws Exception { + reusedList.clear(); + RowData record = recordState.get(mapKey); + if (record != null) { + reusedList.add(record); + } + return reusedList; + } + + @Override + public void cleanup(RowData mapKey) throws Exception { + recordState.remove(mapKey); + } + } + + /** + * State view for input sides that have a unique key, but it differs from the join key. + * + * <p>Stores data as {@code MapState<MapKey, Map<UK, Record>>}. + */ + private static final class InputSideHasUniqueKey implements MultiJoinStateView { + + // stores map in the mapping <MapKey, Map<UK, Record>> + private final MapState<RowData, Map<RowData, RowData>> recordState; + private final KeySelector<RowData, RowData> uniqueKeySelector; + + private InputSideHasUniqueKey( + RuntimeContext ctx, + final String stateName, + final InternalTypeInfo<RowData> mapKeyType, + final InternalTypeInfo<RowData> recordType, + final InternalTypeInfo<RowData> uniqueKeyType, + final KeySelector<RowData, RowData> uniqueKeySelector, + final StateTtlConfig ttlConfig) { + checkNotNull(uniqueKeyType); + checkNotNull(uniqueKeySelector); + this.uniqueKeySelector = uniqueKeySelector; + + TypeInformation<Map<RowData, RowData>> mapValueTypeInfo = + Types.MAP(uniqueKeyType, recordType); // UK is the key in the inner map + + MapStateDescriptor<RowData, Map<RowData, RowData>> recordStateDesc = + createStateDescriptor(stateName, mapKeyType, mapValueTypeInfo, ttlConfig); + + this.recordState = ctx.getMapState(recordStateDesc); + } + + @Override + public void addRecord(RowData mapKey, RowData record) throws Exception { + RowData uniqueKey = uniqueKeySelector.getKey(record); + Map<RowData, RowData> uniqueKeyToRecordMap = recordState.get(mapKey); + if (uniqueKeyToRecordMap == null) { + uniqueKeyToRecordMap = new HashMap<>(); + } + uniqueKeyToRecordMap.put(uniqueKey, record); + recordState.put(mapKey, uniqueKeyToRecordMap); + } + + @Override + public void retractRecord(RowData mapKey, RowData record) throws Exception { + RowData uniqueKey = uniqueKeySelector.getKey(record); + Map<RowData, RowData> uniqueKeyToRecordMap = recordState.get(mapKey); + if (uniqueKeyToRecordMap != null) { + uniqueKeyToRecordMap.remove(uniqueKey); + if (uniqueKeyToRecordMap.isEmpty()) { + // Clean up the entry for mapKey if the inner map becomes empty + recordState.remove(mapKey); + } else { + recordState.put(mapKey, uniqueKeyToRecordMap); + } + } + // ignore uniqueKeyToRecordMap == null + } + + @Override + public Iterable<RowData> getRecords(RowData mapKey) throws Exception { + Map<RowData, RowData> uniqueKeyToRecordMap = recordState.get(mapKey); Review Comment: `mapKey` != `uniqueKey` You use the `uniqueKey` as the key of `recordState` ########## flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/join/stream/StreamingMultiJoinOperatorTestBase.java: ########## @@ -0,0 +1,644 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.operators.join.stream; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory; +import org.apache.flink.streaming.api.operators.StreamOperator; +import org.apache.flink.streaming.api.operators.StreamOperatorParameters; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.KeyedMultiInputStreamOperatorTestHarness; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.runtime.generated.GeneratedMultiJoinCondition; +import org.apache.flink.table.runtime.generated.MultiJoinCondition; +import org.apache.flink.table.runtime.keyselector.RowDataKeySelector; +import org.apache.flink.table.runtime.operators.join.stream.StreamingMultiJoinOperator.JoinType; +import org.apache.flink.table.runtime.operators.join.stream.keyselector.AttributeBasedJoinKeyExtractor; +import org.apache.flink.table.runtime.operators.join.stream.keyselector.AttributeBasedJoinKeyExtractor.AttributeRef; +import org.apache.flink.table.runtime.operators.join.stream.keyselector.JoinKeyExtractor; +import org.apache.flink.table.runtime.operators.join.stream.utils.JoinInputSideSpec; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.runtime.util.RowDataHarnessAssertor; +import org.apache.flink.table.runtime.util.StreamRecordUtils; +import org.apache.flink.table.types.logical.CharType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.VarCharType; +import org.apache.flink.table.utils.HandwrittenSelectorUtil; +import org.apache.flink.types.RowKind; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Base class for testing the StreamingMultiJoinOperator. Provides common functionality and helper + * methods for testing multi-way joins. + */ +public abstract class StreamingMultiJoinOperatorTestBase { + + // ========================================================================== + // Constants + // ========================================================================== + + protected static final RowKind INSERT = RowKind.INSERT; + protected static final RowKind UPDATE_BEFORE = RowKind.UPDATE_BEFORE; + protected static final RowKind UPDATE_AFTER = RowKind.UPDATE_AFTER; + protected static final RowKind DELETE = RowKind.DELETE; + + // ========================================================================== + // Test Configuration + // ========================================================================== + + protected final List<InternalTypeInfo<RowData>> inputTypeInfos; + protected final List<RowDataKeySelector> keySelectors; + protected final List<JoinInputSideSpec> inputSpecs; + protected final List<JoinType> joinTypes; + protected final List<GeneratedMultiJoinCondition> joinConditions; + protected final boolean isFullOuterJoin; + protected final Map<Integer, Map<AttributeRef, AttributeRef>> joinAttributeMap; + protected final InternalTypeInfo<RowData> joinKeyTypeInfo; + protected final JoinKeyExtractor keyExtractor; + + // ========================================================================== + // Test State + // ========================================================================== + + protected RowDataHarnessAssertor assertor; + protected KeyedMultiInputStreamOperatorTestHarness<RowData, RowData> testHarness; + + // ========================================================================== + // Constructor + // ========================================================================== + + protected StreamingMultiJoinOperatorTestBase( + int numInputs, + List<JoinType> joinTypes, + List<GeneratedMultiJoinCondition> joinConditions, + boolean isFullOuterJoin) { + this.inputTypeInfos = new ArrayList<>(numInputs); + this.keySelectors = new ArrayList<>(numInputs); + this.inputSpecs = new ArrayList<>(numInputs); + this.joinTypes = joinTypes; + this.isFullOuterJoin = isFullOuterJoin; + this.joinConditions = joinConditions; + this.joinAttributeMap = new HashMap<>(); + + initializeInputs(numInputs); + initializeJoinConditions(); + + this.joinKeyTypeInfo = InternalTypeInfo.of(new CharType(false, 20)); + this.keyExtractor = + new AttributeBasedJoinKeyExtractor(this.joinAttributeMap, this.inputTypeInfos); + } + + /** Constructor allowing explicit provision of joinAttributeMap for custom conditions. */ + protected StreamingMultiJoinOperatorTestBase( + int numInputs, + List<JoinType> joinTypes, + List<GeneratedMultiJoinCondition> joinConditions, + Map<Integer, Map<AttributeRef, AttributeRef>> joinAttributeMap, + boolean isFullOuterJoin) { + this.inputTypeInfos = new ArrayList<>(numInputs); + this.keySelectors = new ArrayList<>(numInputs); + this.inputSpecs = new ArrayList<>(numInputs); + this.joinTypes = joinTypes; + this.isFullOuterJoin = isFullOuterJoin; + this.joinConditions = joinConditions; // Use provided conditions + this.joinAttributeMap = joinAttributeMap; // Use provided map + + initializeInputs(numInputs); + + this.joinKeyTypeInfo = InternalTypeInfo.of(new CharType(false, 20)); + this.keyExtractor = + new AttributeBasedJoinKeyExtractor(this.joinAttributeMap, this.inputTypeInfos); + } + + // ========================================================================== + // Test Lifecycle + // ========================================================================== + + @BeforeEach + void beforeEach() throws Exception { + testHarness = createTestHarness(); + setupKeySelectorsForTestHarness(testHarness); + testHarness.setup(); + testHarness.open(); + assertor = + new RowDataHarnessAssertor( + getOutputType().getChildren().toArray(new LogicalType[0])); + } + + @AfterEach + void afterEach() throws Exception { + if (testHarness != null) { + testHarness.close(); + } + } + + // ========================================================================== + // Helper Methods for Test Data + // ========================================================================== + + protected void insertUser(String userId, String userName, String details) throws Exception { + processRecord(0, INSERT, userId, userName, details); + } + + protected void insertOrder(String userId, String orderId, String details) throws Exception { + processRecord(1, INSERT, userId, orderId, details); + } + + protected void insertPayment(String userId, String paymentId, String details) throws Exception { + processRecord(2, INSERT, userId, paymentId, details); + } + + protected void updateBeforeUser(String userId, String userName, String details) + throws Exception { + processRecord(0, UPDATE_BEFORE, userId, userName, details); + } + + protected void updateAfterUser(String userId, String userName, String details) + throws Exception { + processRecord(0, UPDATE_AFTER, userId, userName, details); + } + + protected void updateBeforeOrder(String userId, String orderId, String details) + throws Exception { + processRecord(1, UPDATE_BEFORE, userId, orderId, details); + } + + protected void updateAfterOrder(String userId, String orderId, String details) + throws Exception { + processRecord(1, UPDATE_AFTER, userId, orderId, details); + } + + protected void updateBeforePayment(String userId, String paymentId, String details) + throws Exception { + processRecord(2, UPDATE_BEFORE, userId, paymentId, details); + } + + protected void updateAfterPayment(String userId, String paymentId, String details) + throws Exception { + processRecord(2, UPDATE_AFTER, userId, paymentId, details); + } + + protected void deleteUser(String userId, String userName, String details) throws Exception { + processRecord(0, DELETE, userId, userName, details); + } + + protected void deleteOrder(String userId, String orderId, String details) throws Exception { + processRecord(1, DELETE, userId, orderId, details); + } + + protected void deletePayment(String userId, String paymentId, String details) throws Exception { + processRecord(2, DELETE, userId, paymentId, details); + } + + protected static List<GeneratedMultiJoinCondition> defaultConditions() { + return new ArrayList<>(); + } + + // ========================================================================== + // Assertion Methods + // ========================================================================== + + protected void emits(RowKind kind, String... fields) throws Exception { + assertor.shouldEmit(testHarness, rowOfKind(kind, fields)); + } + + protected void emitsNothing() { + assertor.shouldEmitNothing(testHarness); + } + + protected void emits(RowKind kind1, String[] fields1, RowKind kind2, String[] fields2) + throws Exception { + assertor.shouldEmitAll(testHarness, rowOfKind(kind1, fields1), rowOfKind(kind2, fields2)); + } + + protected void emits( + RowKind kind1, + String[] fields1, + RowKind kind2, + String[] fields2, + RowKind kind3, + String[] fields3) + throws Exception { + assertor.shouldEmitAll( + testHarness, + rowOfKind(kind1, fields1), + rowOfKind(kind2, fields2), + rowOfKind(kind3, fields3)); + } + + protected void emits( + RowKind kind1, + String[] fields1, + RowKind kind2, + String[] fields2, + RowKind kind3, + String[] fields3, + RowKind kind4, + String[] fields4) + throws Exception { + assertor.shouldEmitAll( + testHarness, + rowOfKind(kind1, fields1), + rowOfKind(kind2, fields2), + rowOfKind(kind3, fields3), + rowOfKind(kind4, fields4)); + } + + // ========================================================================== + // Private Helper Methods + // ========================================================================== + + private void initializeInputs(int numInputs) { + if (numInputs < 2) { + throw new IllegalArgumentException("Number of inputs must be a" + "t least 2"); + } + + // In our test, the first input is always the one with the unique key as a join key + inputTypeInfos.add(createInputTypeInfo(0)); + keySelectors.add(createKeySelector(0)); + inputSpecs.add( + JoinInputSideSpec.withUniqueKeyContainedByJoinKey( + createUniqueKeyType(0), keySelectors.get(0))); + + // Following tables contain a unique key but are not contained in the join key + for (int i = 1; i < numInputs; i++) { + inputTypeInfos.add(createInputTypeInfo(i)); + keySelectors.add(createKeySelector(i)); + inputSpecs.add( + JoinInputSideSpec.withUniqueKey(createUniqueKeyType(i), keySelectors.get(i))); + } + } + + private void initializeJoinConditions() { + // If the map is already populated, it means conditions and map were provided explicitly. + // Validation happened in the specific constructor. + if (!joinAttributeMap.isEmpty()) { + return; + } + + // Proceed with default generation only if conditions AND map were not provided. + if (joinConditions.isEmpty()) { + // First input doesn't have a left input to join with + joinConditions.add(null); + for (int i = 0; i < inputSpecs.size(); i++) { // Iterate based on number of inputs + // Add the join condition comparing current input (i) with previous (i-1) + if (i > 0) { + GeneratedMultiJoinCondition condition = createJoinCondition(i, i - 1); + joinConditions.add(condition); + } + + // Populate the attribute map based on the condition's logic (field 0 <-> field 0) + Map<AttributeRef, AttributeRef> currentJoinMap = new HashMap<>(); + // Left side attribute (previous input, field 0) + AttributeRef leftAttr = new AttributeRef(i - 1, 0); + // Right side attribute (current input, field 0) + AttributeRef rightAttr = new AttributeRef(i, 0); + // Map: right_input_id -> { left_attr -> right_attr } + currentJoinMap.put(leftAttr, rightAttr); + joinAttributeMap.put(i, currentJoinMap); + } + + } else if (joinConditions.size() != inputSpecs.size()) { + throw new IllegalArgumentException( + "The number of provided join conditions must match the number of inputs (" + + inputSpecs.size() + + "), but got " + + joinConditions.size()); + } + } + + private void processRecord(int inputIndex, RowKind kind, String... fields) throws Exception { + StreamRecord<RowData> record; + switch (kind) { + case INSERT: + record = StreamRecordUtils.insertRecord((Object[]) fields); + break; + case UPDATE_BEFORE: + record = StreamRecordUtils.updateBeforeRecord((Object[]) fields); + break; + case UPDATE_AFTER: + record = StreamRecordUtils.updateAfterRecord((Object[]) fields); + break; + case DELETE: + record = StreamRecordUtils.deleteRecord((Object[]) fields); + break; + default: + throw new IllegalArgumentException("Unsupported RowKind: " + kind); + } + testHarness.processElement(inputIndex, record); + } + + private void setupKeySelectorsForTestHarness( + KeyedMultiInputStreamOperatorTestHarness<RowData, RowData> harness) { + for (int i = 0; i < this.inputSpecs.size(); i++) { + /* Testcase: our join key is always the first key for all tables and that's why 0 */ + KeySelector<RowData, RowData> keySelector = row -> GenericRowData.of(row.getString(0)); + harness.setKeySelector(i, keySelector); + } + } + + protected KeyedMultiInputStreamOperatorTestHarness<RowData, RowData> createTestHarness() + throws Exception { + KeyedMultiInputStreamOperatorTestHarness<RowData, RowData> harness = + new KeyedMultiInputStreamOperatorTestHarness<>( + new MultiStreamingJoinOperatorFactory( + inputSpecs, + inputTypeInfos, + joinTypes, + joinConditions, + joinAttributeMap), + TypeInformation.of(RowData.class)); + + // Setup key selectors for each input + setupKeySelectorsForTestHarness(harness); + return harness; + } + + protected RowType getOutputType() { + var typesStream = + inputTypeInfos.stream() + .flatMap(typeInfo -> typeInfo.toRowType().getChildren().stream()); + var namesStream = + inputTypeInfos.stream() + .flatMap(typeInfo -> typeInfo.toRowType().getFieldNames().stream()); + + return RowType.of( + typesStream.toArray(LogicalType[]::new), namesStream.toArray(String[]::new)); + } + + protected RowData rowOfKind(RowKind kind, String... fields) { + return StreamRecordUtils.rowOfKind(kind, (Object[]) fields); + } + + protected String[] r(String... values) { + return values; + } + + // ========================================================================== + // Factory Class + // ========================================================================== + + private static class MultiStreamingJoinOperatorFactory + extends AbstractStreamOperatorFactory<RowData> { + + private static final long serialVersionUID = 1L; + private final List<JoinInputSideSpec> inputSpecs; + private final List<InternalTypeInfo<RowData>> inputTypeInfos; + private final List<JoinType> joinTypes; + private final List<GeneratedMultiJoinCondition> joinConditions; + private final Map<Integer, Map<AttributeRef, AttributeRef>> joinAttributeMap; + private final JoinKeyExtractor keyExtractor; + + public MultiStreamingJoinOperatorFactory( + List<JoinInputSideSpec> inputSpecs, + List<InternalTypeInfo<RowData>> inputTypeInfos, + List<JoinType> joinTypes, + List<GeneratedMultiJoinCondition> joinConditions, + Map<Integer, Map<AttributeRef, AttributeRef>> joinAttributeMap) { + this.inputSpecs = inputSpecs; + this.inputTypeInfos = inputTypeInfos; + this.joinTypes = joinTypes; + this.joinConditions = joinConditions; + this.joinAttributeMap = joinAttributeMap; + this.keyExtractor = + new AttributeBasedJoinKeyExtractor(joinAttributeMap, inputTypeInfos); + } + + @Override + public <T extends StreamOperator<RowData>> T createStreamOperator( + StreamOperatorParameters<RowData> parameters) { + StreamingMultiJoinOperator op = + createJoinOperator( + parameters, + inputSpecs, + inputTypeInfos, + joinTypes, + joinConditions, + keyExtractor); + + @SuppressWarnings("unchecked") + T operator = (T) op; + return operator; + } + + @Override + public Class<? extends StreamOperator<RowData>> getStreamOperatorClass( + ClassLoader classLoader) { + return StreamingMultiJoinOperator.class; + } + + private StreamingMultiJoinOperator createJoinOperator( + StreamOperatorParameters<RowData> parameters, + List<JoinInputSideSpec> inputSpecs, + List<InternalTypeInfo<RowData>> inputTypeInfos, + List<JoinType> joinTypes, + List<GeneratedMultiJoinCondition> joinConditions, + JoinKeyExtractor keyExtractor) { + + long[] retentionTime = new long[inputSpecs.size()]; + Arrays.fill(retentionTime, 9999999L); + + MultiJoinCondition multiJoinCondition = + createMultiJoinCondition(inputSpecs.size()) + .newInstance(getClass().getClassLoader()); + MultiJoinCondition[] createdJoinConditions = createJoinConditions(joinConditions); + + return new StreamingMultiJoinOperator( + parameters, + inputTypeInfos, + inputSpecs, + joinTypes, + multiJoinCondition, + retentionTime, + createdJoinConditions, + keyExtractor); + } + + private MultiJoinCondition[] createJoinConditions( + List<GeneratedMultiJoinCondition> generatedJoinConditions) { + MultiJoinCondition[] conditions = new MultiJoinCondition[inputSpecs.size()]; + // We expect generatedJoinConditions size to match inputSpecs size (or joinTypes size) + if (generatedJoinConditions.size() != inputSpecs.size()) { + throw new IllegalArgumentException( + "The number of generated join conditions must match the number of inputs/joins."); + } + for (int i = 0; i < inputSpecs.size(); i++) { + GeneratedMultiJoinCondition generatedCondition = generatedJoinConditions.get(i); + if (generatedCondition + != null) { // Allow null conditions (e.g., for INNER joins without specific + // cond) + try { + conditions[i] = generatedCondition.newInstance(getClass().getClassLoader()); + } catch (Exception e) { + throw new RuntimeException( + "Failed to instantiate join condition for input " + i, e); + } + } else { + conditions[i] = null; // Explicitly set to null if no condition provided + } + } + return conditions; + } + } + + // ========================================================================== + // Type Creation Methods + // ========================================================================== + + protected InternalTypeInfo<RowData> createInputTypeInfo(int inputIndex) { + return InternalTypeInfo.of( + RowType.of( + new LogicalType[] { + new CharType(false, 20), + new CharType(false, 20), + VarCharType.STRING_TYPE + }, + new String[] { + String.format("user_id_%d", inputIndex), + String.format("id_%d", inputIndex), + String.format("details_%d", inputIndex) + })); + } + + protected InternalTypeInfo<RowData> createUniqueKeyType(int inputIndex) { + return InternalTypeInfo.of( + RowType.of( + new LogicalType[] { + new CharType(false, 20), + }, + new String[] { + String.format("id_%d", inputIndex), + })); + } + + protected RowDataKeySelector createKeySelector(int inputIndex) { + return HandwrittenSelectorUtil.getRowDataSelector( + /* Testcase: primary key is 0 for the first table and 1 for all others */ + new int[] {inputIndex == 0 ? 0 : 1}, + inputTypeInfos + .get(inputIndex) + .toRowType() + .getChildren() + .toArray(new LogicalType[0])); + } + + protected static GeneratedMultiJoinCondition createMultiJoinCondition(int numInputs) { Review Comment: Instead of hardcoding the string you can try this approach: https://github.com/apache/flink/blob/d57b89db6b2e177d131ce2376ee17672bd42dfa8/flink-table/flink-table-runtime/src/test/java/org/apache/flink/table/runtime/operators/over/ProcTimeRangeBoundedPrecedingFunctionTest.java#L43 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org