This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new c12e8881a [CELEBORN-1490][CIP-6] Add Flink Hybrid Shuffle IT test cases
c12e8881a is described below

commit c12e8881ab5b03e39a8b53d038a283351bf5906c
Author: Weijie Guo <[email protected]>
AuthorDate: Fri Nov 1 17:27:24 2024 +0800

    [CELEBORN-1490][CIP-6] Add Flink Hybrid Shuffle IT test cases
    
    ### What changes were proposed in this pull request?
    1. Add Flink Hybrid Shuffle IT test cases
    2. Fix bug in open stream.
    
    ### Why are the changes needed?
    
    Test coverage for celeborn + hybrid shuffle
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    Closes #2859 from reswqa/10-itcase-10month.
    
    Authored-by: Weijie Guo <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 .../plugin/flink/RemoteBufferStreamReader.java     |   4 +-
 .../flink/readclient/CelebornBufferStream.java     |  45 ++++--
 .../flink/readclient/FlinkShuffleClientImpl.java   |   3 +-
 .../flink/tiered/CelebornChannelBufferReader.java  |  11 +-
 .../flink/tiered/CelebornTierConsumerAgent.java    |   5 +-
 .../common/network/client/TransportClient.java     |  42 +++++
 pom.xml                                            |   2 +
 .../apache/celeborn/tests/flink/FlinkVersion.java  |  69 ++++++++
 .../tests/flink/JobGraphRunningHelper.java         |  70 ++++++++
 .../tests/flink/HybridShuffleWordCountTest.scala   | 180 +++++++++++++++++++++
 .../service/deploy/MiniClusterFeature.scala        |   2 +-
 11 files changed, 408 insertions(+), 25 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
index 46b1f1ff8..0bea1452d 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteBufferStreamReader.java
@@ -87,7 +87,7 @@ public class RemoteBufferStreamReader extends CreditListener {
           client.readBufferedPartition(
               shuffleId, partitionId, subPartitionIndexStart, 
subPartitionIndexEnd, false);
       bufferStream.open(
-          RemoteBufferStreamReader.this::requestBuffer, initialCredit, 
messageConsumer);
+          RemoteBufferStreamReader.this::requestBuffer, initialCredit, 
messageConsumer, false);
     } catch (Exception e) {
       logger.warn("Failed to open stream and report to flink framework. ", e);
       messageConsumer.accept(new TransportableError(0L, e));
@@ -158,6 +158,6 @@ public class RemoteBufferStreamReader extends 
CreditListener {
   public void onStreamEnd(BufferStreamEnd streamEnd) {
     long streamId = streamEnd.getStreamId();
     logger.debug("Buffer stream reader get stream end for {}", streamId);
-    bufferStream.moveToNextPartitionIfPossible(streamId);
+    bufferStream.moveToNextPartitionIfPossible(streamId, null, false);
   }
 }
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
index 1fda95c01..061a918c2 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java
@@ -53,6 +53,7 @@ public class CelebornBufferStream {
   private PartitionLocation[] locations;
   private int subIndexStart;
   private int subIndexEnd;
+  private long pushDataTimeoutMs;
   private TransportClient client;
   private AtomicInteger currentLocationIndex = new AtomicInteger(0);
   private long streamId = 0;
@@ -72,23 +73,26 @@ public class CelebornBufferStream {
       String shuffleKey,
       PartitionLocation[] locations,
       int subIndexStart,
-      int subIndexEnd) {
+      int subIndexEnd,
+      long pushDataTimeoutMs) {
     this.mapShuffleClient = mapShuffleClient;
     this.clientFactory = dataClientFactory;
     this.shuffleKey = shuffleKey;
     this.locations = locations;
     this.subIndexStart = subIndexStart;
     this.subIndexEnd = subIndexEnd;
+    this.pushDataTimeoutMs = pushDataTimeoutMs;
   }
 
   public void open(
       Supplier<ByteBuf> bufferSupplier,
       int initialCredit,
-      Consumer<RequestMessage> messageConsumer) {
+      Consumer<RequestMessage> messageConsumer,
+      boolean sync) {
     this.bufferSupplier = bufferSupplier;
     this.initialCredit = initialCredit;
     this.messageConsumer = messageConsumer;
-    moveToNextPartitionIfPossible(0);
+    moveToNextPartitionIfPossible(0, null, sync);
   }
 
   public void addCredit(PbReadAddCredit pbReadAddCredit) {
@@ -156,12 +160,19 @@ public class CelebornBufferStream {
       String shuffleKey,
       PartitionLocation[] locations,
       int subIndexStart,
-      int subIndexEnd) {
+      int subIndexEnd,
+      long pushDataTimeoutMs) {
     if (locations == null || locations.length == 0) {
       return empty();
     } else {
       return new CelebornBufferStream(
-          client, dataClientFactory, shuffleKey, locations, subIndexStart, 
subIndexEnd);
+          client,
+          dataClientFactory,
+          shuffleKey,
+          locations,
+          subIndexStart,
+          subIndexEnd,
+          pushDataTimeoutMs);
     }
   }
 
@@ -198,12 +209,10 @@ public class CelebornBufferStream {
     }
   }
 
-  public void moveToNextPartitionIfPossible(long endedStreamId) {
-    moveToNextPartitionIfPossible(endedStreamId, null);
-  }
-
   public void moveToNextPartitionIfPossible(
-      long endedStreamId, @Nullable BiConsumer<Long, Integer> 
requiredSegmentIdConsumer) {
+      long endedStreamId,
+      @Nullable BiConsumer<Long, Integer> requiredSegmentIdConsumer,
+      boolean sync) {
     logger.debug(
         "MoveToNextPartitionIfPossible in this:{},  endedStreamId: {}, 
currentLocationIndex: {}, currentSteamId:{}, locationsLength:{}",
         this,
@@ -218,7 +227,7 @@ public class CelebornBufferStream {
 
     if (currentLocationIndex.get() < locations.length) {
       try {
-        openStreamInternal(requiredSegmentIdConsumer);
+        openStreamInternal(requiredSegmentIdConsumer, sync);
         logger.debug(
             "MoveToNextPartitionIfPossible after openStream this:{},  
endedStreamId: {}, currentLocationIndex: {}, currentSteamId:{}, 
locationsLength:{}",
             this,
@@ -237,7 +246,8 @@ public class CelebornBufferStream {
    * Open the stream, note that if the openReaderFuture is not null, 
requiredSegmentIdConsumer will
    * be invoked for every subPartition when open stream success.
    */
-  private void openStreamInternal(@Nullable BiConsumer<Long, Integer> 
requiredSegmentIdConsumer)
+  private void openStreamInternal(
+      @Nullable BiConsumer<Long, Integer> requiredSegmentIdConsumer, boolean 
sync)
       throws IOException, InterruptedException {
     this.client =
         clientFactory.createClientWithRetry(
@@ -255,8 +265,7 @@ public class CelebornBufferStream {
                 .setInitialCredit(initialCredit)
                 .build()
                 .toByteArray());
-    client.sendRpc(
-        openStream.toByteBuffer(),
+    RpcResponseCallback rpcResponseCallback =
         new RpcResponseCallback() {
 
           @Override
@@ -313,7 +322,13 @@ public class CelebornBufferStream {
                 NettyUtils.getRemoteAddress(client.getChannel()));
             messageConsumer.accept(new TransportableError(streamId, e));
           }
-        });
+        };
+
+    if (sync) {
+      client.sendRpcSync(openStream.toByteBuffer(), rpcResponseCallback, 
pushDataTimeoutMs);
+    } else {
+      client.sendRpc(openStream.toByteBuffer(), rpcResponseCallback);
+    }
   }
 
   public TransportClient getClient() {
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
index 650ed1a3a..efbf343ce 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/FlinkShuffleClientImpl.java
@@ -199,7 +199,8 @@ public class FlinkShuffleClientImpl extends 
ShuffleClientImpl {
           shuffleKey,
           partitionLocations,
           subPartitionIndexStart,
-          subPartitionIndexEnd);
+          subPartitionIndexEnd,
+          conf.pushDataTimeoutMs());
     }
   }
 
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
index 527617c96..a387cf9fb 100644
--- 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornChannelBufferReader.java
@@ -129,13 +129,12 @@ public class CelebornChannelBufferReader {
     }
   }
 
-  public void open(int initialCredit) {
+  public void open(int initialCredit, boolean sync) {
     try {
       bufferStream =
           client.readBufferedPartition(
               shuffleId, partitionId, subPartitionIndexStart, 
subPartitionIndexEnd, true);
-      bufferStream.open(this::requestBuffer, initialCredit, messageConsumer);
-      this.isOpened = bufferStream.isOpened();
+      bufferStream.open(this::requestBuffer, initialCredit, messageConsumer, 
sync);
     } catch (Exception e) {
       messageConsumer.accept(new TransportableError(0L, e));
       LOG.error("Failed to open reader", e);
@@ -178,6 +177,10 @@ public class CelebornChannelBufferReader {
     return isOpened;
   }
 
+  public void setOpened(boolean opened) {
+    isOpened = opened;
+  }
+
   boolean isClosed() {
     return closed;
   }
@@ -306,7 +309,7 @@ public class CelebornChannelBufferReader {
     if (!closed && !CelebornBufferStream.isEmptyStream(bufferStream)) {
       // TOOD: Update the partition locations here if support reading and 
writing shuffle data
       // simultaneously
-      bufferStream.moveToNextPartitionIfPossible(streamId, 
this::sendRequireSegmentId);
+      bufferStream.moveToNextPartitionIfPossible(streamId, 
this::sendRequireSegmentId, true);
     }
   }
 
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
index d858ae891..0febd8bd3 100644
--- 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierConsumerAgent.java
@@ -375,7 +375,7 @@ public class CelebornTierConsumerAgent implements 
TierConsumerAgent {
   private boolean openReader(CelebornChannelBufferReader bufferReader) {
     if (!bufferReader.isOpened()) {
       try {
-        bufferReader.open(0);
+        bufferReader.open(0, true);
       } catch (Exception e) {
         // may throw PartitionUnRetryAbleException
         recycleAllResources();
@@ -383,7 +383,8 @@ public class CelebornTierConsumerAgent implements 
TierConsumerAgent {
       }
     }
 
-    return bufferReader.isOpened();
+    bufferReader.setOpened(true);
+    return true;
   }
 
   private void initBufferReaders() {
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index 2c335b350..3fc7c43b2 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -325,6 +325,48 @@ public class TransportClient implements Closeable {
     }
   }
 
+  /**
+   * Synchronously sends an opaque message to the RpcHandler on the 
server-side, waiting for up to a
+   * specified timeout for a response. The callback will be invoked with the 
server's response or
+   * upon any failure.
+   */
+  public void sendRpcSync(ByteBuffer message, RpcResponseCallback callback, 
long timeoutMs)
+      throws IOException {
+    final SettableFuture<Void> result = SettableFuture.create();
+
+    sendRpc(
+        message,
+        new RpcResponseCallback() {
+          @Override
+          public void onSuccess(ByteBuffer response) {
+            try {
+              ByteBuffer copy = ByteBuffer.allocate(response.remaining());
+              copy.put(response);
+              // flip "copy" to make it readable
+              copy.flip();
+              callback.onSuccess(copy);
+              result.set(null);
+            } catch (Throwable t) {
+              logger.warn("Error in responding RPC callback", t);
+              callback.onFailure(t);
+              result.set(null);
+            }
+          }
+
+          @Override
+          public void onFailure(Throwable e) {
+            callback.onFailure(e);
+            result.set(null);
+          }
+        });
+
+    try {
+      result.get(timeoutMs, TimeUnit.MILLISECONDS);
+    } catch (Exception e) {
+      throw new IOException("Exception in sendRpcSync to: " + 
this.getSocketAddress(), e);
+    }
+  }
+
   /**
    * Sends an opaque message to the RpcHandler on the server-side. No reply is 
expected for the
    * message, and no delivery guarantees are made.
diff --git a/pom.xml b/pom.xml
index 94042ad8f..9c57f6e12 100644
--- a/pom.xml
+++ b/pom.xml
@@ -889,6 +889,7 @@
             </systemProperties>
             <environmentVariables>
               <CELEBORN_LOCAL_HOSTNAME>localhost</CELEBORN_LOCAL_HOSTNAME>
+              <FLINK_VERSION>${flink.version}</FLINK_VERSION>
             </environmentVariables>
             <forkCount>1</forkCount>
             <reuseForks>false</reuseForks>
@@ -927,6 +928,7 @@
             </systemProperties>
             <environmentVariables>
               <CELEBORN_LOCAL_HOSTNAME>localhost</CELEBORN_LOCAL_HOSTNAME>
+              <FLINK_VERSION>${flink.version}</FLINK_VERSION>
             </environmentVariables>
           </configuration>
           <executions>
diff --git 
a/tests/flink-it/src/test/java/org/apache/celeborn/tests/flink/FlinkVersion.java
 
b/tests/flink-it/src/test/java/org/apache/celeborn/tests/flink/FlinkVersion.java
new file mode 100644
index 000000000..81eb39271
--- /dev/null
+++ 
b/tests/flink-it/src/test/java/org/apache/celeborn/tests/flink/FlinkVersion.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.flink;
+
+import org.apache.flink.annotation.Public;
+
+/** All supported flink versions. */
+@Public
+public enum FlinkVersion {
+  v1_14("1.14"),
+  v1_15("1.15"),
+  v1_16("1.16"),
+  v1_17("1.17"),
+  v1_18("1.18"),
+  v1_19("1.19"),
+  v1_20("1.20");
+
+  private final String versionStr;
+
+  FlinkVersion(String versionStr) {
+    this.versionStr = versionStr;
+  }
+
+  public static FlinkVersion fromVersionStr(String versionStr) {
+    switch (versionStr) {
+      case "1.14":
+        return v1_14;
+      case "1.15":
+        return v1_15;
+      case "1.16":
+        return v1_16;
+      case "1.17":
+        return v1_17;
+      case "1.18":
+        return v1_18;
+      case "1.19":
+        return v1_19;
+      case "1.20":
+        return v1_20;
+      default:
+        throw new IllegalArgumentException("Unsupported flink version: " + 
versionStr);
+    }
+  }
+
+  @Override
+  public String toString() {
+    return versionStr;
+  }
+
+  public boolean isNewerOrEqualVersionThan(FlinkVersion otherVersion) {
+    return this.ordinal() >= otherVersion.ordinal();
+  }
+}
diff --git 
a/tests/flink-it/src/test/java/org/apache/celeborn/tests/flink/JobGraphRunningHelper.java
 
b/tests/flink-it/src/test/java/org/apache/celeborn/tests/flink/JobGraphRunningHelper.java
new file mode 100644
index 000000000..2bb642b8a
--- /dev/null
+++ 
b/tests/flink-it/src/test/java/org/apache/celeborn/tests/flink/JobGraphRunningHelper.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.flink;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.client.program.MiniClusterClient;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.configuration.TaskManagerOptions;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobmaster.JobResult;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+
+/** Utils to run {@link JobGraph} on {@link MiniCluster}. */
+public class JobGraphRunningHelper {
+
+  public static void execute(
+      JobGraph jobGraph,
+      Configuration configuration,
+      int numTaskManagers,
+      int numSlotsPerTaskManager)
+      throws Exception {
+    configuration.set(TaskManagerOptions.TOTAL_FLINK_MEMORY, 
MemorySize.parse("1g"));
+
+    // use random ports
+    if (!configuration.containsKey("jobmanager.rpc.port")) {
+      configuration.setString("jobmanager.rpc.port", "0");
+    }
+    if (!configuration.containsKey("rest.bind-port")) {
+      configuration.setString("rest.bind-port", "0");
+    }
+
+    final MiniClusterConfiguration miniClusterConfiguration =
+        new MiniClusterConfiguration.Builder()
+            .setConfiguration(configuration)
+            .setNumTaskManagers(numTaskManagers)
+            .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+            .build();
+
+    try (MiniCluster miniCluster = new MiniCluster(miniClusterConfiguration)) {
+      miniCluster.start();
+
+      MiniClusterClient miniClusterClient = new 
MiniClusterClient(configuration, miniCluster);
+      // wait for the submission to succeed
+      JobID jobID = miniClusterClient.submitJob(jobGraph).get();
+
+      JobResult jobResult = miniClusterClient.requestJobResult(jobID).get();
+      if (jobResult.getSerializedThrowable().isPresent()) {
+        throw new AssertionError(jobResult.getSerializedThrowable().get());
+      }
+    }
+  }
+}
diff --git 
a/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HybridShuffleWordCountTest.scala
 
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HybridShuffleWordCountTest.scala
new file mode 100644
index 000000000..b174914ef
--- /dev/null
+++ 
b/tests/flink-it/src/test/scala/org/apache/celeborn/tests/flink/HybridShuffleWordCountTest.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.flink
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.flink.api.common.RuntimeExecutionMode
+import org.apache.flink.api.common.restartstrategy.RestartStrategies
+import org.apache.flink.configuration.{Configuration, ExecutionOptions}
+import org.apache.flink.runtime.jobgraph.JobType
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
+import org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.celeborn.common.internal.Logging
+import org.apache.celeborn.service.deploy.MiniClusterFeature
+import org.apache.celeborn.service.deploy.worker.Worker
+
+class HybridShuffleWordCountTest extends AnyFunSuite with Logging with 
MiniClusterFeature
+  with BeforeAndAfterAll {
+  var workers: collection.Set[Worker] = null
+
+  val NUM_PARALLELISM = 8
+
+  val NUM_TASK_MANAGERS = 2
+
+  val NUM_SLOTS_PER_TASK_MANAGER = 10
+
+  override def beforeAll(): Unit = {
+    logInfo("test initialized , setup celeborn mini cluster")
+    val masterConf = Map(
+      "celeborn.master.host" -> "localhost",
+      "celeborn.master.port" -> "9097")
+    val workerConf = Map("celeborn.master.endpoints" -> "localhost:9097")
+    workers = setUpMiniCluster(masterConf, workerConf)._2
+  }
+
+  override def afterAll(): Unit = {
+    logInfo("all test complete , stop celeborn mini cluster")
+    shutdownMiniCluster()
+  }
+
+  test("Celeborn Flink Hybrid Shuffle Integration test(Local) - word count") {
+    assumeFlinkVersion()
+    testLocalEnv()
+  }
+
+  test(
+    "Celeborn Flink Hybrid Shuffle Integration test(Flink mini cluster) single 
tier - word count") {
+    assumeFlinkVersion()
+    testInMiniCluster()
+  }
+
+  private def assumeFlinkVersion(): Unit = {
+    // Celeborn supports flink hybrid shuffle staring from flink 1.20
+    val flinkVersion = sys.env.getOrElse("FLINK_VERSION", "")
+    assume(
+      flinkVersion.nonEmpty && FlinkVersion.fromVersionStr(
+        
flinkVersion.split("\\.").take(2).mkString(".")).isNewerOrEqualVersionThan(
+        FlinkVersion.v1_20))
+  }
+
+  private def testLocalEnv(): Unit = {
+    // set up execution environment
+    val configuration = new Configuration
+    val parallelism = NUM_PARALLELISM
+    configuration.setString(
+      "shuffle-service-factory.class",
+      "org.apache.flink.runtime.io.network.NettyShuffleServiceFactory")
+    configuration.setString(
+      "taskmanager.network.hybrid-shuffle.external-remote-tier-factory.class",
+      "org.apache.celeborn.plugin.flink.tiered.CelebornTierFactory")
+    configuration.setString("celeborn.master.endpoints", "localhost:9097")
+    configuration.set(ExecutionOptions.RUNTIME_MODE, 
RuntimeExecutionMode.BATCH)
+    configuration.setString(
+      "execution.batch-shuffle-mode",
+      "ALL_EXCHANGES_HYBRID_FULL")
+    configuration.setString("taskmanager.memory.network.min", "1024m")
+    configuration.setString(
+      "execution.batch.adaptive.auto-parallelism.min-parallelism",
+      "" + parallelism)
+    configuration.setString("restart-strategy.type", "fixed-delay")
+    configuration.setString("restart-strategy.fixed-delay.attempts", "50")
+    configuration.setString("restart-strategy.fixed-delay.delay", "5s")
+    configuration.setString(
+      "jobmanager.partition.hybrid.partition-data-consume-constraint",
+      "ALL_PRODUCERS_FINISHED")
+    configuration.setString("rest.bind-port", "8081-8099")
+
+    val env = 
StreamExecutionEnvironment.createLocalEnvironmentWithWebUI(configuration)
+    env.getConfig.setParallelism(parallelism)
+    env.disableOperatorChaining()
+    // make parameters available in the web interface
+    WordCountHelper.execute(env, parallelism)
+
+    val graph = env.getStreamGraph
+    env.execute(graph)
+    checkFlushingFileLength()
+  }
+
+  private def testInMiniCluster(): Unit = {
+    // set up execution environment
+    val configuration = new Configuration
+    val parallelism = NUM_PARALLELISM
+    configuration.setString(
+      "shuffle-service-factory.class",
+      "org.apache.flink.runtime.io.network.NettyShuffleServiceFactory")
+    configuration.setString(
+      "taskmanager.network.hybrid-shuffle.external-remote-tier-factory.class",
+      "org.apache.celeborn.plugin.flink.tiered.CelebornTierFactory")
+    configuration.setString("celeborn.master.endpoints", "localhost:9097")
+    configuration.set(ExecutionOptions.RUNTIME_MODE, 
RuntimeExecutionMode.BATCH)
+    configuration.setString(
+      "execution.batch-shuffle-mode",
+      "ALL_EXCHANGES_HYBRID_FULL")
+    configuration.setString("taskmanager.memory.network.min", "256m")
+    configuration.setString(
+      "execution.batch.adaptive.auto-parallelism.min-parallelism",
+      "" + parallelism)
+    configuration.setString("restart-strategy.type", "fixed-delay")
+    configuration.setString("restart-strategy.fixed-delay.attempts", "50")
+    configuration.setString("restart-strategy.fixed-delay.delay", "5s")
+    configuration.setString(
+      "jobmanager.partition.hybrid.partition-data-consume-constraint",
+      "ALL_PRODUCERS_FINISHED")
+    configuration.setString("rest.bind-port", "8081-8099")
+    val env = getEnvironment(configuration);
+    env.getConfig.setParallelism(parallelism)
+    env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0L))
+    env.disableOperatorChaining()
+    // make parameters available in the web interface
+    WordCountHelper.execute(env, parallelism)
+
+    val graph = env.getStreamGraph
+    graph.setJobType(JobType.BATCH)
+    val jobGraph = StreamingJobGraphGenerator.createJobGraph(graph)
+    JobGraphRunningHelper.execute(
+      jobGraph,
+      configuration,
+      NUM_TASK_MANAGERS,
+      NUM_SLOTS_PER_TASK_MANAGER)
+    checkFlushingFileLength()
+  }
+
+  def getEnvironment(configuration: Configuration): StreamExecutionEnvironment 
= {
+    
configuration.setBoolean("taskmanager.network.hybrid-shuffle.enable-new-mode", 
true)
+    
configuration.setBoolean("execution.batch.adaptive.auto-parallelism.enabled", 
true)
+    val env = StreamExecutionEnvironment.getExecutionEnvironment(configuration)
+    env.setRestartStrategy(RestartStrategies.fixedDelayRestart(10, 0L))
+    env
+  }
+
+  private def checkFlushingFileLength(): Unit = {
+    workers.map(worker => {
+      worker.storageManager.workingDirWriters.values().asScala.map(writers => {
+        writers.forEach((fileName, fileWriter) => {
+          assert(new File(fileName).length() == 
fileWriter.getDiskFileInfo.getFileLength)
+        })
+      })
+    })
+  }
+}
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
index 7c1280f13..c66a36d51 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/MiniClusterFeature.scala
@@ -247,7 +247,7 @@ trait MiniClusterFeature extends Logging {
     workerInfos.keySet
   }
 
-  private def setUpMiniCluster(
+  def setUpMiniCluster(
       masterConf: Map[String, String] = null,
       workerConf: Map[String, String] = null,
       workerNum: Int = 3): (Master, collection.Set[Worker]) = {

Reply via email to