fhueske commented on code in PR #28212:
URL: https://github.com/apache/flink/pull/28212#discussion_r3316392321
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java:
##########
@@ -947,6 +948,61 @@ private void
validatePartitionConsistency(List<ArgumentInfo> arguments) {
}
}
+ private void validateInitialStateKeys(List<ArgumentInfo> arguments) {
+ if (stateArgs.isEmpty()) {
+ return;
+ }
+
+ // all partitioned tables share the same partition key shape, so
any one is
+ // sufficient for validation.
+ Optional<TableArgumentInfo> partitionedTable =
+ arguments.stream()
+ .filter(arg -> arg instanceof TableArgumentInfo)
+ .map(arg -> (TableArgumentInfo) arg)
+ .filter(t -> t.isSetSemantic &&
t.partitionColumnNames != null)
+ .findFirst();
+
+ if (partitionedTable.isEmpty()) {
+ return;
+ }
+
+ TableArgumentInfo table = partitionedTable.get();
+ int expectedArity = table.partitionColumnNames.length;
+ LogicalType[] expectedTypes =
+ Arrays.stream(table.partitionColumnNames)
+ .map(col -> extractPartitionColumnType(table,
col).getLogicalType())
+ .toArray(LogicalType[]::new);
+
+ for (Map.Entry<String, StateArgumentConfiguration> entry :
stateArgs.entrySet()) {
+ for (Row key : entry.getValue().initialValues.keySet()) {
+ if (key.getArity() != expectedArity) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Initial state key for state '%s' has
arity %d, "
+ + "but partition key has arity
%d.",
+ entry.getKey(), key.getArity(),
expectedArity));
+ }
+
+ for (int i = 0; i < expectedArity; i++) {
+ Object value = key.getField(i);
+ Class<?> expectedClass =
expectedTypes[i].getDefaultConversion();
+ if (value != null && !expectedClass.isInstance(value))
{
+ throw new IllegalArgumentException(
+ String.format(
+ "Initial state key for state '%s'
has type %s "
Review Comment:
Same here, include the key in the error message?
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java:
##########
@@ -947,6 +948,61 @@ private void
validatePartitionConsistency(List<ArgumentInfo> arguments) {
}
}
+ private void validateInitialStateKeys(List<ArgumentInfo> arguments) {
+ if (stateArgs.isEmpty()) {
+ return;
+ }
+
+ // all partitioned tables share the same partition key shape, so
any one is
+ // sufficient for validation.
Review Comment:
nit:
mention that this is ensured by calling `validatePartitionConsistency()`
earlier?
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java:
##########
@@ -876,17 +949,74 @@ private void
validatePartitionConsistency(List<ArgumentInfo> arguments) {
}
}
- private DataType extractPartitionColumnType(ArgumentInfo arg, String
columnName) {
- if (!(arg.dataType instanceof FieldsDataType)) {
+ private void validateInitialStateKeys(List<ArgumentInfo> arguments) {
+ if (stateArgs.isEmpty()) {
+ return;
+ }
+
+ // all partitioned tables share the same partition key shape, so
any one is
+ // sufficient for validation.
+ Optional<TableArgumentInfo> partitionedTable =
+ arguments.stream()
+ .filter(arg -> arg instanceof TableArgumentInfo)
+ .map(arg -> (TableArgumentInfo) arg)
+ .filter(t -> t.isSetSemantic &&
t.partitionColumnNames != null)
+ .findFirst();
+
+ if (partitionedTable.isEmpty()) {
+ return;
+ }
+
+ TableArgumentInfo table = partitionedTable.get();
+ int expectedArity = table.partitionColumnNames.length;
+ LogicalType[] expectedTypes =
+ Arrays.stream(table.partitionColumnNames)
+ .map(col -> extractPartitionColumnType(table,
col).getLogicalType())
+ .toArray(LogicalType[]::new);
+
+ for (Map.Entry<String, StateArgumentConfiguration> entry :
stateArgs.entrySet()) {
+ for (Row key : entry.getValue().initialValues.keySet()) {
+ if (key.getArity() != expectedArity) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Initial state key for state '%s' has
arity %d, "
+ + "but partition key has arity
%d.",
+ entry.getKey(), key.getArity(),
expectedArity));
Review Comment:
```suggestion
"Initial state key '%s' for state
'%s' has arity %d, "
+ "but partition key has
arity %d.",
key, entry.getKey(), key.getArity(),
expectedArity));
```
Also include the key in the error message?
##########
flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java:
##########
@@ -597,6 +699,56 @@ void testOptionalPartitionByWithPartition() throws
Exception {
}
}
+ @Test
+ void testOptionalPartitionByWithStateNoPartition() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(StatefulOptionalPartitionPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .build()) {
+
+ harness.processElement(Row.of("A", 10));
+ harness.processElement(Row.of("B", 20));
+ harness.processElement(Row.of("A", 30));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output).hasSize(3);
+ assertThat(output.get(0)).isEqualTo(Row.of(1L));
+ assertThat(output.get(1)).isEqualTo(Row.of(2L));
+ assertThat(output.get(2)).isEqualTo(Row.of(3L));
+
+ StatefulOptionalPartitionPTF.CounterState state =
+ harness.getStateForKey("state", Row.of());
Review Comment:
So for stateful PTFs with optional partitioning and no `PARTITION BY`, the
partitionKey to access state is `Row.of()` (which makes sense!).
I think this should be added to the documentation and we should also check
that all APIs that require a `partitionKey` argument work (initialize, get,
set, clear).
That would ensure that the additional checks that you just added also work
for this special case.
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java:
##########
@@ -270,43 +273,69 @@ public void clearOutput() {
output.clear();
}
- /**
- * Given a target table argument and a row to process, construct the right
set of arguments for
- * the PTF's eval function and attempt to invoke it.
- */
- private void invokeEval(ArgumentInfo activeTableArg, Row activeRow) throws
Exception {
- // Set collector context so it can prepend columns if needed
- collector.setContext(activeTableArg, activeRow);
+ /** Get state for a specific partition key. */
+ public <T> T getStateForKey(String stateName, Row partitionKey) {
+ return stateManager.getStateForKey(stateName, partitionKey);
+ }
- Object[] args = new Object[arguments.size()];
+ /** Set state for a specific partition key. */
+ public void setStateForKey(String stateName, Row partitionKey, Object
state) throws Exception {
+ stateManager.setStateForKey(stateName, partitionKey, state);
+ }
- for (int i = 0; i < arguments.size(); i++) {
- ArgumentInfo arg = arguments.get(i);
+ /** Get all partition keys that have a specific state entry. */
+ public Set<Row> getKeysForState(String stateName) {
+ return stateManager.getKeysForState(stateName);
+ }
+
+ /** Get all state values for a state name across all partition keys. */
+ public <T> Map<Row, T> getStateForAllKeys(String stateName) {
+ return stateManager.getStateForAllKeys(stateName);
+ }
+
+ /** Clear all state for a given partition key. */
+ public void clearStateForKey(Row partitionKey) {
Review Comment:
```suggestion
public void clearAllStatesForKey(Row partitionKey) {
```
to make it more explicit that this affects all states for this key?
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java:
##########
@@ -1029,6 +1030,29 @@ private DataType
extractPartitionColumnType(TableArgumentInfo tableArg, String c
columnName, tableArg.name));
}
+ private TestHarnessStateManager.PartitionKeyInfo
extractPartitionKeyInfo(
+ List<ArgumentInfo> arguments) {
+ Optional<TableArgumentInfo> partitionedTable =
+ arguments.stream()
+ .filter(arg -> arg instanceof TableArgumentInfo)
+ .map(arg -> (TableArgumentInfo) arg)
+ .filter(t -> t.isSetSemantic &&
t.partitionColumnNames != null)
+ .findFirst();
+
Review Comment:
A short comment that any table arg is sufficient because all must have the
same partitionKey type would be good here (like in the other spot).
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java:
##########
@@ -947,6 +948,61 @@ private void
validatePartitionConsistency(List<ArgumentInfo> arguments) {
}
}
+ private void validateInitialStateKeys(List<ArgumentInfo> arguments) {
+ if (stateArgs.isEmpty()) {
+ return;
+ }
+
+ // all partitioned tables share the same partition key shape, so
any one is
+ // sufficient for validation.
+ Optional<TableArgumentInfo> partitionedTable =
+ arguments.stream()
+ .filter(arg -> arg instanceof TableArgumentInfo)
+ .map(arg -> (TableArgumentInfo) arg)
+ .filter(t -> t.isSetSemantic &&
t.partitionColumnNames != null)
+ .findFirst();
+
+ if (partitionedTable.isEmpty()) {
+ return;
Review Comment:
wouldn't this be a failure case?
Even if this is checked before, it makes sense IMO to throw an exception
here to help readers understand this (even if we would never throw this
exception, but then why having this check?)
or is this the case for having no `PARTITION BY` with set semantics and
optional partitioning?
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/StateConverter.java:
##########
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+
+/**
+ * Converter between external state representations (ListView, MapView & value
state) and internal
Review Comment:
```suggestion
* Converter between external state representations (ListView, MapView, Row
& Pojo state) and internal
```
?
##########
docs/content.zh/docs/dev/table/functions/ptfs.md:
##########
@@ -2275,6 +2275,189 @@ void testScalarOnly() throws Exception {
{{< /tab >}}
{{< /tabs >}}
+#### Testing with State
+
+The harness supports all PTF state types: value state, `Row`, `ListView`, and
`MapView`.
Review Comment:
```suggestion
The harness supports all PTF state types: value state (Pojo and `Row`), list
state (`ListView`), and map state (`MapView`).
```
In the DataStream API ValueState refers to an atomic state entity that is
always fully read and written (in contrast to List and Map state which can be
selectively read and written).
ValueState can have different data types, POJOs, Rows, even Maps and Lists
(which would be stupid because there are specialized states for those).
So `Row` is also a type of value state (a value state of type `Row`).
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ListViewStateConverter.java:
##########
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.dataview.ListView;
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.GenericArrayData;
+import org.apache.flink.table.data.conversion.DataStructureConverter;
+import org.apache.flink.table.types.logical.ArrayType;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Converter for ListView state.
+ *
+ * <p>Converts between external ListView objects and internal ArrayData
representation.
+ */
+@Internal
+class ListViewStateConverter implements StateConverter {
Review Comment:
Thanks for the details, that makes sense!
It's not nice, but also "just" an internal class without user exposure. So I
think this is fine.
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java:
##########
@@ -947,6 +948,61 @@ private void
validatePartitionConsistency(List<ArgumentInfo> arguments) {
}
}
+ private void validateInitialStateKeys(List<ArgumentInfo> arguments) {
Review Comment:
Nice check!
##########
docs/content.zh/docs/dev/table/functions/ptfs.md:
##########
@@ -2275,6 +2275,189 @@ void testScalarOnly() throws Exception {
{{< /tab >}}
{{< /tabs >}}
+#### Testing with State
+
+The harness supports all PTF state types: value state, `Row`, `ListView`, and
`MapView`.
+
+{{< tabs "state-testing" >}}
+{{< tab "Java" >}}
+```java
+// A PTF that uses all four state types: value state, Row, ListView, and
MapView.
+@DataTypeHint("ROW<count BIGINT>")
+public class StatefulPTF extends ProcessTableFunction<Row> {
+ public static class ValueState {
+ public long count = 0L;
+ }
+
+ public void eval(
+ @StateHint ValueState valueState,
+ @StateHint(type = @DataTypeHint("ROW<lastValue INT>")) Row rowState,
+ @StateHint(type = @DataTypeHint("ARRAY<INT>")) ListView<Integer> listState,
+ @StateHint MapView<String, Integer> mapState,
+ @ArgumentHint(ArgumentTrait.SET_SEMANTIC_TABLE) Row input) throws
Exception {
+ // Value state — increment counter
+ valueState.count++;
+
+ // Row state — track the last value seen
+ int value = input.getFieldAs("value");
+ rowState.setField("lastValue", value);
+
+ // ListView state — accumulate values
+ listState.add(value);
+
+ // MapView state — count occurrences by name
+ String name = input.getFieldAs("name");
+ Integer tagCount = mapState.get(name);
+ mapState.put(name, tagCount == null ? 1 : tagCount + 1);
+
+ collect(Row.of(valueState.count));
+ }
+}
+
+@Test
+void testWithState() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+ harness.processElement(Row.of("Alice", 20));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output.get(0)).isEqualTo(Row.of("Alice", 1L));
+ assertThat(output.get(1)).isEqualTo(Row.of("Alice", 2L));
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**Initial State Setup**: Use `.withInitialStateForKey()` to pre-populate state
before processing.
+State initialization is scoped per partition key:
+
+{{< tabs "initial-state" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testWithInitialState() throws Exception {
+ // Value state
+ StatefulPTF.ValueState initialValue = new StatefulPTF.ValueState();
+ initialValue.count = 100L;
+
+ // Row state
+ Row initialRow = Row.withNames();
+ initialRow.setField("lastValue", 42);
+
+ // ListView state
+ ListView<Integer> initialList = new ListView<>();
+ initialList.add(10);
+ initialList.add(20);
+
+ // MapView state
+ MapView<String, Integer> initialMap = new MapView<>();
+ initialMap.put("Alice", 5);
+
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ // Initial state is set per partition key
+ .withInitialStateForKey("valueState", Row.of("Alice"), initialValue)
+ .withInitialStateForKey("rowState", Row.of("Alice"), initialRow)
+ .withInitialStateForKey("listState", Row.of("Alice"), initialList)
+ .withInitialStateForKey("mapState", Row.of("Alice"), initialMap)
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output).containsExactly(Row.of("Alice", 101L));
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**State Introspection**: Use `getStateForKey()`, `getKeysForState()`, and
`getStateForAllKeys()` to inspect state during tests:
+
+{{< tabs "state-introspection" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testStateIntrospection() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+ harness.processElement(Row.of("Bob", 20));
+
+ // Check value state
+ StatefulPTF.ValueState aliceState =
+ harness.getStateForKey("valueState", Row.of("Alice"));
+ assertThat(aliceState.count).isEqualTo(1L);
+
+ // Check Row state
+ Row aliceRowState = harness.getStateForKey("rowState", Row.of("Alice"));
+ assertThat(aliceRowState.getField("lastValue")).isEqualTo(10);
+
+ // Check ListView state
+ ListView<Integer> aliceList = harness.getStateForKey("listState",
Row.of("Alice"));
+ assertThat(aliceList.getList()).containsExactly(10);
+
+ // Check MapView state
+ MapView<String, Integer> aliceMap = harness.getStateForKey("mapState",
Row.of("Alice"));
+ assertThat(aliceMap.get("Alice")).isEqualTo(1);
+
+ // Get all partition keys with state
+ Set<Row> keys = harness.getKeysForState("valueState");
+ assertThat(keys).containsExactlyInAnyOrder(Row.of("Alice"), Row.of("Bob"));
+
+ // Get all state across partition keys
+ Map<Row, StatefulPTF.ValueState> allState =
+ harness.getStateForAllKeys("valueState");
+ assertThat(allState.get(Row.of("Bob")).count).isEqualTo(1L);
+ }
+}
+```
+{{< /tab >}}
+{{< /tabs >}}
+
+**State Mutation**: Use `setStateForKey()`, `clearStateForKey()`, and
`clearStateEntryForKey()` to modify state during tests:
+
+{{< tabs "state-mutation" >}}
+{{< tab "Java" >}}
+```java
+@Test
+void testStateMutation() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(StatefulPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<name STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build()) {
+
+ harness.processElement(Row.of("Alice", 10));
+
+ // Overwrite a specific state entry for a partition key
+ StatefulPTF.ValueState newState = new StatefulPTF.ValueState();
+ newState.count = 100L;
+ harness.setStateForKey("valueState", Row.of("Alice"), newState);
+
Review Comment:
Add an `assertThat` to check that the state was actually updated (or process
another element and check the output to see the effect of `setTestForKey()`?
##########
flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java:
##########
@@ -680,6 +699,56 @@ void testOptionalPartitionByWithPartition() throws
Exception {
}
}
+ @Test
+ void testOptionalPartitionByWithStateNoPartition() throws Exception {
+ try (ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(StatefulOptionalPartitionPTF.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .build()) {
+
+ harness.processElement(Row.of("A", 10));
+ harness.processElement(Row.of("B", 20));
+ harness.processElement(Row.of("A", 30));
+
+ List<Row> output = harness.getOutput();
+ assertThat(output).hasSize(3);
+ assertThat(output.get(0)).isEqualTo(Row.of(1L));
+ assertThat(output.get(1)).isEqualTo(Row.of(2L));
+ assertThat(output.get(2)).isEqualTo(Row.of(3L));
+
+ StatefulOptionalPartitionPTF.CounterState state =
+ harness.getStateForKey("state", Row.of());
+ assertThat(state.counter).isEqualTo(3L);
Review Comment:
What are the expected semantics here?
Is the function executed with parallelism = 1 and a default key (all data
goes through the same function instance?
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarness.java:
##########
@@ -270,43 +273,69 @@ public void clearOutput() {
output.clear();
}
- /**
- * Given a target table argument and a row to process, construct the right
set of arguments for
- * the PTF's eval function and attempt to invoke it.
- */
- private void invokeEval(ArgumentInfo activeTableArg, Row activeRow) throws
Exception {
- // Set collector context so it can prepend columns if needed
- collector.setContext(activeTableArg, activeRow);
+ /** Get state for a specific partition key. */
+ public <T> T getStateForKey(String stateName, Row partitionKey) {
+ return stateManager.getStateForKey(stateName, partitionKey);
+ }
- Object[] args = new Object[arguments.size()];
+ /** Set state for a specific partition key. */
+ public void setStateForKey(String stateName, Row partitionKey, Object
state) throws Exception {
+ stateManager.setStateForKey(stateName, partitionKey, state);
+ }
- for (int i = 0; i < arguments.size(); i++) {
- ArgumentInfo arg = arguments.get(i);
+ /** Get all partition keys that have a specific state entry. */
+ public Set<Row> getKeysForState(String stateName) {
+ return stateManager.getKeysForState(stateName);
+ }
+
+ /** Get all state values for a state name across all partition keys. */
+ public <T> Map<Row, T> getStateForAllKeys(String stateName) {
+ return stateManager.getStateForAllKeys(stateName);
+ }
+
+ /** Clear all state for a given partition key. */
+ public void clearStateForKey(Row partitionKey) {
+ stateManager.clearStateForKey(partitionKey);
+ }
- if (arg.isTableArgument && arg.name.equals(activeTableArg.name)) {
- // If the argument is the active table argument, first convert
the input row
- // to an internal RowData type, and then convert the RowData
to type that the
- // argument expects. For Rows, this will structure the Row
based on the table
- // argument structure. Otherwise, for POJOs, it will pass the
expected POJO to eval.
+ /** Clear specific state entry for a given partition key. */
+ public void clearStateEntryForKey(String stateName, Row partitionKey) {
Review Comment:
```suggestion
public void clearStateForKey(String stateName, Row partitionKey) {
```
the other state methods don't have `Entry` in their name.
This method name is more consistent with the other methods
(`getStateForKey()`, `setStateForKey()`) which also address one state for one
key.
##########
docs/content.zh/docs/dev/table/functions/ptfs.md:
##########
@@ -2275,6 +2275,189 @@ void testScalarOnly() throws Exception {
{{< /tab >}}
{{< /tabs >}}
+#### Testing with State
+
+The harness supports all PTF state types: value state, `Row`, `ListView`, and
`MapView`.
+
+{{< tabs "state-testing" >}}
+{{< tab "Java" >}}
+```java
+// A PTF that uses all four state types: value state, Row, ListView, and
MapView.
Review Comment:
```suggestion
// A PTF that uses all four state types: Pojo value state, Row value state,
ListView state, and MapView state.
```
In the Flink docs we typically refer to these custom classes for data
transport or state as Pojos.
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessStateManager.java:
##########
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.types.Row;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * State manager for {@link ProcessTableFunctionTestHarness}.
+ *
+ * <p>Handles state storage, lifecycle, and conversion between external and
internal storage
+ * formats.
+ */
+@Internal
+class TestHarnessStateManager {
+
+ private final Map<Row, Map<String, Object>> stateByKey = new HashMap<>();
+ private final List<ProcessTableFunctionTestHarness.StateArgumentInfo>
stateArguments;
+ private final Map<String, StateConverter> stateConverters;
+ private final PartitionKeyInfo partitionKeyInfo;
+
+ TestHarnessStateManager(
+ List<ProcessTableFunctionTestHarness.StateArgumentInfo>
stateArguments,
+ Map<String, StateConverter> stateConverters,
+ PartitionKeyInfo partitionKeyInfo) {
+ this.stateArguments = stateArguments;
+ this.stateConverters = stateConverters;
+ this.partitionKeyInfo = partitionKeyInfo;
+ }
+
+ static class PartitionKeyInfo {
+ final int arity;
+ @Nullable final String[] columnNames;
+ @Nullable final Class<?>[] columnTypes;
+
+ PartitionKeyInfo(
+ int arity,
+ @Nullable String[] columnNames,
+ @Nullable LogicalType[] columnLogicalTypes) {
+ this.arity = arity;
+ this.columnNames = columnNames;
+ this.columnTypes =
+ columnLogicalTypes != null
+ ? Arrays.stream(columnLogicalTypes)
+ .map(LogicalType::getDefaultConversion)
+ .toArray(Class<?>[]::new)
+ : null;
+ }
+
+ void validate(Row key) {
+ if (key.getArity() != arity) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Partition key has arity %d, but expected
arity %d.",
+ key.getArity(), arity));
+ }
+ if (columnTypes == null) {
+ return;
+ }
+ for (int i = 0; i < arity; i++) {
+ Object value = key.getField(i);
+ if (value != null && !columnTypes[i].isInstance(value)) {
+ String columnName = columnNames != null ? columnNames[i] :
"position " + i;
+ throw new IllegalArgumentException(
+ String.format(
+ "Partition key has type %s at position %d,
"
+ + "but partition column '%s'
expects %s.",
+ value.getClass().getSimpleName(),
+ i,
+ columnName,
+ columnTypes[i].getSimpleName()));
+ }
+ }
+ }
+ }
+
+ /**
+ * Load state for a partition key. Creates new state instances if none
exist. Converts internal
+ * storage to external objects (value state, ListView, MapView).
+ */
+ Map<String, Object> loadStateForKey(Row key) {
+ Map<String, Object> internalState =
+ stateByKey.computeIfAbsent(key, k -> createEmptyKeyState());
+
+ Map<String, Object> externalState = new HashMap<>();
+ for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg :
stateArguments) {
+ Object internalData = internalState.get(stateArg.name);
+ Object external = convertToExternal(internalData, stateArg);
+ externalState.put(stateArg.name, external);
+ }
+ return externalState;
+ }
+
+ /**
+ * Update mutated state after eval() invocation. Converts external objects
to internal format.
+ */
+ void updateStateForKey(Row key, Map<String, Object> externalState) throws
Exception {
+ Map<String, Object> internalState = new HashMap<>();
+ for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg :
stateArguments) {
+ Object external = externalState.get(stateArg.name);
+ Object internalData = convertToInternal(external, stateArg);
+ internalState.put(stateArg.name, internalData);
+ }
+ stateByKey.put(key, internalState);
+ }
+
+ /** Clear all state for a partition key. */
+ void clearStateForKey(Row key) {
+ partitionKeyInfo.validate(key);
+ stateByKey.remove(key);
+ }
+
+ /** Clear specific state entry for a given partition key, resetting it to
its default value. */
+ void clearStateEntryForKey(String stateName, Row key) {
Review Comment:
```suggestion
void clearStateForKey(String stateName, Row key) {
```
##########
flink-table/flink-table-test-utils/src/main/java/org/apache/flink/table/runtime/functions/TestHarnessStateManager.java:
##########
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.types.Row;
+
+import javax.annotation.Nullable;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * State manager for {@link ProcessTableFunctionTestHarness}.
+ *
+ * <p>Handles state storage, lifecycle, and conversion between external and
internal storage
+ * formats.
+ */
+@Internal
+class TestHarnessStateManager {
+
+ private final Map<Row, Map<String, Object>> stateByKey = new HashMap<>();
+ private final List<ProcessTableFunctionTestHarness.StateArgumentInfo>
stateArguments;
+ private final Map<String, StateConverter> stateConverters;
+ private final PartitionKeyInfo partitionKeyInfo;
+
+ TestHarnessStateManager(
+ List<ProcessTableFunctionTestHarness.StateArgumentInfo>
stateArguments,
+ Map<String, StateConverter> stateConverters,
+ PartitionKeyInfo partitionKeyInfo) {
+ this.stateArguments = stateArguments;
+ this.stateConverters = stateConverters;
+ this.partitionKeyInfo = partitionKeyInfo;
+ }
+
+ static class PartitionKeyInfo {
+ final int arity;
+ @Nullable final String[] columnNames;
+ @Nullable final Class<?>[] columnTypes;
+
+ PartitionKeyInfo(
+ int arity,
+ @Nullable String[] columnNames,
+ @Nullable LogicalType[] columnLogicalTypes) {
+ this.arity = arity;
+ this.columnNames = columnNames;
+ this.columnTypes =
+ columnLogicalTypes != null
+ ? Arrays.stream(columnLogicalTypes)
+ .map(LogicalType::getDefaultConversion)
+ .toArray(Class<?>[]::new)
+ : null;
+ }
+
+ void validate(Row key) {
+ if (key.getArity() != arity) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Partition key has arity %d, but expected
arity %d.",
+ key.getArity(), arity));
+ }
+ if (columnTypes == null) {
+ return;
+ }
+ for (int i = 0; i < arity; i++) {
+ Object value = key.getField(i);
+ if (value != null && !columnTypes[i].isInstance(value)) {
+ String columnName = columnNames != null ? columnNames[i] :
"position " + i;
+ throw new IllegalArgumentException(
+ String.format(
+ "Partition key has type %s at position %d,
"
+ + "but partition column '%s'
expects %s.",
+ value.getClass().getSimpleName(),
+ i,
+ columnName,
+ columnTypes[i].getSimpleName()));
+ }
+ }
+ }
+ }
+
+ /**
+ * Load state for a partition key. Creates new state instances if none
exist. Converts internal
+ * storage to external objects (value state, ListView, MapView).
+ */
+ Map<String, Object> loadStateForKey(Row key) {
+ Map<String, Object> internalState =
+ stateByKey.computeIfAbsent(key, k -> createEmptyKeyState());
+
+ Map<String, Object> externalState = new HashMap<>();
+ for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg :
stateArguments) {
+ Object internalData = internalState.get(stateArg.name);
+ Object external = convertToExternal(internalData, stateArg);
+ externalState.put(stateArg.name, external);
+ }
+ return externalState;
+ }
+
+ /**
+ * Update mutated state after eval() invocation. Converts external objects
to internal format.
+ */
+ void updateStateForKey(Row key, Map<String, Object> externalState) throws
Exception {
+ Map<String, Object> internalState = new HashMap<>();
+ for (ProcessTableFunctionTestHarness.StateArgumentInfo stateArg :
stateArguments) {
+ Object external = externalState.get(stateArg.name);
+ Object internalData = convertToInternal(external, stateArg);
+ internalState.put(stateArg.name, internalData);
+ }
+ stateByKey.put(key, internalState);
+ }
+
+ /** Clear all state for a partition key. */
+ void clearStateForKey(Row key) {
Review Comment:
```suggestion
void clearAllStatesForKey(Row key) {
```
##########
flink-table/flink-table-test-utils/src/test/java/org/apache/flink/table/runtime/functions/ProcessTableFunctionTestHarnessTest.java:
##########
@@ -1062,4 +1110,309 @@ void testPartitionByDuplicateConfigThrows() {
assertThat(exception.getMessage()).contains("Partition config already
exists");
}
+
+ //
-------------------------------------------------------------------------
+ // State Tests
+ //
-------------------------------------------------------------------------
+
+ @Test
+ void testPojoState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithPojoState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ assertThat(harness.getOutput()).containsExactly(Row.of("Alice", 1L));
+
+ PTFWithPojoState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(1L);
+
+ harness.processElementForTable("input", Row.of("Alice", 15));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("Alice", 2L));
+
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(2L);
+
+ harness.close();
+ }
+
+ @Test
+ void testPojoStatePartitionIsolation() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithPojoState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Bob", 20));
+ harness.processElementForTable("input", Row.of("Alice", 15));
+
+ PTFWithPojoState.CounterState aliceState =
harness.getStateForKey("state", Row.of("Alice"));
+ PTFWithPojoState.CounterState bobState =
harness.getStateForKey("state", Row.of("Bob"));
+
+ assertThat(aliceState.counter).isEqualTo(2L);
+ assertThat(bobState.counter).isEqualTo(1L);
+
+ harness.close();
+ }
+
+ @Test
+ void testPojoStateWithInitialState() throws Exception {
+ PTFWithPojoState.CounterState initialState = new
PTFWithPojoState.CounterState();
+ initialState.counter = 100L;
+
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithPojoState.class)
+ .withTableArgument("input", DataTypes.of("ROW<id
INT>"))
+ .withPartitionBy("input", "id")
+ .withInitialStateArgument("state", Row.of(1),
initialState)
+ .build();
+
+ PTFWithPojoState.CounterState state = harness.getStateForKey("state",
Row.of(1));
+ assertThat(state.counter).isEqualTo(100L);
+
+ harness.processElement(Row.of(1));
+ assertThat(harness.getOutput()).containsExactly(Row.of(1, 101L));
+
+ harness.processElement(Row.of(2));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of(2, 1L));
+
+ harness.close();
+ }
+
+ @Test
+ void testGetStateKeys() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithPojoState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Bob", 20));
+ harness.processElementForTable("input", Row.of("Charlie", 30));
+
+ java.util.Set<Row> keys = harness.getStateKeys("state");
+ assertThat(keys)
+ .containsExactlyInAnyOrder(Row.of("Alice"), Row.of("Bob"),
Row.of("Charlie"));
+
+ harness.close();
+ }
+
+ @Test
+ void testGetAllState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithPojoState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Alice", 15));
+ harness.processElementForTable("input", Row.of("Bob", 20));
+
+ java.util.Map<Row, PTFWithPojoState.CounterState> allState =
harness.getAllState("state");
+
+ assertThat(allState).hasSize(2);
+ assertThat(allState.get(Row.of("Alice")).counter).isEqualTo(2L);
+ assertThat(allState.get(Row.of("Bob")).counter).isEqualTo(1L);
+
+ harness.close();
+ }
+
+ @Test
+ void testListViewState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithListViewState.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .withPartitionBy("input", "key")
+ .build();
+
+ harness.processElementForTable("input", Row.of("A", 1));
+ assertThat(harness.getOutput()).containsExactly(Row.of("A", new
Integer[] {1}));
+
+ harness.processElementForTable("input", Row.of("A", 2));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("A", new
Integer[] {1, 2}));
+
+ org.apache.flink.table.api.dataview.ListView<Integer> listState =
+ harness.getStateForKey("listState", Row.of("A"));
+ assertThat(listState.get()).containsExactly(1, 2);
+
+ harness.close();
+ }
+
+ @Test
+ void testMapViewState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithMapViewState.class)
+ .withTableArgument(
+ "input", DataTypes.of("ROW<partition STRING,
key STRING>"))
+ .withPartitionBy("input", "partition")
+ .build();
+
+ harness.processElementForTable("input", Row.of("P1", "foo"));
+ assertThat(harness.getOutput()).containsExactly(Row.of("P1", "foo",
1));
+
+ harness.processElementForTable("input", Row.of("P1", "foo"));
+ assertThat(harness.getOutput().get(1)).isEqualTo(Row.of("P1", "foo",
2));
+
+ harness.processElementForTable("input", Row.of("P1", "bar"));
+
+ org.apache.flink.table.api.dataview.MapView<String, Integer> mapState =
+ harness.getStateForKey("mapState", Row.of("P1"));
+ assertThat(mapState.get("foo")).isEqualTo(2);
+ assertThat(mapState.get("bar")).isEqualTo(1);
+
+ harness.close();
+ }
+
+ @Test
+ void testEmptyState() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithPojoState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ PTFWithPojoState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+
+ assertThat(state).isNull();
+
+ harness.close();
+ }
+
+ @Test
+ void testClearStateForPartition() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithPojoState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Alice", 15));
+
+ PTFWithPojoState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(2L);
+
+ harness.clearStateForPartition(Row.of("Alice"));
+
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state).isNull();
+
+ harness.close();
+ }
+
+ @Test
+ void testClearStateEntry() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+ ProcessTableFunctionTestHarness.ofClass(PTFWithPojoState.class)
+ .withTableArgument("input", DataTypes.of("ROW<name
STRING, value INT>"))
+ .withPartitionBy("input", "name")
+ .build();
+
+ harness.processElementForTable("input", Row.of("Alice", 10));
+ harness.processElementForTable("input", Row.of("Alice", 15));
+
+ PTFWithPojoState.CounterState state = harness.getStateForKey("state",
Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(2L);
+
+ harness.clearStateEntry(Row.of("Alice"), "state");
+
+ state = harness.getStateForKey("state", Row.of("Alice"));
+ assertThat(state.counter).isEqualTo(0L);
+
+ harness.close();
+ }
+
+ @Test
+ void testMultipleStateParameters() throws Exception {
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithMultipleStates.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .withPartitionBy("input", "key")
+ .build();
+
+ harness.processElementForTable("input", Row.of("A", 10));
+ harness.processElementForTable("input", Row.of("A", 20));
+ harness.processElementForTable("input", Row.of("B", 5));
+
+ assertThat(harness.getOutput())
+ .containsExactly(Row.of("A", 1L, 10), Row.of("A", 2L, 30),
Row.of("B", 1L, 5));
+
+ PTFWithMultipleStates.CounterState counterA =
+ harness.getStateForKey("counter", Row.of("A"));
+ assertThat(counterA.count).isEqualTo(2L);
+
+ ListView<Integer> historyA = harness.getStateForKey("history",
Row.of("A"));
+ assertThat(historyA.get()).containsExactly(10, 20);
+
+ harness.close();
+ }
+
+ @Test
+ void testInitialStateWithListView() throws Exception {
+ ListView<Integer> initialList = new ListView<>();
+ initialList.add(100);
+ initialList.add(200);
+
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithListViewState.class)
+ .withTableArgument("input", DataTypes.of("ROW<key
STRING, value INT>"))
+ .withPartitionBy("input", "key")
+ .withInitialStateArgument("listState", Row.of("A"),
initialList)
+ .build();
+
+ ListView<Integer> listState = harness.getStateForKey("listState",
Row.of("A"));
+ assertThat(listState.get()).containsExactly(100, 200);
+
+ harness.processElementForTable("input", Row.of("A", 3));
+ assertThat(harness.getOutput()).containsExactly(Row.of("A", new
Integer[] {100, 200, 3}));
+
+ harness.close();
+ }
+
+ @Test
+ void testInitialStateWithMapView() throws Exception {
+ MapView<String, Integer> initialMap = new MapView<>();
+ initialMap.put("existing", 42);
+
+ ProcessTableFunctionTestHarness<Row> harness =
+
ProcessTableFunctionTestHarness.ofClass(PTFWithMapViewState.class)
+ .withTableArgument(
+ "input", DataTypes.of("ROW<partition STRING,
key STRING>"))
+ .withPartitionBy("input", "partition")
+ .withInitialStateArgument("mapState", Row.of("P1"),
initialMap)
+ .build();
+
+ MapView<String, Integer> mapState = harness.getStateForKey("mapState",
Row.of("P1"));
+ assertThat(mapState.get("existing")).isEqualTo(42);
+
+ harness.processElementForTable("input", Row.of("P1", "existing"));
+ assertThat(harness.getOutput()).containsExactly(Row.of("P1",
"existing", 43));
+
+ harness.close();
+ }
+
Review Comment:
Sounds good, thank you!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]