joeyutong commented on code in PR #628:
URL: https://github.com/apache/flink-agents/pull/628#discussion_r3212625590


##########
runtime/src/main/java/org/apache/flink/agents/runtime/actionstate/FlussActionStateStore.java:
##########
@@ -0,0 +1,496 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.agents.runtime.actionstate;
+
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.plan.AgentConfiguration;
+import org.apache.flink.agents.plan.actions.Action;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.fluss.client.Connection;
+import org.apache.fluss.client.ConnectionFactory;
+import org.apache.fluss.client.admin.Admin;
+import org.apache.fluss.client.admin.OffsetSpec;
+import org.apache.fluss.client.table.Table;
+import org.apache.fluss.client.table.scanner.ScanRecord;
+import org.apache.fluss.client.table.scanner.log.LogScanner;
+import org.apache.fluss.client.table.scanner.log.ScanRecords;
+import org.apache.fluss.client.table.writer.AppendWriter;
+import org.apache.fluss.config.ConfigOptions;
+import org.apache.fluss.config.Configuration;
+import org.apache.fluss.metadata.DatabaseDescriptor;
+import org.apache.fluss.metadata.Schema;
+import org.apache.fluss.metadata.TableBucket;
+import org.apache.fluss.metadata.TableDescriptor;
+import org.apache.fluss.metadata.TablePath;
+import org.apache.fluss.row.BinaryString;
+import org.apache.fluss.row.GenericRow;
+import org.apache.fluss.row.InternalRow;
+import org.apache.fluss.types.DataTypes;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_DATABASE;
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE;
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_ACTION_STATE_TABLE_BUCKETS;
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_BOOTSTRAP_SERVERS;
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_JAAS_CONFIG;
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_MECHANISM;
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_PASSWORD;
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SASL_USERNAME;
+import static 
org.apache.flink.agents.api.configuration.AgentConfigOptions.FLUSS_SECURITY_PROTOCOL;
+import static 
org.apache.flink.agents.runtime.actionstate.ActionStateUtil.generateKey;
+import static org.apache.fluss.config.ConfigOptions.BOOTSTRAP_SERVERS;
+import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_CONFIG;
+import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_PASSWORD;
+import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_JAAS_USERNAME;
+import static org.apache.fluss.config.ConfigOptions.CLIENT_SASL_MECHANISM;
+import static org.apache.fluss.config.ConfigOptions.CLIENT_SECURITY_PROTOCOL;
+
+/**
+ * An implementation of {@link ActionStateStore} that uses an Apache Fluss log 
table as the backend.
+ * All state is maintained in an in-memory map for fast lookups, with the 
Fluss log table providing
+ * durability and recovery support.
+ */
+public class FlussActionStateStore implements ActionStateStore {
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(FlussActionStateStore.class);
+
+    private static final Duration POLL_TIMEOUT = Duration.ofSeconds(1);
+
+    private static final String SECURITY_PROTOCOL_PLAINTEXT = "PLAINTEXT";
+
+    // Column names in the Fluss table schema
+    private static final String COL_NAME_STATE_KEY = "state_key";
+    private static final String COL_NAME_STATE_PAYLOAD = "state_payload";
+    private static final String COL_NAME_AGENT_KEY = "agent_key";
+
+    // Column indices in the Fluss table schema
+    private static final int COL_STATE_KEY = 0;
+    private static final int COL_STATE_PAYLOAD = 1;
+
+    private final AgentConfiguration agentConfiguration;
+    private final String databaseName;
+    private final String tableName;
+    private final TablePath tablePath;
+
+    private final Connection connection;
+    private final Table table;
+    private final AppendWriter writer;
+
+    /** In-memory cache for O(1) state lookups; rebuilt from Fluss log on 
recovery. */
+    private final Map<String, ActionState> actionStates;
+
+    @VisibleForTesting
+    FlussActionStateStore(
+            Map<String, ActionState> actionStates,
+            Connection connection,
+            Table table,
+            AppendWriter writer) {
+        this.agentConfiguration = null;
+        this.databaseName = null;
+        this.tableName = null;
+        this.tablePath = null;
+        this.actionStates = actionStates;
+        this.connection = connection;
+        this.table = table;
+        this.writer = writer;
+    }
+
+    public FlussActionStateStore(AgentConfiguration agentConfiguration) {
+        this.agentConfiguration = agentConfiguration;
+        this.databaseName = 
agentConfiguration.get(FLUSS_ACTION_STATE_DATABASE);
+        this.tableName = agentConfiguration.get(FLUSS_ACTION_STATE_TABLE);
+        this.tablePath = TablePath.of(databaseName, tableName);
+        this.actionStates = new HashMap<>();
+
+        Configuration flussConf = new Configuration();
+        flussConf.setString(
+                BOOTSTRAP_SERVERS.key(), 
agentConfiguration.get(FLUSS_BOOTSTRAP_SERVERS));
+        // Minimize latency for synchronous put(): setting batch linger time 
to zero ensures
+        // that each append is sent immediately without waiting for additional 
records to batch.
+        flussConf.set(ConfigOptions.CLIENT_WRITER_BATCH_TIMEOUT, 
Duration.ZERO);
+
+        // Only set security/SASL parameters when the protocol requires 
authentication.
+        // When PLAINTEXT (the default), SASL parameters are semantically 
invalid and may
+        // cause the Fluss client to attempt an unwanted SASL handshake.
+        String securityProtocol = 
agentConfiguration.get(FLUSS_SECURITY_PROTOCOL);
+        flussConf.setString(CLIENT_SECURITY_PROTOCOL, securityProtocol);
+        if (!SECURITY_PROTOCOL_PLAINTEXT.equalsIgnoreCase(securityProtocol)) {
+            flussConf.setString(
+                    CLIENT_SASL_MECHANISM, 
agentConfiguration.get(FLUSS_SASL_MECHANISM));
+
+            String jaasConfig = agentConfiguration.get(FLUSS_SASL_JAAS_CONFIG);
+            if (jaasConfig != null) {
+                flussConf.setString(CLIENT_SASL_JAAS_CONFIG, jaasConfig);
+            }
+            String username = agentConfiguration.get(FLUSS_SASL_USERNAME);
+            if (username != null) {
+                flussConf.setString(CLIENT_SASL_JAAS_USERNAME, username);
+            }
+            String password = agentConfiguration.get(FLUSS_SASL_PASSWORD);
+            if (password != null) {
+                flussConf.setString(CLIENT_SASL_JAAS_PASSWORD, password);
+            }
+        }
+
+        this.connection = ConnectionFactory.createConnection(flussConf);
+        maybeCreateDatabaseAndTable();
+        this.table = connection.getTable(tablePath);
+        this.writer = table.newAppend().createWriter();
+
+        LOG.info(
+                "Initialized FlussActionStateStore (log table) with table: 
{}.{}",
+                databaseName,
+                tableName);
+    }
+
+    @Override
+    public void put(Object key, long seqNum, Action action, Event event, 
ActionState state)
+            throws Exception {
+        String stateKey = generateKey(key, seqNum, action, event);
+        byte[] payload = ActionStateSerde.serialize(state);
+
+        GenericRow row =
+                GenericRow.of(
+                        BinaryString.fromString(stateKey),
+                        payload,
+                        BinaryString.fromString(key.toString()));
+
+        // Synchronous write ensures the record is durable before returning.
+        // TODO: Optimize throughput via batching + flush() once Fluss 
supports it
+        //  (see
+        // 
https://github.com/apache/fluss/blob/5850c837/fluss-client/src/main/java/org/apache/fluss/client/write/Sender.java#L234-L241).
+        //  Note: steps affecting recovery correctness must remain synchronous.
+        writer.append(row).get();
+        actionStates.put(stateKey, state);
+
+        LOG.debug("Stored action state: key={}, isCompleted={}", stateKey, 
state.isCompleted());
+    }
+
+    @Override
+    public ActionState get(Object key, long seqNum, Action action, Event 
event) throws Exception {
+        String stateKey = generateKey(key, seqNum, action, event);
+
+        boolean hasDivergence = checkDivergence(key.toString(), seqNum);
+
+        if (!actionStates.containsKey(stateKey) || hasDivergence) {
+            actionStates
+                    .entrySet()
+                    .removeIf(
+                            entry -> {
+                                try {
+                                    List<String> parts = 
ActionStateUtil.parseKey(entry.getKey());
+                                    if (parts.size() >= 2) {
+                                        long stateSeqNum = 
Long.parseLong(parts.get(1));
+                                        return stateSeqNum > seqNum;
+                                    }
+                                } catch (NumberFormatException e) {
+                                    LOG.warn(
+                                            "Failed to parse sequence number 
from state key: {}",
+                                            stateKey);
+                                }
+                                return false;
+                            });
+        }
+
+        ActionState state = actionStates.get(stateKey);
+        LOG.debug("Lookup action state: key={}, found={}", stateKey, state != 
null);
+        return state;
+    }
+
+    private boolean checkDivergence(String key, long seqNum) {
+        return actionStates.keySet().stream()
+                        .filter(k -> k.startsWith(key + "_" + seqNum + "_"))
+                        .count()
+                > 1;
+    }
+
+    /**
+     * Rebuilds in-memory state by scanning the Fluss log table. If recovery 
markers are provided,
+     * computes the minimum offset per bucket across all markers and 
subscribes from those offsets.
+     * Otherwise, skips rebuild since there is no checkpointed position to 
recover from. Reads from
+     * the start offset up to the latest offset captured at rebuild start. For 
the same state key
+     * appearing multiple times in the log, the latest record wins 
(last-write-wins).
+     */
+    @Override
+    public void rebuildState(List<Object> recoveryMarkers) {
+        LOG.info(
+                "Rebuilding action state from Fluss log table with {} recovery 
markers",
+                recoveryMarkers.size());
+
+        if (recoveryMarkers.isEmpty()) {
+            LOG.info("No recovery markers, skipping state rebuild");
+            return;
+        }
+
+        actionStates.clear();
+
+        Map<Integer, Long> bucketStartOffsets = 
mergeRecoveryMarkerOffsets(recoveryMarkers);
+        if (bucketStartOffsets.isEmpty()) {
+            LOG.info("No valid bucket offsets in recovery markers, skipping 
state rebuild");
+            return;
+        }
+
+        Map<Integer, Long> bucketEndOffsets = getBucketEndOffsets();
+        Map<Integer, Long> bucketEarliestOffsets = getBucketEarliestOffsets();
+        LOG.debug(
+                "Rebuild window: startOffsets={}, earliestOffsets={}, 
endOffsets={}",
+                bucketStartOffsets,
+                bucketEarliestOffsets,
+                bucketEndOffsets);
+
+        try (LogScanner scanner = table.newScan().createLogScanner()) {
+            Map<Integer, Long> remainingBuckets =
+                    subscribeEffectiveOffsets(
+                            scanner, bucketStartOffsets, bucketEndOffsets, 
bucketEarliestOffsets);
+            LOG.debug("Subscribed buckets for rebuild: {}", remainingBuckets);
+
+            pollAndReplay(scanner, remainingBuckets);
+        } catch (Exception e) {
+            throw new RuntimeException("Failed to rebuild state from Fluss log 
table", e);
+        }
+
+        LOG.info("Completed rebuilding state, recovered {} states", 
actionStates.size());
+    }
+
+    /**
+     * Merges recovery markers into a per-bucket start offset map. For each 
bucket, the minimum
+     * offset across all markers is used to cover the widest recovery window.
+     */
+    private Map<Integer, Long> mergeRecoveryMarkerOffsets(List<Object> 
recoveryMarkers) {
+        Map<Integer, Long> bucketStartOffsets = new HashMap<>();
+        for (Object marker : recoveryMarkers) {
+            if (marker instanceof Map) {
+                @SuppressWarnings("unchecked")
+                Map<Integer, Long> markerMap = (Map<Integer, Long>) marker;
+                for (Map.Entry<Integer, Long> entry : markerMap.entrySet()) {
+                    bucketStartOffsets.merge(entry.getKey(), entry.getValue(), 
Math::min);
+                }
+            } else if (marker != null) {
+                LOG.warn(
+                        "Ignoring unrecognized recovery marker type: {}",
+                        marker.getClass().getName());
+            }
+        }
+        return bucketStartOffsets;
+    }
+
+    /**
+     * Validates effective offsets for each bucket and subscribes the scanner. 
Buckets with no new
+     * data are skipped; buckets with data loss (retention cleaned the 
recovery window) cause an
+     * immediate failure.
+     *
+     * @return a map of bucket-id to end-offset for buckets that need to be 
scanned
+     */
+    private Map<Integer, Long> subscribeEffectiveOffsets(
+            LogScanner scanner,
+            Map<Integer, Long> bucketStartOffsets,
+            Map<Integer, Long> bucketEndOffsets,
+            Map<Integer, Long> bucketEarliestOffsets) {
+        Map<Integer, Long> remainingBuckets = new HashMap<>();
+        for (Map.Entry<Integer, Long> entry : bucketStartOffsets.entrySet()) {
+            int bucket = entry.getKey();
+            long startOffset = entry.getValue();
+            long endOffset = bucketEndOffsets.get(bucket);
+            long earliestOffset = bucketEarliestOffsets.get(bucket);

Review Comment:
   @Sxnan Good catch. The NPE is a real issue here, but I think the bigger 
problem is the bucket layout mismatch behind it.
   If the table bucket count changes, the offsets in the recovery marker are no 
longer safe to apply to the current table. Some buckets may be missing, or 
existing bucket ids may refer to a different key distribution.
   Maybe rebuild should validate the recovery marker bucket set against the 
current table buckets before applying offsets. If they do not match, it would 
be safer to fail fast with a clear error.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to