This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-agents.git

commit 4ed5fc5928f8d959e4e1a2d729047cb9e32006f3
Author: Xu Huang <zuosi...@alibaba-inc.com>
AuthorDate: Thu Jul 31 15:23:25 2025 +0800

    [runtime] Add rescaling test case
---
 runtime/pom.xml                                    |   6 +
 .../flink/agents/runtime/RescalingITCase.java      | 516 +++++++++++++++++++++
 2 files changed, 522 insertions(+)

diff --git a/runtime/pom.xml b/runtime/pom.xml
index be95de6..387db96 100644
--- a/runtime/pom.xml
+++ b/runtime/pom.xml
@@ -85,6 +85,12 @@ under the License.
                 </exclusion>
             </exclusions>
         </dependency>
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            <artifactId>flink-test-utils</artifactId>
+            <version>${flink.version}</version>
+            <scope>test</scope>
+        </dependency>
         <dependency>
             <groupId>com.alibaba</groupId>
             <artifactId>pemja</artifactId>
diff --git 
a/runtime/src/test/java/org/apache/flink/agents/runtime/RescalingITCase.java 
b/runtime/src/test/java/org/apache/flink/agents/runtime/RescalingITCase.java
new file mode 100644
index 0000000..64b0a6f
--- /dev/null
+++ b/runtime/src/test/java/org/apache/flink/agents/runtime/RescalingITCase.java
@@ -0,0 +1,516 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.agents.runtime;
+
+import org.apache.flink.agents.api.Agent;
+import org.apache.flink.agents.api.Event;
+import org.apache.flink.agents.api.InputEvent;
+import org.apache.flink.agents.api.OutputEvent;
+import org.apache.flink.agents.api.context.RunnerContext;
+import org.apache.flink.agents.plan.AgentPlan;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.time.Deadline;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.client.program.ClusterClient;
+import org.apache.flink.configuration.CheckpointingOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestartStrategyOptions;
+import org.apache.flink.configuration.StateBackendOptions;
+import org.apache.flink.core.execution.SavepointFormatType;
+import org.apache.flink.runtime.client.JobStatusMessage;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.KeyedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.testutils.TestingUtils;
+import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.TestLogger;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.File;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+import static 
org.apache.flink.runtime.testutils.CommonTestUtils.waitForAllTaskRunning;
+import static org.apache.flink.test.util.TestUtils.submitJobAndWaitForResult;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/** This test case is derived from an existing test in Flink. */
+@RunWith(Parameterized.class)
+public class RescalingITCase extends TestLogger {
+
+    @ClassRule
+    public static final TestExecutorResource<ScheduledExecutorService> 
EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorResource();
+
+    private static final int numTaskManagers = 2;
+    private static final int slotsPerTaskManager = 2;
+    private static final int numSlots = numTaskManagers * slotsPerTaskManager;
+
+    @Parameterized.Parameters(name = "backend = {0}")
+    public static Collection<Object[]> data() {
+        return Arrays.asList(new Object[][] {{"hashmap"}, {"rocksdb"}});
+    }
+
+    public RescalingITCase(String backend) {
+        this.backend = backend;
+    }
+
+    private final String backend;
+
+    private String currentBackend = null;
+
+    private static MiniClusterWithClientResource cluster;
+
+    @ClassRule public static TemporaryFolder temporaryFolder = new 
TemporaryFolder();
+
+    @Before
+    public void setup() throws Exception {
+        // detect parameter change
+        if (currentBackend != backend) {
+            shutDownExistingCluster();
+
+            currentBackend = backend;
+
+            Configuration config = new Configuration();
+
+            final File checkpointDir = temporaryFolder.newFolder();
+            final File savepointDir = temporaryFolder.newFolder();
+
+            config.set(StateBackendOptions.STATE_BACKEND, currentBackend);
+            config.set(
+                    CheckpointingOptions.CHECKPOINTS_DIRECTORY, 
checkpointDir.toURI().toString());
+            config.set(CheckpointingOptions.SAVEPOINT_DIRECTORY, 
savepointDir.toURI().toString());
+
+            cluster =
+                    new MiniClusterWithClientResource(
+                            new MiniClusterResourceConfiguration.Builder()
+                                    .setConfiguration(config)
+                                    .setNumberTaskManagers(numTaskManagers)
+                                    .setNumberSlotsPerTaskManager(numSlots)
+                                    .build());
+            cluster.before();
+        }
+
+        TestAgent.numProcessedEvent.set(0);
+    }
+
+    @AfterClass
+    public static void shutDownExistingCluster() {
+        if (cluster != null) {
+            cluster.after();
+            cluster = null;
+        }
+    }
+
+    @Test
+    public void testSavepointRescalingInKeyedState() throws Exception {
+        testSavepointRescalingKeyedState(false, false);
+    }
+
+    @Test
+    public void testSavepointRescalingOutKeyedState() throws Exception {
+        testSavepointRescalingKeyedState(true, false);
+    }
+
+    @Test
+    public void testSavepointRescalingInKeyedStateDerivedMaxParallelism() 
throws Exception {
+        testSavepointRescalingKeyedState(false, true);
+    }
+
+    @Test
+    public void testSavepointRescalingOutKeyedStateDerivedMaxParallelism() 
throws Exception {
+        testSavepointRescalingKeyedState(true, true);
+    }
+
+    /**
+     * Tests that a job with purely keyed state can be restarted from a 
savepoint with a different
+     * parallelism.
+     */
+    public void testSavepointRescalingKeyedState(boolean scaleOut, boolean 
deriveMaxParallelism)
+            throws Exception {
+        final int numberKeys = 42;
+        final int numberElements = 1000;
+        final int numberElements2 = 500;
+        final int parallelism = scaleOut ? numSlots / 2 : numSlots;
+        final int parallelism2 = scaleOut ? numSlots : numSlots / 2;
+        final int maxParallelism = 13;
+
+        Duration timeout = Duration.ofMinutes(3);
+        Deadline deadline = Deadline.now().plus(timeout);
+
+        ClusterClient<?> client = cluster.getClusterClient();
+
+        try {
+            JobGraph jobGraph =
+                    createJobGraphWithKeyedState(
+                            parallelism, maxParallelism, numberKeys, 
numberElements, false, 100);
+
+            final JobID jobID = jobGraph.getJobID();
+
+            client.submitJob(jobGraph).get();
+
+            // wait til the sources have emitted numberElements for each key 
and completed a
+            // checkpoint
+            assertTrue(
+                    SubtaskIndexFlatMapper.workCompletedLatch.await(
+                            deadline.timeLeft().toMillis(), 
TimeUnit.MILLISECONDS));
+
+            // verify the current state
+
+            Set<Tuple2<Integer, Integer>> actualResult = 
CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+
+                expectedResult.add(
+                        Tuple2.of(
+                                
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelism, parallelism, 
keyGroupIndex),
+                                numberElements * key));
+            }
+
+            assertEquals(expectedResult, actualResult);
+
+            // clear the CollectionSink set for the restarted job
+            CollectionSink.clearElementsSet();
+
+            waitForAllTaskRunning(cluster.getMiniCluster(), 
jobGraph.getJobID(), false);
+            CompletableFuture<String> savepointPathFuture =
+                    client.triggerSavepoint(jobID, null, 
SavepointFormatType.CANONICAL);
+
+            final String savepointPath =
+                    savepointPathFuture.get(deadline.timeLeft().toMillis(), 
TimeUnit.MILLISECONDS);
+
+            client.cancel(jobID).get();
+
+            while (!getRunningJobs(client).isEmpty()) {
+                Thread.sleep(50);
+            }
+
+            int restoreMaxParallelism =
+                    deriveMaxParallelism ? JobVertex.MAX_PARALLELISM_DEFAULT : 
maxParallelism;
+
+            JobGraph scaledJobGraph =
+                    createJobGraphWithKeyedState(
+                            parallelism2,
+                            restoreMaxParallelism,
+                            numberKeys,
+                            numberElements2,
+                            true,
+                            100);
+
+            scaledJobGraph.setSavepointRestoreSettings(
+                    SavepointRestoreSettings.forPath(savepointPath));
+
+            submitJobAndWaitForResult(client, scaledJobGraph, 
getClass().getClassLoader());
+
+            Set<Tuple2<Integer, Integer>> actualResult2 = 
CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+                expectedResult2.add(
+                        Tuple2.of(
+                                
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelism, parallelism2, 
keyGroupIndex),
+                                key * (numberElements + numberElements2)));
+            }
+
+            assertEquals(expectedResult2, actualResult2);
+            assertEquals(
+                    numberKeys * (numberElements + numberElements2) * 2,
+                    TestAgent.numProcessedEvent.get());
+
+        } finally {
+            // clear the CollectionSink set for the restarted job
+            CollectionSink.clearElementsSet();
+        }
+    }
+
+    private static JobGraph createJobGraphWithKeyedState(
+            int parallelism,
+            int maxParallelism,
+            int numberKeys,
+            int numberElements,
+            boolean terminateAfterEmission,
+            int checkpointingInterval)
+            throws Exception {
+
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(parallelism);
+        if (0 < maxParallelism) {
+            env.getConfig().setMaxParallelism(maxParallelism);
+        }
+        env.enableCheckpointing(checkpointingInterval);
+        env.configure(new 
Configuration().set(RestartStrategyOptions.RESTART_STRATEGY, "none"));
+        env.getConfig().setUseSnapshotCompression(true);
+
+        DataStream<Integer> input =
+                env.addSource(
+                                new SubtaskIndexSource(
+                                        numberKeys, numberElements, 
terminateAfterEmission))
+                        .keyBy(
+                                new KeySelector<Integer, Integer>() {
+                                    private static final long serialVersionUID 
=
+                                            -7952298871120320940L;
+
+                                    @Override
+                                    public Integer getKey(Integer value) 
throws Exception {
+                                        return value;
+                                    }
+                                });
+
+        // insert agent topology
+        input =
+                CompileUtils.connectToAgent(
+                                (KeyedStream<Integer, Integer>) input,
+                                new AgentPlan(new TestAgent()))
+                        .map(
+                                new MapFunction<Object, Integer>() {
+                                    @Override
+                                    public Integer map(Object value) throws 
Exception {
+                                        return (Integer) value;
+                                    }
+                                })
+                        .keyBy(
+                                new KeySelector<Integer, Integer>() {
+                                    private static final long serialVersionUID 
=
+                                            -7952298871120320940L;
+
+                                    @Override
+                                    public Integer getKey(Integer value) 
throws Exception {
+                                        return value;
+                                    }
+                                });
+
+        SubtaskIndexFlatMapper.workCompletedLatch = new 
CountDownLatch(numberKeys);
+
+        DataStream<Tuple2<Integer, Integer>> result =
+                input.flatMap(new SubtaskIndexFlatMapper(numberElements));
+
+        result.addSink(new CollectionSink<Tuple2<Integer, Integer>>());
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    private static class SubtaskIndexSource extends 
RichParallelSourceFunction<Integer> {
+
+        private static final long serialVersionUID = -400066323594122516L;
+
+        private final int numberKeys;
+        private final int numberElements;
+        private final boolean terminateAfterEmission;
+
+        protected int counter = 0;
+
+        private boolean running = true;
+
+        SubtaskIndexSource(int numberKeys, int numberElements, boolean 
terminateAfterEmission) {
+
+            this.numberKeys = numberKeys;
+            this.numberElements = numberElements;
+            this.terminateAfterEmission = terminateAfterEmission;
+        }
+
+        @Override
+        public void run(SourceContext<Integer> ctx) throws Exception {
+            final Object lock = ctx.getCheckpointLock();
+            final int subtaskIndex = 
getRuntimeContext().getTaskInfo().getIndexOfThisSubtask();
+
+            while (running) {
+
+                if (counter < numberElements) {
+                    synchronized (lock) {
+                        for (int value = subtaskIndex;
+                                value < numberKeys;
+                                value +=
+                                        getRuntimeContext()
+                                                .getTaskInfo()
+                                                
.getNumberOfParallelSubtasks()) {
+
+                            ctx.collect(value);
+                        }
+
+                        counter++;
+                    }
+                } else {
+                    if (terminateAfterEmission) {
+                        running = false;
+                    } else {
+                        Thread.sleep(100);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public void cancel() {
+            running = false;
+        }
+    }
+
+    private static class SubtaskIndexFlatMapper
+            extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
+            implements CheckpointedFunction {
+
+        private static final long serialVersionUID = 5273172591283191348L;
+
+        private static CountDownLatch workCompletedLatch = new 
CountDownLatch(1);
+
+        private transient ValueState<Integer> counter;
+        private transient ValueState<Integer> sum;
+
+        private final int numberElements;
+
+        SubtaskIndexFlatMapper(int numberElements) {
+            this.numberElements = numberElements;
+        }
+
+        @Override
+        public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> 
out)
+                throws Exception {
+            int count = counter.value() + 1;
+            counter.update(count);
+
+            int s = sum.value() + value;
+            sum.update(s);
+
+            if (count % numberElements == 0) {
+                out.collect(
+                        
Tuple2.of(getRuntimeContext().getTaskInfo().getIndexOfThisSubtask(), s));
+                workCompletedLatch.countDown();
+            }
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext context) throws 
Exception {
+            // all managed, nothing to do.
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) 
throws Exception {
+            counter =
+                    context.getKeyedStateStore()
+                            .getState(new ValueStateDescriptor<>("counter", 
Integer.class, 0));
+            sum =
+                    context.getKeyedStateStore()
+                            .getState(new ValueStateDescriptor<>("sum", 
Integer.class, 0));
+        }
+    }
+
+    private static class CollectionSink<IN> implements SinkFunction<IN> {
+
+        private static Set<Object> elements =
+                Collections.newSetFromMap(new ConcurrentHashMap<Object, 
Boolean>());
+
+        private static final long serialVersionUID = -1652452958040267745L;
+
+        public static <IN> Set<IN> getElementsSet() {
+            return (Set<IN>) elements;
+        }
+
+        public static void clearElementsSet() {
+            elements.clear();
+        }
+
+        @Override
+        public void invoke(IN value) throws Exception {
+            elements.add(value);
+        }
+    }
+
+    private static List<JobID> getRunningJobs(ClusterClient<?> client) throws 
Exception {
+        Collection<JobStatusMessage> statusMessages = client.listJobs().get();
+        return statusMessages.stream()
+                .filter(status -> 
!status.getJobState().isGloballyTerminalState())
+                .map(JobStatusMessage::getJobId)
+                .collect(Collectors.toList());
+    }
+
+    /** Test event class for testing. */
+    public static class TestEvent extends Event {
+        private final int data;
+
+        public TestEvent(int data) {
+            this.data = data;
+        }
+
+        public int getData() {
+            return data;
+        }
+    }
+
+    /** Test agent class for testing. */
+    public static class TestAgent extends Agent {
+
+        public static final AtomicInteger numProcessedEvent = new 
AtomicInteger(0);
+
+        @org.apache.flink.agents.api.Action(listenEvents = {InputEvent.class})
+        public static void handleInputEvent(InputEvent event, RunnerContext 
context) {
+            // Test action implementation
+            numProcessedEvent.incrementAndGet();
+            context.sendEvent(new TestEvent((Integer) event.getInput()));
+        }
+
+        @org.apache.flink.agents.api.Action(listenEvents = {TestEvent.class})
+        public static void handleTestEvent(TestEvent event, RunnerContext 
context) {
+            numProcessedEvent.incrementAndGet();
+            context.sendEvent(new OutputEvent(event.data));
+        }
+    }
+}

Reply via email to