This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
The following commit(s) were added to refs/heads/main by this push: new 1369252 [Feature] [api] Introduce per action state consistency API (#138) 1369252 is described below commit 136925237498cdac3ac025a9adef8143ae301e91 Author: Letao Jiang <let...@users.noreply.github.com> AuthorDate: Sat Sep 20 19:30:06 2025 -0700 [Feature] [api] Introduce per action state consistency API (#138) --- .../api/configuration/AgentConfigOptions.java | 4 + .../flink/agents/api/context/MemoryUpdate.java | 75 ++++++ .../agents/runtime/actionstate/ActionState.java | 93 +++++++ .../runtime/actionstate/ActionStateStore.java | 92 +++++++ .../runtime/actionstate/ActionStateUtil.java | 90 +++++++ .../runtime/actionstate/KafkaActionStateStore.java | 70 +++++ .../agents/runtime/context/RunnerContextImpl.java | 19 +- .../agents/runtime/memory/MemoryObjectImpl.java | 28 +- .../runtime/operator/ActionExecutionOperator.java | 219 +++++++++++++++- .../operator/ActionExecutionOperatorFactory.java | 14 +- .../runtime/actionstate/ActionStateUtilTest.java | 227 ++++++++++++++++ .../actionstate/InMemoryActionStateStore.java | 77 ++++++ .../agents/runtime/memory/MemoryObjectTest.java | 27 +- .../flink/agents/runtime/memory/MemoryRefTest.java | 2 +- .../operator/ActionExecutionOperatorTest.java | 289 ++++++++++++++++++++- 15 files changed, 1304 insertions(+), 22 deletions(-) diff --git a/api/src/main/java/org/apache/flink/agents/api/configuration/AgentConfigOptions.java b/api/src/main/java/org/apache/flink/agents/api/configuration/AgentConfigOptions.java index 49e476d..814d91c 100644 --- a/api/src/main/java/org/apache/flink/agents/api/configuration/AgentConfigOptions.java +++ b/api/src/main/java/org/apache/flink/agents/api/configuration/AgentConfigOptions.java @@ -23,4 +23,8 @@ public class AgentConfigOptions { /** The config parameter specifies the directory for the FileEvent file. */ public static final ConfigOption<String> BASE_LOG_DIR = new ConfigOption<>("baseLogDir", String.class, null); + + /** The config parameter specifies the backend for action state store. */ + public static final ConfigOption<String> ACTION_STATE_STORE_BACKEND = + new ConfigOption<>("actionStateStoreBackend", String.class, null); } diff --git a/api/src/main/java/org/apache/flink/agents/api/context/MemoryUpdate.java b/api/src/main/java/org/apache/flink/agents/api/context/MemoryUpdate.java new file mode 100644 index 0000000..7169fc2 --- /dev/null +++ b/api/src/main/java/org/apache/flink/agents/api/context/MemoryUpdate.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.api.context; + +import java.io.Serializable; +import java.util.Objects; + +public class MemoryUpdate implements Serializable { + private static final long serialVersionUID = 1L; + + private final String path; + private final Object value; + + /** + * Creates a new MemoryUpdate instance. + * + * @param path the absolute path of the data in Short-Term Memory. + * @param value the new value to set at the specified path. + */ + public MemoryUpdate(String path, Object value) { + this.path = path; + this.value = value; + } + + /** + * Gets the path of the memory update. + * + * @return the absolute path of the data in Short-Term Memory. + */ + public String getPath() { + return path; + } + + /** + * Gets the value of the memory update. + * + * @return the new value to set at the specified path. + */ + public Object getValue() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof MemoryUpdate)) return false; + MemoryUpdate that = (MemoryUpdate) o; + return Objects.equals(path, that.path) && Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(path, value); + } + + @Override + public String toString() { + return "MemoryUpdate{" + "path='" + path + '\'' + ", value=" + value + '}'; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java new file mode 100644 index 0000000..258fb9a --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionState.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.context.MemoryUpdate; +import org.apache.flink.agents.runtime.operator.ActionTask; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public class ActionState { + private final Event taskEvent; + private final List<MemoryUpdate> memoryUpdates; + private final List<Event> outputEvents; + private ActionTask generatedActionTask; + + /** Constructs a new TaskActionState instance. */ + public ActionState(final Event taskEvent) { + this.taskEvent = taskEvent; + memoryUpdates = new ArrayList<>(); + outputEvents = new ArrayList<>(); + } + + /** Getters for the fields */ + public Event getTaskEvent() { + return taskEvent; + } + + public List<MemoryUpdate> getMemoryUpdates() { + return memoryUpdates; + } + + public List<Event> getOutputEvents() { + return outputEvents; + } + + public Optional<ActionTask> getGeneratedActionTask() { + return Optional.ofNullable(generatedActionTask); + } + + /** Setters for the fields */ + public void addMemoryUpdate(MemoryUpdate memoryUpdate) { + memoryUpdates.add(memoryUpdate); + } + + public void addEvent(Event event) { + outputEvents.add(event); + } + + public void setGeneratedActionTask(ActionTask generatedActionTask) { + this.generatedActionTask = generatedActionTask; + } + + @Override + public int hashCode() { + int result = taskEvent != null ? taskEvent.hashCode() : 0; + result = 31 * result + (memoryUpdates != null ? memoryUpdates.hashCode() : 0); + result = 31 * result + (outputEvents != null ? outputEvents.hashCode() : 0); + result = 31 * result + (generatedActionTask != null ? generatedActionTask.hashCode() : 0); + return result; + } + + @Override + public String toString() { + return "TaskActionState{" + + "taskEvent=" + + taskEvent + + ", memoryUpdates=" + + memoryUpdates + + ", outputEvents=" + + outputEvents + + ", generatedActionTask=" + + generatedActionTask + + '}'; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionStateStore.java b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionStateStore.java new file mode 100644 index 0000000..e19450d --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionStateStore.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.Action; + +import java.io.IOException; +import java.util.List; + +/** Interface for storing and retrieving the state of actions performed by agents. */ +public interface ActionStateStore { + enum BackendType { + KAFKA("kafka"); + + private final String type; + + BackendType(String type) { + this.type = type; + } + + public String getType() { + return type; + } + } + + /** + * Store the state of a specific action associated with a given key to the backend storage. + * + * @param key the key associate with the message + * @param seqNum the sequence number of the key + * @param action the action the agent is taking + * @param event the event that triggered the action + * @param state the current state of the whole task + * @throws IOException when key generation failed + */ + void put(Object key, long seqNum, Action action, Event event, ActionState state) + throws IOException; + + /** + * Retrieve the state of a specific action associated with a given key from the backend storage. + * It any of the sequence number for a key can't be found, all the states associated with the + * key after the sequence number should be ignored and null will be returned. + * + * @param key the key associated with the message + * @param seqNum the sequence number of the key + * @param action the action the agent is taking + * @param event the event that triggered the action + * @return the state of the action, or null if not found + * @throws IOException when key generation failed + */ + ActionState get(Object key, long seqNum, Action action, Event event) throws IOException; + + /** + * Rebuild the in-memory state from the backend storage using the provided recovery markers. + * + * @param recoveryMarkers a list of markers representing the recovery points + */ + void rebuildState(List<Object> recoveryMarkers); + + /** + * Prune the state for a given key. + * + * @param key the key whose state should be pruned + * @param seqNum the sequence number up to which the state should be pruned + */ + void pruneState(Object key, long seqNum); + + /** + * Get a marker object representing the current recovery point in the state store. + * + * @return a marker object, or null if not supported + */ + default Object getRecoveryMarker() { + return null; + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionStateUtil.java b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionStateUtil.java new file mode 100644 index 0000000..daed7c3 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/ActionStateUtil.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.agents.runtime.python.event.PythonEvent; +import org.apache.flink.shaded.guava31.com.google.common.base.Preconditions; +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; + +import javax.annotation.Nonnull; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.UUID; + +/** Utility class for action state related operations. */ +public class ActionStateUtil { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private static final String KEY_SEPARATOR = "_"; + + public static String generateKey( + @Nonnull Object key, long seqNum, @Nonnull Action action, @Nonnull Event event) + throws IOException { + Preconditions.checkNotNull(key, "key cannot be null."); + Preconditions.checkNotNull(action, "action cannot be null."); + Preconditions.checkNotNull(event, "event cannot be null."); + return String.join( + KEY_SEPARATOR, + key.toString(), + String.valueOf(seqNum), + generateUUIDForEvent(event), + generateUUIDForAction(action)); + } + + public static List<String> parseKey(String key) { + Preconditions.checkNotNull(key, "key cannot be null."); + String[] parts = key.split(KEY_SEPARATOR); + Preconditions.checkArgument(parts.length == 4, "Invalid key format."); + return List.of(parts); + } + + private static String generateUUIDForEvent(Event event) throws IOException { + if (event instanceof InputEvent) { + InputEvent inputEvent = (InputEvent) event; + byte[] inputEventBytes = + MAPPER.writeValueAsBytes( + new Object[] {inputEvent.getInput(), inputEvent.getAttributes()}); + return String.valueOf(UUID.nameUUIDFromBytes(inputEventBytes)); + } else if (event instanceof PythonEvent) { + PythonEvent pythonEvent = (PythonEvent) event; + byte[] pythonEventBytes = + MAPPER.writeValueAsBytes( + new Object[] { + pythonEvent.getEvent(), + pythonEvent.getEventType(), + pythonEvent.getAttributes() + }); + return String.valueOf(UUID.nameUUIDFromBytes(pythonEventBytes)); + } else { + return String.valueOf( + UUID.nameUUIDFromBytes( + event.getAttributes().toString().getBytes(StandardCharsets.UTF_8))); + } + } + + private static String generateUUIDForAction(Action action) throws IOException { + return String.valueOf( + UUID.nameUUIDFromBytes( + String.valueOf(action.hashCode()).getBytes(StandardCharsets.UTF_8))); + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStore.java b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStore.java new file mode 100644 index 0000000..6b6ced8 --- /dev/null +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/KafkaActionStateStore.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.annotation.VisibleForTesting; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An implementation of ActionStateStore that uses Kafka as the backend storage for action states. + * This class provides methods to put, get, and retrieve all action states associated with a given + * key and action. + */ +public class KafkaActionStateStore implements ActionStateStore { + + // In memory action state for quick state retrival, this map is only used during recovery + private final Map<String, Map<String, ActionState>> keyedActionStates; + + @VisibleForTesting + KafkaActionStateStore(Map<String, Map<String, ActionState>> keyedActionStates) { + this.keyedActionStates = keyedActionStates; + } + + /** Constructs a new KafkaActionStateStore with an empty in-memory action state map. */ + public KafkaActionStateStore() { + this(new HashMap<>()); + } + + @Override + public void put(Object key, long seqNum, Action action, Event event, ActionState state) + throws IOException { + // TODO: Implement me + } + + @Override + public ActionState get(Object key, long seqNum, Action action, Event event) throws IOException { + // TODO: Implement me + return null; + } + + @Override + public void rebuildState(List<Object> recoveryMarker) { + // TODO: implement me + } + + @Override + public void pruneState(Object key, long seqNum) { + // TODO: implement me + } +} diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java index 53e8509..653c5cc 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/context/RunnerContextImpl.java @@ -20,6 +20,7 @@ package org.apache.flink.agents.runtime.context; import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.configuration.ReadableConfiguration; import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.MemoryUpdate; import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceType; @@ -32,6 +33,7 @@ import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessin import org.apache.flink.util.Preconditions; import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -46,6 +48,7 @@ public class RunnerContextImpl implements RunnerContext { protected final FlinkAgentsMetricGroupImpl agentMetricGroup; protected final Runnable mailboxThreadChecker; protected final AgentPlan agentPlan; + protected final List<MemoryUpdate> memoryUpdates; protected String actionName; public RunnerContextImpl( @@ -57,6 +60,7 @@ public class RunnerContextImpl implements RunnerContext { this.agentMetricGroup = agentMetricGroup; this.mailboxThreadChecker = mailboxThreadChecker; this.agentPlan = agentPlan; + this.memoryUpdates = new LinkedList<>(); } public void setActionName(String actionName) { @@ -98,10 +102,23 @@ public class RunnerContextImpl implements RunnerContext { this.pendingEvents.isEmpty(), "There are pending events remaining in the context."); } + /** + * Gets all the updates made to this MemoryObject since it was created or the last time this + * method was called. This method lives here because it is internally used by the ActionTask to + * persist memory updates after an action is executed. + * + * @return list of memory updates + */ + public List<MemoryUpdate> getAllMemoryUpdates() { + mailboxThreadChecker.run(); + return List.copyOf(memoryUpdates); + } + @Override public MemoryObject getShortTermMemory() throws Exception { mailboxThreadChecker.run(); - return new MemoryObjectImpl(store, MemoryObjectImpl.ROOT_KEY, mailboxThreadChecker); + return new MemoryObjectImpl( + store, MemoryObjectImpl.ROOT_KEY, mailboxThreadChecker, memoryUpdates); } @Override diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java index 6999e07..70856eb 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/memory/MemoryObjectImpl.java @@ -19,10 +19,17 @@ package org.apache.flink.agents.runtime.memory; import org.apache.flink.agents.api.context.MemoryObject; import org.apache.flink.agents.api.context.MemoryRef; +import org.apache.flink.agents.api.context.MemoryUpdate; import org.apache.flink.api.common.state.MapState; import java.io.Serializable; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; public class MemoryObjectImpl implements MemoryObject { @@ -35,15 +42,21 @@ public class MemoryObjectImpl implements MemoryObject { private static final String SEPARATOR = "."; private final MapState<String, MemoryItem> store; + private final List<MemoryUpdate> memoryUpdates; private final String prefix; private final Runnable mailboxThreadChecker; - public MemoryObjectImpl(MapState<String, MemoryItem> store, String prefix) throws Exception { - this(store, prefix, () -> {}); + public MemoryObjectImpl( + MapState<String, MemoryItem> store, String prefix, List<MemoryUpdate> memoryUpdates) + throws Exception { + this(store, prefix, () -> {}, memoryUpdates); } public MemoryObjectImpl( - MapState<String, MemoryItem> store, String prefix, Runnable mailboxThreadChecker) + MapState<String, MemoryItem> store, + String prefix, + Runnable mailboxThreadChecker, + List<MemoryUpdate> memoryUpdates) throws Exception { this.store = store; this.prefix = prefix; @@ -51,6 +64,7 @@ public class MemoryObjectImpl implements MemoryObject { if (!store.contains(ROOT_KEY)) { store.put(ROOT_KEY, new MemoryItem()); } + this.memoryUpdates = memoryUpdates; } @Override @@ -58,7 +72,7 @@ public class MemoryObjectImpl implements MemoryObject { mailboxThreadChecker.run(); String absPath = fullPath(path); if (store.contains(absPath)) { - return new MemoryObjectImpl(store, absPath); + return new MemoryObjectImpl(store, absPath, memoryUpdates); } return null; } @@ -90,6 +104,7 @@ public class MemoryObjectImpl implements MemoryObject { MemoryItem val = new MemoryItem(value); store.put(absPath, val); + memoryUpdates.add(new MemoryUpdate(absPath, value)); return MemoryRef.create(absPath); } @@ -114,6 +129,7 @@ public class MemoryObjectImpl implements MemoryObject { } else { store.put(absPath, new MemoryItem()); } + memoryUpdates.add(new MemoryUpdate(absPath, null)); String parent = absPath.contains(SEPARATOR) @@ -123,7 +139,7 @@ public class MemoryObjectImpl implements MemoryObject { parentItem.getSubKeys().add(parts[parts.length - 1]); store.put(parent, parentItem); - return new MemoryObjectImpl(store, absPath); + return new MemoryObjectImpl(store, absPath, memoryUpdates); } @Override diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index cfda788..46ece97 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -21,6 +21,7 @@ import org.apache.flink.agents.api.Event; import org.apache.flink.agents.api.EventContext; import org.apache.flink.agents.api.InputEvent; import org.apache.flink.agents.api.OutputEvent; +import org.apache.flink.agents.api.context.MemoryUpdate; import org.apache.flink.agents.api.listener.EventListener; import org.apache.flink.agents.api.logger.EventLogger; import org.apache.flink.agents.api.logger.EventLoggerConfig; @@ -30,6 +31,9 @@ import org.apache.flink.agents.plan.Action; import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.plan.JavaFunction; import org.apache.flink.agents.plan.PythonFunction; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.ActionStateStore; +import org.apache.flink.agents.runtime.actionstate.KafkaActionStateStore; import org.apache.flink.agents.runtime.context.RunnerContextImpl; import org.apache.flink.agents.runtime.env.PythonEnvironmentManager; import org.apache.flink.agents.runtime.memory.MemoryObjectImpl; @@ -46,10 +50,20 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapState; import org.apache.flink.api.common.state.MapStateDescriptor; +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.python.env.PythonDependencyInfo; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.VoidNamespace; +import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.flink.streaming.api.operators.*; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.ChainingStrategy; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxExecutorImpl; @@ -59,12 +73,16 @@ import org.apache.flink.util.ExceptionUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import static org.apache.flink.agents.api.configuration.AgentConfigOptions.ACTION_STATE_STORE_BACKEND; +import static org.apache.flink.agents.runtime.actionstate.ActionStateStore.BackendType.KAFKA; import static org.apache.flink.agents.runtime.utils.StateUtil.*; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; @@ -86,6 +104,9 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private static final Logger LOG = LoggerFactory.getLogger(ActionExecutionOperator.class); + private static final String RECOVERY_MARKER_STATE_NAME = "recoveryMarker"; + private static final String MESSAGE_SEQUENCE_NUMBER_STATE_NAME = "messageSequenceNumber"; + private final AgentPlan agentPlan; private final Boolean inputIsJava; @@ -126,11 +147,16 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private final transient EventLogger eventLogger; private final transient List<EventListener> eventListeners; + private transient ActionStateStore actionStateStore; + private transient ValueState<Long> sequenceNumberKState; + private transient Map<Long, Map<Object, Long>> checkpointIdToSeqNums; + public ActionExecutionOperator( AgentPlan agentPlan, Boolean inputIsJava, ProcessingTimeService processingTimeService, - MailboxExecutor mailboxExecutor) { + MailboxExecutor mailboxExecutor, + ActionStateStore actionStateStore) { this.agentPlan = agentPlan; this.inputIsJava = inputIsJava; this.processingTimeService = processingTimeService; @@ -138,6 +164,8 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT this.mailboxExecutor = mailboxExecutor; this.eventLogger = EventLoggerFactory.createLogger(EventLoggerConfig.builder().build()); this.eventListeners = new ArrayList<>(); + this.actionStateStore = actionStateStore; + this.checkpointIdToSeqNums = new HashMap<>(); } @Override @@ -157,6 +185,21 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT metricGroup = new FlinkAgentsMetricGroupImpl(getMetricGroup()); builtInMetrics = new BuiltInMetrics(metricGroup, agentPlan); + // init the action state store with proper implementation + if (actionStateStore == null + && KAFKA.getType() + .equalsIgnoreCase(agentPlan.getConfig().get(ACTION_STATE_STORE_BACKEND))) { + LOG.info("Using Kafka as backend of action state store."); + actionStateStore = new KafkaActionStateStore(); + } + + // init sequence number state for per key message ordering + sequenceNumberKState = + getRuntimeContext() + .getState( + new ValueStateDescriptor<>( + MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class)); + // init agent processing related state actionTasksKState = getRuntimeContext() @@ -235,6 +278,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT if (isInputEvent) { // If the event is an InputEvent, we mark that the key is currently being processed. currentProcessingKeysOpState.add(key); + initOrIncSequenceNumber(); } // We then obtain the triggered action and add ActionTasks to the waiting processing // queue. @@ -287,6 +331,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT private void processActionTaskForKey(Object key) throws Exception { // 1. Get an action task for the key. setCurrentKey(key); + ActionTask actionTask = pollFromListState(actionTasksKState); if (actionTask == null) { int removedCount = removeFromListState(currentProcessingKeysOpState, key); @@ -301,20 +346,51 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // 2. Invoke the action task. createAndSetRunnerContext(actionTask); - ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke(); - for (Event actionOutputEvent : actionTaskResult.getOutputEvents()) { + + long sequenceNumber = sequenceNumberKState.value(); + boolean isFinished; + List<Event> outputEvents; + Optional<ActionTask> generatedActionTaskOpt; + ActionState actionState = + maybeGetActionState(key, sequenceNumber, actionTask.action, actionTask.event); + if (actionState != null && actionState.getGeneratedActionTask().isEmpty()) { + isFinished = true; + outputEvents = actionState.getOutputEvents(); + generatedActionTaskOpt = actionState.getGeneratedActionTask(); + for (MemoryUpdate memoryUpdate : actionState.getMemoryUpdates()) { + actionTask + .getRunnerContext() + .getShortTermMemory() + .set(memoryUpdate.getPath(), memoryUpdate.getValue()); + } + } else { + maybeInitActionState(key, sequenceNumber, actionTask.action, actionTask.event); + ActionTask.ActionTaskResult actionTaskResult = actionTask.invoke(); + maybePersistTaskResult( + key, + sequenceNumber, + actionTask.action, + actionTask.event, + actionTask.getRunnerContext(), + actionTaskResult); + isFinished = actionTaskResult.isFinished(); + outputEvents = actionTaskResult.getOutputEvents(); + generatedActionTaskOpt = actionTaskResult.getGeneratedActionTask(); + } + + for (Event actionOutputEvent : outputEvents) { processEvent(key, actionOutputEvent); } boolean currentInputEventFinished = false; - if (actionTaskResult.isFinished()) { + if (isFinished) { builtInMetrics.markActionExecuted(actionTask.action.getName()); currentInputEventFinished = !currentKeyHasMoreActionTask(); } else { - // If the action task not finished, we should get a new action task to execute continue. - Optional<ActionTask> generatedActionTaskOpt = actionTaskResult.getGeneratedActionTask(); + // If the action task is not finished, we should get a new action task to continue the + // execution. checkNotNull( - generatedActionTaskOpt.isPresent(), + generatedActionTaskOpt.get(), "ActionTask not finished, but the generated action task is null."); actionTasksKState.add(generatedActionTaskOpt.get()); } @@ -324,6 +400,7 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT // Once all sub-events and actions related to the current InputEvent are completed, // we can proceed to process the next InputEvent. int removedCount = removeFromListState(currentProcessingKeysOpState, key); + maybePruneState(key, sequenceNumber); checkState( removedCount == 1, "Current processing key count for key " @@ -392,6 +469,72 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT super.close(); } + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + if (actionStateStore != null) { + List<Object> markers = new ArrayList<>(); + + // We use UnionList here to ensure that the task can access all the recovery marker + // after + // parallelism modifications. + // The ActionStateStore will decide how to use the recovery markers. + ListState<Object> recoveryMarkerOpState = + getOperatorStateBackend() + .getUnionListState( + new ListStateDescriptor<>( + RECOVERY_MARKER_STATE_NAME, + TypeInformation.of(Object.class))); + + Iterable<Object> recoveryMarkers = recoveryMarkerOpState.get(); + if (recoveryMarkers != null) { + recoveryMarkers.forEach(markers::add); + } + actionStateStore.rebuildState(markers); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + if (actionStateStore != null) { + Object recoveryMarker = actionStateStore.getRecoveryMarker(); + if (recoveryMarker != null) { + ListState<Object> recoveryMarkerOpState = + getOperatorStateBackend() + .getListState( + new ListStateDescriptor<>( + RECOVERY_MARKER_STATE_NAME, + TypeInformation.of(Object.class))); + recoveryMarkerOpState.update(List.of(recoveryMarker)); + } + } + + HashMap<Object, Long> keyToSeqNum = new HashMap<>(); + getKeyedStateBackend() + .applyToAllKeys( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + new ValueStateDescriptor<>(MESSAGE_SEQUENCE_NUMBER_STATE_NAME, Long.class), + (key, state) -> keyToSeqNum.put(key, state.value())); + checkpointIdToSeqNums.put(context.getCheckpointId(), keyToSeqNum); + + super.snapshotState(context); + } + + @Override + public void notifyCheckpointComplete(long checkpointId) throws Exception { + if (actionStateStore != null) { + Map<Object, Long> keyToSeqNum = + checkpointIdToSeqNums.getOrDefault(checkpointId, new HashMap<>()); + for (Map.Entry<Object, Long> entry : keyToSeqNum.entrySet()) { + actionStateStore.pruneState(entry.getKey(), entry.getValue()); + } + checkpointIdToSeqNums.remove(checkpointId); + } + super.notifyCheckpointComplete(checkpointId); + } + private Event wrapToInputEvent(IN input) { if (inputIsJava) { return new InputEvent(input); @@ -485,6 +628,66 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT } } + private void initOrIncSequenceNumber() throws Exception { + // Initialize the sequence number state if it does not exist. + Long sequenceNumber = sequenceNumberKState.value(); + if (sequenceNumber == null) { + sequenceNumberKState.update(0L); + } else { + sequenceNumberKState.update(sequenceNumber + 1); + } + } + + private ActionState maybeGetActionState( + Object key, long sequenceNum, Action action, Event event) throws IOException { + return actionStateStore == null + ? null + : actionStateStore.get(key.toString(), sequenceNum, action, event); + } + + private void maybeInitActionState(Object key, long sequenceNum, Action action, Event event) + throws IOException { + if (actionStateStore != null) { + // Initialize the action state if it does not exist. It will exist when the action is an + // async action and + // has been persisted before the action task is finished. + if (actionStateStore.get(key, sequenceNum, action, event) == null) { + actionStateStore.put(key, sequenceNum, action, event, new ActionState(event)); + } + } + } + + private void maybePersistTaskResult( + Object key, + long sequenceNum, + Action action, + Event event, + RunnerContextImpl context, + ActionTask.ActionTaskResult actionTaskResult) + throws IOException { + if (actionStateStore == null) { + return; + } + + ActionState actionState = actionStateStore.get(key, sequenceNum, action, event); + actionState.setGeneratedActionTask(actionTaskResult.getGeneratedActionTask().orElse(null)); + + for (MemoryUpdate memoryUpdate : context.getAllMemoryUpdates()) { + actionState.addMemoryUpdate(memoryUpdate); + } + + for (Event outputEvent : actionTaskResult.getOutputEvents()) { + actionState.addEvent(outputEvent); + } + actionStateStore.put(key, sequenceNum, action, event, actionState); + } + + private void maybePruneState(Object key, long sequenceNum) throws IOException { + if (actionStateStore != null) { + actionStateStore.pruneState(key, sequenceNum); + } + } + /** Failed to execute Action task. */ public static class ActionTaskExecutionException extends Exception { public ActionTaskExecutionException(String message, Throwable cause) { diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java index 92b70a5..1b863ff 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorFactory.java @@ -18,6 +18,8 @@ package org.apache.flink.agents.runtime.operator; import org.apache.flink.agents.plan.AgentPlan; +import org.apache.flink.agents.runtime.actionstate.ActionStateStore; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperator; @@ -31,9 +33,18 @@ public class ActionExecutionOperatorFactory<IN, OUT> private final Boolean inputIsJava; + private final ActionStateStore actionStateStore; + public ActionExecutionOperatorFactory(AgentPlan agentPlan, Boolean inputIsJava) { + this(agentPlan, inputIsJava, null); + } + + @VisibleForTesting + protected ActionExecutionOperatorFactory( + AgentPlan agentPlan, Boolean inputIsJava, ActionStateStore actionStateStore) { this.agentPlan = agentPlan; this.inputIsJava = inputIsJava; + this.actionStateStore = actionStateStore; } @Override @@ -44,7 +55,8 @@ public class ActionExecutionOperatorFactory<IN, OUT> agentPlan, inputIsJava, parameters.getProcessingTimeService(), - parameters.getMailboxExecutor()); + parameters.getMailboxExecutor(), + actionStateStore); op.setup( parameters.getContainingTask(), parameters.getStreamConfig(), diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateUtilTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateUtilTest.java new file mode 100644 index 0000000..b66e7e4 --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/ActionStateUtilTest.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.api.InputEvent; +import org.apache.flink.agents.api.context.RunnerContext; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.agents.plan.JavaFunction; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** Test class for {@link ActionStateUtil}. */ +public class ActionStateUtilTest { + + @Test + public void testGenerateKeyConsistency() throws Exception { + // Create test data + Object key = "consistency-test"; + Action action = new TestAction("consistency-action"); + InputEvent inputEvent = new InputEvent("same-input"); + InputEvent inputEvent2 = new InputEvent("same-input"); + + // Generate keys multiple times + String key1 = ActionStateUtil.generateKey(key, 1, action, inputEvent); + String key2 = ActionStateUtil.generateKey(key, 1, action, inputEvent2); + + // Keys should be the same for the same input + assertEquals(key1, key2); + } + + @Test + public void testGenerateKeyDifferentInputs() throws Exception { + // Create test data + Object key = "diff-test"; + Action action = new TestAction("diff-action"); + InputEvent inputEvent1 = new InputEvent("input1"); + InputEvent inputEvent2 = new InputEvent("input2"); + + // Generate keys + String key1 = ActionStateUtil.generateKey(key, 1, action, inputEvent1); + String key2 = ActionStateUtil.generateKey(key, 1, action, inputEvent2); + + // Keys should be different for different inputs + assertNotEquals(key1, key2); + } + + @Test + public void testGenerateKeyWithNullKey() throws Exception { + Action action = new TestAction("test-action"); + InputEvent inputEvent = new InputEvent("test-input"); + + assertThrows( + NullPointerException.class, + () -> { + ActionStateUtil.generateKey(null, 1, action, inputEvent); + }); + } + + @Test + public void testGenerateKeyWithNullAction() { + Object key = "test-key"; + InputEvent inputEvent = new InputEvent("test-input"); + + assertThrows( + NullPointerException.class, + () -> { + ActionStateUtil.generateKey(key, 1, null, inputEvent); + }); + } + + @Test + public void testGenerateKeyWithNullEvent() throws Exception { + Object key = "test-key"; + Action action = new TestAction("test-action"); + + assertThrows( + NullPointerException.class, + () -> { + ActionStateUtil.generateKey(key, 1, action, null); + }); + } + + @Test + public void testParseKeyValidKey() throws Exception { + // Create test data and generate a key + Object key = "test-key"; + Action action = new TestAction("test-action"); + InputEvent inputEvent = new InputEvent("test-input"); + long seqNum = 123; + + String generatedKey = ActionStateUtil.generateKey(key, seqNum, action, inputEvent); + + // Parse the generated key + List<String> parsedParts = ActionStateUtil.parseKey(generatedKey); + + // Verify the parsed components + assertEquals(4, parsedParts.size()); + assertEquals(key.toString(), parsedParts.get(0)); + assertEquals(String.valueOf(seqNum), parsedParts.get(1)); + // The third and fourth parts are UUIDs - just verify they're non-empty + assertTrue(parsedParts.get(2).length() > 0); + assertTrue(parsedParts.get(3).length() > 0); + } + + @Test + public void testParseKeyRoundTrip() throws Exception { + // Test that generate -> parse -> values match original inputs + Object originalKey = "round-trip-test"; + Action action = new TestAction("round-trip-action"); + InputEvent inputEvent = new InputEvent("round-trip-input"); + long seqNum = 456; + + String generatedKey = ActionStateUtil.generateKey(originalKey, seqNum, action, inputEvent); + List<String> parsedParts = ActionStateUtil.parseKey(generatedKey); + + assertEquals(originalKey.toString(), parsedParts.get(0)); + assertEquals(String.valueOf(seqNum), parsedParts.get(1)); + } + + @Test + public void testParseKeyWithNullInput() { + assertThrows( + NullPointerException.class, + () -> { + ActionStateUtil.parseKey(null); + }); + } + + @Test + public void testParseKeyWithInvalidFormat() { + // Test with too few parts + assertThrows( + IllegalArgumentException.class, + () -> { + ActionStateUtil.parseKey("only_three_parts"); + }); + + // Test with too many parts + assertThrows( + IllegalArgumentException.class, + () -> { + ActionStateUtil.parseKey("one_two_three_four_five_six"); + }); + + // Test with empty string + assertThrows( + IllegalArgumentException.class, + () -> { + ActionStateUtil.parseKey(""); + }); + } + + @Test + public void testParseKeyWithSpecialCharacters() throws Exception { + // Test with keys containing special characters (but not the separator) + Object key = "key-with-special@chars#123"; + Action action = new TestAction("action-with-special@chars"); + InputEvent inputEvent = new InputEvent("input-with-special@chars"); + long seqNum = 789; + + String generatedKey = ActionStateUtil.generateKey(key, seqNum, action, inputEvent); + List<String> parsedParts = ActionStateUtil.parseKey(generatedKey); + + assertEquals(key.toString(), parsedParts.get(0)); + assertEquals(String.valueOf(seqNum), parsedParts.get(1)); + } + + @Test + public void testParseKeyConsistencyWithDifferentKeys() throws Exception { + // Generate keys with different inputs and verify parsing consistency + Action action = new TestAction("consistency-action"); + InputEvent inputEvent = new InputEvent("consistency-input"); + + String key1 = ActionStateUtil.generateKey("key1", 100, action, inputEvent); + String key2 = ActionStateUtil.generateKey("key2", 200, action, inputEvent); + + List<String> parsed1 = ActionStateUtil.parseKey(key1); + List<String> parsed2 = ActionStateUtil.parseKey(key2); + + // Keys should be different + assertNotEquals(parsed1.get(0), parsed2.get(0)); + assertNotEquals(parsed1.get(1), parsed2.get(1)); + + // But event and action UUIDs should be the same (same event and action) + assertEquals(parsed1.get(2), parsed2.get(2)); // Event UUID + assertEquals(parsed1.get(3), parsed2.get(3)); // Action UUID + } + + private static class TestAction extends Action { + + public static void doNothing(Event event, RunnerContext context) { + // No operation + } + + public TestAction(String name) throws Exception { + super( + name, + new JavaFunction( + TestAction.class.getName(), + "doNothing", + new Class[] {Event.class, RunnerContext.class}), + List.of(InputEvent.class.getName())); + } + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/InMemoryActionStateStore.java b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/InMemoryActionStateStore.java new file mode 100644 index 0000000..0a5dc9c --- /dev/null +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/actionstate/InMemoryActionStateStore.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.agents.runtime.actionstate; + +import org.apache.flink.agents.api.Event; +import org.apache.flink.agents.plan.Action; +import org.apache.flink.annotation.VisibleForTesting; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.agents.runtime.actionstate.ActionStateUtil.generateKey; + +/** + * An in-memory implementation of {@link ActionStateStore} for testing and local execution purposes. + * This implementation does not persist state across restarts. + */ +public class InMemoryActionStateStore implements ActionStateStore { + + private final Map<String, Map<String, ActionState>> keyedActionStates; + private final boolean doCleanup; + + public InMemoryActionStateStore(boolean doCleanup) { + this.keyedActionStates = new HashMap<>(); + this.doCleanup = doCleanup; + } + + @Override + public void put(Object key, long seqNum, Action action, Event event, ActionState state) + throws IOException { + Map<String, ActionState> actionStates = + keyedActionStates.getOrDefault(key.toString(), new HashMap<>()); + actionStates.put(generateKey(key.toString(), seqNum, action, event), state); + keyedActionStates.put(key.toString(), actionStates); + } + + @Override + public ActionState get(Object key, long seqNum, Action action, Event event) throws IOException { + return keyedActionStates + .getOrDefault(key.toString(), new HashMap<>()) + .get(generateKey(key.toString(), seqNum, action, event)); + } + + @Override + public void rebuildState(List<Object> recoveryMarker) { + // No-op for in-memory store as it does not persist state; + } + + @Override + public void pruneState(Object key, long seqNum) { + if (doCleanup) { + keyedActionStates.remove(key.toString()); + } + } + + @VisibleForTesting + public Map<String, Map<String, ActionState>> getKeyedActionStates() { + return keyedActionStates; + } +} diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java index 821110c..c6942c9 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryObjectTest.java @@ -18,18 +18,21 @@ package org.apache.flink.agents.runtime.memory; import org.apache.flink.agents.api.context.MemoryObject; +import org.apache.flink.agents.api.context.MemoryUpdate; import org.apache.flink.api.common.state.MapState; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.util.*; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.*; /** Tests for {@link MemoryObject}. */ public class MemoryObjectTest { - private MemoryObjectImpl memory; + private MemoryObject memory; + private List<MemoryUpdate> memoryUpdates; /** Simple POJO example. */ static class Person { @@ -58,7 +61,8 @@ public class MemoryObjectTest { @BeforeEach void setUp() throws Exception { ForTestMemoryMapState<MemoryObjectImpl.MemoryItem> mapState = new ForTestMemoryMapState<>(); - memory = new MemoryObjectImpl(mapState, MemoryObjectImpl.ROOT_KEY); + memoryUpdates = new LinkedList<>(); + memory = new MemoryObjectImpl(mapState, MemoryObjectImpl.ROOT_KEY, memoryUpdates); } @Test @@ -156,6 +160,25 @@ public class MemoryObjectTest { assertTrue(memory.isExist("exist")); assertFalse(memory.isExist("not.exist")); } + + @Test + void testMemoryUpdates() throws Exception { + memory.set("str", "hello"); + memory = memory.newObject("str", true); + memory.set("test", 100); + memory = memory.newObject("new_str", false); + memory.set("int", 42); + memory.set("str", "world"); + + assertThat(memoryUpdates) + .containsExactlyInAnyOrder( + new MemoryUpdate("str", "hello"), + new MemoryUpdate("str", null), + new MemoryUpdate("str.test", 100), + new MemoryUpdate("str.new_str", null), + new MemoryUpdate("str.new_str.int", 42), + new MemoryUpdate("str.new_str.str", "world")); + } } /** Simple, non-serialized HashMap implementation. */ diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java index 462a0af..a997a32 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/memory/MemoryRefTest.java @@ -110,7 +110,7 @@ public class MemoryRefTest { @BeforeEach void setUp() throws Exception { ForTestMemoryMapState<MemoryObjectImpl.MemoryItem> mapState = new ForTestMemoryMapState<>(); - memory = new MemoryObjectImpl(mapState, MemoryObjectImpl.ROOT_KEY); + memory = new MemoryObjectImpl(mapState, MemoryObjectImpl.ROOT_KEY, new LinkedList<>()); } @Test diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java index 90c558c..ca39a17 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperatorTest.java @@ -25,6 +25,8 @@ import org.apache.flink.agents.api.context.RunnerContext; import org.apache.flink.agents.plan.Action; import org.apache.flink.agents.plan.AgentPlan; import org.apache.flink.agents.plan.JavaFunction; +import org.apache.flink.agents.runtime.actionstate.ActionState; +import org.apache.flink.agents.runtime.actionstate.InMemoryActionStateStore; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; @@ -33,6 +35,7 @@ import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.util.ExceptionUtils; import org.junit.jupiter.api.Test; +import java.lang.reflect.Field; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -119,8 +122,6 @@ public class ActionExecutionOperatorTest { (KeySelector<Long, Long>) value -> value, TypeInformation.of(Long.class))) { testHarness.open(); - ActionExecutionOperator<Long, Object> operator = - (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); // Process input data 1 with key 0 testHarness.processElement(new StreamRecord<>(0L)); @@ -163,6 +164,288 @@ public class ActionExecutionOperatorTest { } } + @Test + void testInMemoryActionStateStoreIntegration() throws Exception { + AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false); + + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>( + agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Use reflection to access the action state store for validation + Field actionStateStoreField = + ActionExecutionOperator.class.getDeclaredField("actionStateStore"); + actionStateStoreField.setAccessible(true); + InMemoryActionStateStore actionStateStore = + (InMemoryActionStateStore) actionStateStoreField.get(operator); + + assertThat(actionStateStore).isNotNull(); + assertThat(actionStateStore.getKeyedActionStates()).isEmpty(); + + // Process an element and verify action state is created and managed + testHarness.processElement(new StreamRecord<>(5L)); + operator.waitInFlightEventsFinished(); + + // Verify that action states were created during processing + Map<String, Map<String, ActionState>> actionStates = + actionStateStore.getKeyedActionStates(); + assertThat(actionStates).isNotEmpty(); + + // Verify the content of stored action states + assertThat(actionStates.size()).isEqualTo(1); + + // Verify each action state contains expected information + for (Map.Entry<String, Map<String, ActionState>> outerEntry : actionStates.entrySet()) { + for (Map.Entry<String, ActionState> entry : outerEntry.getValue().entrySet()) { + ActionState state = entry.getValue(); + assertThat(state).isNotNull(); + assertThat(state.getTaskEvent()).isNotNull(); + + // Check that output events were captured + assertThat(state.getOutputEvents()).isNotEmpty(); + + // Verify the generated action task is empty (action completed) + assertThat(state.getGeneratedActionTask()).isEmpty(); + } + } + + // Verify output + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + assertThat(recordOutput.get(0).getValue()).isEqualTo(12L); + + // Test checkpoint complete triggers cleanup + testHarness.notifyOfCompletedCheckpoint(1L); + } + } + + @Test + void testActionStateStoreContentVerification() throws Exception { + AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false); + + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>( + agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Use reflection to access the action state store for validation + Field actionStateStoreField = + ActionExecutionOperator.class.getDeclaredField("actionStateStore"); + actionStateStoreField.setAccessible(true); + InMemoryActionStateStore actionStateStore = + (InMemoryActionStateStore) actionStateStoreField.get(operator); + + Long inputValue = 3L; + testHarness.processElement(new StreamRecord<>(inputValue)); + operator.waitInFlightEventsFinished(); + + Map<String, Map<String, ActionState>> actionStates = + actionStateStore.getKeyedActionStates(); + assertThat(actionStates).hasSize(1); + + // Verify specific action states by examining the keys + for (Map.Entry<String, Map<String, ActionState>> outerEntry : actionStates.entrySet()) { + for (Map.Entry<String, ActionState> entry : outerEntry.getValue().entrySet()) { + String stateKey = entry.getKey(); + ActionState state = entry.getValue(); + + // Verify the state key contains the expected key and action information + assertThat(stateKey).contains(inputValue.toString()); + + // Verify task event is properly stored + Event taskEvent = state.getTaskEvent(); + assertThat(taskEvent).isNotNull(); + + // Verify memory updates contain expected data + if (!state.getMemoryUpdates().isEmpty()) { + // For action1, memory should contain input + 1 + assertThat(state.getMemoryUpdates().get(0).getPath()).isEqualTo("tmp"); + assertThat(state.getMemoryUpdates().get(0).getValue()) + .isEqualTo(inputValue + 1); + } + + // Verify output events are captured + assertThat(state.getOutputEvents()).isNotEmpty(); + + // Check the type of events in the output + Event outputEvent = state.getOutputEvents().get(0); + assertThat(outputEvent).isNotNull(); + if (outputEvent instanceof TestAgent.MiddleEvent) { + TestAgent.MiddleEvent middleEvent = (TestAgent.MiddleEvent) outputEvent; + assertThat(middleEvent.getNum()).isEqualTo(inputValue + 1); + } else if (outputEvent instanceof OutputEvent) { + OutputEvent finalOutput = (OutputEvent) outputEvent; + assertThat(finalOutput.getOutput()).isEqualTo((inputValue + 1) * 2); + } + } + } + + // Verify final output + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + assertThat(recordOutput.get(0).getValue()).isEqualTo((inputValue + 1) * 2); + } + } + + @Test + void testActionStateStoreStateManagement() throws Exception { + AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false); + + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>( + agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Access the action state store + java.lang.reflect.Field actionStateStoreField = + ActionExecutionOperator.class.getDeclaredField("actionStateStore"); + actionStateStoreField.setAccessible(true); + InMemoryActionStateStore actionStateStore = + (InMemoryActionStateStore) actionStateStoreField.get(operator); + + // Process multiple elements with same key to test state persistence + testHarness.processElement(new StreamRecord<>(1L)); + operator.waitInFlightEventsFinished(); + + // Verify initial state creation + Map<String, Map<String, ActionState>> actionStates = + actionStateStore.getKeyedActionStates(); + assertThat(actionStates).isNotEmpty(); + int initialStateCount = actionStates.size(); + + testHarness.processElement(new StreamRecord<>(1L)); + operator.waitInFlightEventsFinished(); + + // Verify state persists and grows for same key processing + actionStates = actionStateStore.getKeyedActionStates(); + assertThat(actionStates.size()).isGreaterThanOrEqualTo(initialStateCount); + + // Process element with different key + testHarness.processElement(new StreamRecord<>(2L)); + operator.waitInFlightEventsFinished(); + + // Verify new states created for different key + actionStates = actionStateStore.getKeyedActionStates(); + assertThat(actionStates.size()).isGreaterThan(initialStateCount); + + // Verify outputs + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(3); + } + } + + @Test + void testActionStateStoreCleanupAfterOutputEvent() throws Exception { + AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false); + + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>( + agentPlanWithStateStore, true, new InMemoryActionStateStore(true)), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Process multiple elements with same key to test state persistence + testHarness.processElement(new StreamRecord<>(1L)); + operator.waitInFlightEventsFinished(); + + testHarness.processElement(new StreamRecord<>(2L)); + operator.waitInFlightEventsFinished(); + + // Process element with different key + testHarness.processElement(new StreamRecord<>(3L)); + operator.waitInFlightEventsFinished(); + + // Verify outputs + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(3); + + // Access the action state store + Field actionStateStoreField = + ActionExecutionOperator.class.getDeclaredField("actionStateStore"); + actionStateStoreField.setAccessible(true); + InMemoryActionStateStore actionStateStore = + (InMemoryActionStateStore) actionStateStoreField.get(operator); + assertThat(actionStateStore.getKeyedActionStates()).isEmpty(); + } + } + + @Test + void testActionStateStoreReplayIncurNoFunctionCall() throws Exception { + AgentPlan agentPlanWithStateStore = TestAgent.getAgentPlan(false); + InMemoryActionStateStore actionStateStore; + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>( + agentPlanWithStateStore, true, new InMemoryActionStateStore(false)), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + // Access the action state store + Field actionStateStoreField = + ActionExecutionOperator.class.getDeclaredField("actionStateStore"); + actionStateStoreField.setAccessible(true); + actionStateStore = (InMemoryActionStateStore) actionStateStoreField.get(operator); + + Long inputValue = 7L; + + // First processing - this will execute the actual functions and store state + testHarness.processElement(new StreamRecord<>(inputValue)); + operator.waitInFlightEventsFinished(); + } + try (KeyedOneInputStreamOperatorTestHarness<Long, Long, Object> testHarness = + new KeyedOneInputStreamOperatorTestHarness<>( + new ActionExecutionOperatorFactory<>( + agentPlanWithStateStore, true, actionStateStore), + (KeySelector<Long, Long>) value -> value, + TypeInformation.of(Long.class))) { + testHarness.open(); + ActionExecutionOperator<Long, Object> operator = + (ActionExecutionOperator<Long, Object>) testHarness.getOperator(); + + Long inputValue = 7L; + + // First processing - this will execute the actual functions and store state + testHarness.processElement(new StreamRecord<>(inputValue)); + operator.waitInFlightEventsFinished(); + // Verify first output is correct + List<StreamRecord<Object>> recordOutput = + (List<StreamRecord<Object>>) testHarness.getRecordOutput(); + assertThat(recordOutput.size()).isEqualTo(1); + assertThat(recordOutput.get(0).getValue()).isEqualTo((inputValue + 1) * 2); + + // The action state store should only have one entry + assertThat(actionStateStore.getKeyedActionStates().get(String.valueOf(inputValue))) + .hasSize(2); + } + } + public static class TestAgent { public static class MiddleEvent extends Event { @@ -254,7 +537,7 @@ public class ActionExecutionOperatorTest { actions.put(action3.getName(), action3); } - return new AgentPlan(actions, actionsByEvent); + return new AgentPlan(actions, actionsByEvent, new HashMap<>()); } catch (Exception e) { ExceptionUtils.rethrow(e); }