This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit ce0b61f376b1be8e9733fe06972dafc634fedb3c
Author: Piotr Nowojski <[email protected]>
AuthorDate: Tue May 28 17:40:24 2024 +0200

    [FLINK-35351][checkpoint] Clean up and unify code for the custom 
partitioner test case
---
 .../checkpointing/UnalignedCheckpointITCase.java   |   2 +-
 .../UnalignedCheckpointRescaleITCase.java          | 104 +++++++++------------
 .../checkpointing/UnalignedCheckpointTestBase.java |  12 ++-
 3 files changed, 54 insertions(+), 64 deletions(-)

diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
index f649b8f7df3..6e94a7b06df 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
@@ -311,7 +311,7 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                 // shifts records from one partition to another evenly to 
retain order
                 .partitionCustom(new ShiftingPartitioner(), l -> l)
                 .map(
-                        new FailingMapper(
+                        new FailingMapper<>(
                                 state ->
                                         state.completedCheckpoints >= 
minCheckpoints / 4
                                                         && state.runNumber == 0
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
index bb217b5a8cd..7a21cab7c1d 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
@@ -22,6 +22,7 @@ package org.apache.flink.test.checkpointing;
 import org.apache.flink.api.common.JobExecutionResult;
 import org.apache.flink.api.common.accumulators.LongCounter;
 import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.OpenContext;
 import org.apache.flink.api.common.functions.Partitioner;
 import org.apache.flink.api.common.functions.RichMapFunction;
@@ -54,10 +55,8 @@ import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
 import java.io.File;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.BitSet;
-import java.util.Collection;
 import java.util.Collections;
 
 import static 
org.apache.flink.api.common.eventtime.WatermarkStrategy.noWatermarks;
@@ -336,38 +335,57 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
         },
         CUSTOM_PARTITIONER {
             final int sinkParallelism = 3;
-            final int numberElements = 1000;
 
             @Override
             public void create(
-                    StreamExecutionEnvironment environment,
+                    StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedFailuresUntilSourceFinishes,
+                    int expectedRestarts,
                     long sourceSleepMs) {
-                int parallelism = environment.getParallelism();
-                environment
-                        .fromData(generateStrings(numberElements / 
parallelism, sinkParallelism))
+                int parallelism = env.getParallelism();
+
+                env.fromSource(
+                                new LongSource(
+                                        minCheckpoints,
+                                        parallelism,
+                                        expectedRestarts,
+                                        env.getCheckpointInterval(),
+                                        sourceSleepMs),
+                                noWatermarks(),
+                                "source")
                         .name("source")
+                        .uid("source")
+                        .map(
+                                new MapFunction<Long, String>() {
+                                    @Override
+                                    public String map(Long value) throws 
Exception {
+                                        value = withoutHeader(value);
+                                        return buildString(
+                                                value % sinkParallelism, value 
/ sinkParallelism);
+                                    }
+                                })
+                        .name("long-to-string-map")
+                        .uid("long-to-string-map")
+                        .map(
+                                new FailingMapper<>(
+                                        state -> false,
+                                        state ->
+                                                state.completedCheckpoints >= 
minCheckpoints / 2
+                                                        && state.runNumber == 
0,
+                                        state -> false,
+                                        state -> false))
+                        .name("failing-map")
+                        .uid("failing-map")
                         .setParallelism(parallelism)
                         .partitionCustom(new StringPartitioner(), str -> 
str.split(" ")[0])
-                        .addSink(new StringSink(numberElements / 
sinkParallelism))
+                        .addSink(new BackPressureInducingSink())
                         .name("sink")
+                        .uid("sink")
                         .setParallelism(sinkParallelism);
             }
 
-            private Collection<String> generateStrings(
-                    int producePerPartition, int partitionCount) {
-                Collection<String> list = new ArrayList<>();
-                for (int i = 0; i < producePerPartition; i++) {
-                    for (int partition = 0; partition < partitionCount; 
partition++) {
-                        list.add(buildString(partition, i));
-                    }
-                }
-                return list;
-            }
-
-            private String buildString(int partition, int index) {
+            private String buildString(long partition, long index) {
                 String longStr = new String(new char[3713]).replace('\0', 
'\uFFFF');
                 return partition + " " + index + " " + longStr;
             }
@@ -378,7 +396,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
             combinedSource
                     .shuffle()
                     .map(
-                            new FailingMapper(
+                            new FailingMapper<>(
                                     state -> false,
                                     state ->
                                             state.completedCheckpoints >= 
minCheckpoints / 2
@@ -604,7 +622,6 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
 
     @Test
     public void shouldRescaleUnalignedCheckpoint() throws Exception {
-        StringSink.failed = false;
         final UnalignedSettings prescaleSettings =
                 new UnalignedSettings(topology)
                         .setParallelism(oldParallelism)
@@ -761,43 +778,12 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
         }
     }
 
-    private static class StringSink implements SinkFunction<String>, 
CheckpointedFunction {
-
-        static volatile boolean failed = false;
-
-        int checkpointConsumed = 0;
-
-        int recordsConsumed = 0;
-
-        final int numberElements;
-
-        public StringSink(int numberElements) {
-            this.numberElements = numberElements;
-        }
-
-        @Override
-        public void invoke(String value, Context ctx) throws Exception {
-            if (!failed && checkpointConsumed > 1) {
-                failed = true;
-                throw new TestException("FAIL");
-            }
-            recordsConsumed++;
-            if (!failed && recordsConsumed == (numberElements / 3)) {
-                Thread.sleep(1000);
-            }
-            if (recordsConsumed == (numberElements - 1)) {
-                Thread.sleep(1000);
-            }
-        }
-
-        @Override
-        public void snapshotState(FunctionSnapshotContext context) {
-            checkpointConsumed++;
-        }
-
+    private static class BackPressureInducingSink<T> implements 
SinkFunction<T> {
         @Override
-        public void initializeState(FunctionInitializationContext context) {
-            // do  nothing
+        public void invoke(T value, Context ctx) throws Exception {
+            // TODO: maybe similarly to VerifyingSink, we should back pressure 
only until some point
+            // but currently it doesn't seem to be needed (test runs quickly 
enough)
+            Thread.sleep(1);
         }
     }
 }
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
index a936f2171b7..985b687a0dc 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
@@ -838,7 +838,7 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
     }
 
     /** A mapper that fails in particular situations/attempts. */
-    protected static class FailingMapper extends RichMapFunction<Long, Long>
+    protected static class FailingMapper<T> extends RichMapFunction<T, T>
             implements CheckpointedFunction, CheckpointListener {
         private static final ListStateDescriptor<FailingMapperState>
                 FAILING_MAPPER_STATE_DESCRIPTOR =
@@ -849,7 +849,7 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
         private final FilterFunction<FailingMapperState> failDuringSnapshot;
         private final FilterFunction<FailingMapperState> failDuringRecovery;
         private final FilterFunction<FailingMapperState> failDuringClose;
-        private long lastValue;
+        private transient Object lastValue;
 
         protected FailingMapper(
                 FilterFunction<FailingMapperState> failDuringMap,
@@ -863,8 +863,12 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
         }
 
         @Override
-        public Long map(Long value) throws Exception {
-            lastValue = withoutHeader(value);
+        public T map(T value) throws Exception {
+            if (value instanceof Long) {
+                lastValue = withoutHeader((Long) value);
+            } else {
+                lastValue = value;
+            }
             checkFail(failDuringMap, "map");
             return value;
         }

Reply via email to