RexXiong commented on code in PR #2362:
URL: 
https://github.com/apache/incubator-celeborn/pull/2362#discussion_r1517101230


##########
client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java:
##########
@@ -242,36 +248,37 @@ private boolean skipLocation(int startMapIndex, int 
endMapIndex, PartitionLocati
       return true;
     }
 
-    private PartitionLocation nextReadableLocation() {
-      int locationCount = locations.length;
+    private Tuple2<PartitionLocation, PbStreamHandler> nextReadableLocation() {
+      int locationCount = locations.size();
       if (fileIndex >= locationCount) {
         return null;
       }
-      PartitionLocation currentLocation = locations[fileIndex];
+      PartitionLocation currentLocation = locations.get(fileIndex);
       while (skipLocation(startMapIndex, endMapIndex, currentLocation)) {
         skipCount.increment();
         fileIndex++;
         if (fileIndex == locationCount) {
           return null;
         }
-        currentLocation = locations[fileIndex];
+        currentLocation = locations.get(fileIndex);
       }
 
       fetchChunkRetryCnt = 0;
 
-      return currentLocation;
+      return new Tuple2(
+          currentLocation, streamHandlers == null ? null : 
streamHandlers.get(fileIndex));

Review Comment:
   Use streamHandlers.remove as we never use the  streamHandler anymore.



##########
client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala:
##########
@@ -107,33 +114,115 @@ class CelebornShuffleReader[K, C](
       }
     }
 
-    val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
-    (startPartition until endPartition).map(partitionId => {
+    val fetchTimeoutMs = conf.clientFetchTimeoutMs
+    val localFetchEnabled = conf.enableReadLocalShuffleFile
+    val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
+    // startPartition is irrelevant
+    val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+    // host-port -> (TransportClient, PartitionLocation Array, 
PbOpenStreamList)
+    val workerRequestMap = new util.HashMap[
+      String,
+      (TransportClient, util.ArrayList[PartitionLocation], 
PbOpenStreamList.Builder)]()
+
+    (startPartition until endPartition).foreach { partitionId =>
+      if (fileGroups.partitionGroups.containsKey(partitionId)) {
+        fileGroups.partitionGroups.get(partitionId).asScala.foreach { location 
=>
+          val hostPort = location.hostAndFetchPort
+          if (!workerRequestMap.containsKey(hostPort)) {
+            val client = shuffleClient.getDataClientFactory().createClient(
+              location.getHost,
+              location.getFetchPort)
+            val pbOpenStreamList = PbOpenStreamList.newBuilder()
+            pbOpenStreamList.setShuffleKey(shuffleKey)
+            workerRequestMap.put(
+              hostPort,
+              (client, new util.ArrayList[PartitionLocation], 
pbOpenStreamList))
+          }
+          val (_, locArr, pbOpenStreamListBuilder) = 
workerRequestMap.get(hostPort)
+
+          locArr.add(location)
+          pbOpenStreamListBuilder.addFileName(location.getFileName)
+            .addStartIndex(startMapIndex)
+            .addEndIndex(endMapIndex)
+          if (localFetchEnabled) {

Review Comment:
   pbOpenStreamListBuilder.addReadLocalShuffle(localFetchEnabled)



##########
worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala:
##########
@@ -301,6 +397,42 @@ class FetchHandler(
     }
   }
 
+  private def makeStreamHandler(
+      streamId: Long,
+      numChunks: Int,
+      offsets: util.List[java.lang.Long] = null,
+      filepath: String = ""): PbStreamHandler = {
+    val pbStreamHandlerBuilder = 
PbStreamHandler.newBuilder.setStreamId(streamId).setNumChunks(
+      numChunks)
+    if (offsets != null) {
+      pbStreamHandlerBuilder.addAllChunkOffsets(offsets)
+    }
+    if (filepath.nonEmpty) {
+      pbStreamHandlerBuilder.setFullPath(filepath)
+    }
+    pbStreamHandlerBuilder.build()
+  }
+
+  private def replyStreamHandler(
+      client: TransportClient,
+      requestId: Long,
+      pbStreamHandler: PbStreamHandler,
+      isLegacy: Boolean): Unit = {
+    if (isLegacy) {
+      client.getChannel.writeAndFlush(new RpcResponse(
+        requestId,
+        new NioManagedBuffer(new StreamHandle(
+          pbStreamHandler.getStreamId,
+          pbStreamHandler.getNumChunks).toByteBuffer)))
+    } else {
+      client.getChannel.writeAndFlush(new RpcResponse(
+        requestId,
+        new NioManagedBuffer(new TransportMessage(
+          MessageType.STREAM_HANDLER,
+          pbStreamHandler.toByteArray).toByteBuffer)))
+    }
+  }
+
   private def replyStreamHandler(

Review Comment:
   Better to eliminate duplicate code in two replyStreamHandlers



##########
client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala:
##########
@@ -107,33 +114,121 @@ class CelebornShuffleReader[K, C](
       }
     }
 
-    val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
-    (startPartition until endPartition).map(partitionId => {
+    val startTime = System.currentTimeMillis()
+    val fetchTimeoutMs = conf.clientFetchTimeoutMs
+    val localFetchEnabled = conf.enableReadLocalShuffleFile
+    val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
+    // startPartition is irrelevant
+    val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+    // host-port -> (TransportClient, PartitionLocation Array, 
PbOpenStreamList)
+    val workerRequestMap = new util.HashMap[
+      String,
+      (TransportClient, util.ArrayList[PartitionLocation], 
PbOpenStreamList.Builder)]()
+
+    var partCnt = 0
+
+    (startPartition until endPartition).foreach { partitionId =>
+      if (fileGroups.partitionGroups.containsKey(partitionId)) {
+        fileGroups.partitionGroups.get(partitionId).asScala.foreach { location 
=>
+          partCnt += 1
+          val hostPort = location.hostAndFetchPort
+          if (!workerRequestMap.containsKey(hostPort)) {
+            val client = shuffleClient.getDataClientFactory().createClient(
+              location.getHost,
+              location.getFetchPort)
+            val pbOpenStreamList = PbOpenStreamList.newBuilder()
+            pbOpenStreamList.setShuffleKey(shuffleKey)
+            workerRequestMap.put(
+              hostPort,
+              (client, new util.ArrayList[PartitionLocation], 
pbOpenStreamList))
+          }
+          val (_, locArr, pbOpenStreamListBuilder) = 
workerRequestMap.get(hostPort)
+
+          locArr.add(location)
+          pbOpenStreamListBuilder.addFileName(location.getFileName)
+            .addStartIndex(startMapIndex)
+            .addEndIndex(endMapIndex)
+          if (localFetchEnabled) {
+            pbOpenStreamListBuilder.addReadLocalShuffle(true)
+          } else {
+            pbOpenStreamListBuilder.addReadLocalShuffle(false)
+          }
+        }
+      }
+    }
+
+    val locationStreamHandlerMap: ConcurrentHashMap[PartitionLocation, 
PbStreamHandler] =
+      JavaUtils.newConcurrentHashMap()
+
+    val futures = workerRequestMap.values().asScala.map { entry =>
       streamCreatorPool.submit(new Runnable {
         override def run(): Unit = {
-          if (exceptionRef.get() == null) {
+          val (client, locArr, pbOpenStreamListBuilder) = entry
+          val msg = new TransportMessage(
+            MessageType.OPEN_STREAM_LIST,
+            pbOpenStreamListBuilder.build().toByteArray)
+          val pbOpenStreamListResponse =
             try {
-              val inputStream = shuffleClient.readPartition(
-                shuffleId,
-                handle.shuffleId,
-                partitionId,
-                context.attemptNumber(),
-                startMapIndex,
-                endMapIndex,
-                if (throwsFetchFailure) exceptionMaker else null,
-                metricsCallback)
-              streams.put(partitionId, inputStream)
+              val response = client.sendRpcSync(msg.toByteBuffer, 
fetchTimeoutMs)
+              
TransportMessage.fromByteBuffer(response).getParsedPayload[PbOpenStreamListResponse]
             } catch {
-              case e: IOException =>
-                logError(s"Exception caught when readPartition $partitionId!", 
e)
-                exceptionRef.compareAndSet(null, e)
-              case e: Throwable =>
-                logError(s"Non IOException caught when readPartition 
$partitionId!", e)
-                exceptionRef.compareAndSet(null, new CelebornIOException(e))
+              case _: Exception => null
+            }
+          if (pbOpenStreamListResponse != null) {
+            0 until locArr.size() foreach { idx =>
+              val streamHandlerOpt = 
pbOpenStreamListResponse.getStreamHandlerOptList.get(idx)
+              if (streamHandlerOpt.getStatus == StatusCode.SUCCESS.getValue) {
+                locationStreamHandlerMap.put(locArr.get(idx), 
streamHandlerOpt.getStreamHandler)
+              }
             }
           }
         }
       })
+    }.toList
+    // wait for all futures to complete
+    futures.foreach(f => f.get())
+    val end = System.currentTimeMillis()
+    logInfo(s"openstreamlist for ${partCnt} cost ${end - startTime}ms")
+
+    val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
+    (startPartition until endPartition).map(partitionId => {
+      val locations =
+        if (fileGroups.partitionGroups.containsKey(partitionId)) {
+          new util.ArrayList(fileGroups.partitionGroups.get(partitionId))
+        } else new util.ArrayList[PartitionLocation]()
+      val streamHandlers =
+        if (locations != null) {
+          val streamHandlerArr = new 
util.ArrayList[PbStreamHandler](locations.size())
+          locations.asScala.foreach { loc =>
+            streamHandlerArr.add(locationStreamHandlerMap.get(loc))
+          }
+          streamHandlerArr
+        } else null
+      if (exceptionRef.get() == null) {
+        try {
+          val inputStream = shuffleClient.readPartition(

Review Comment:
   We can create celebornInputStreams when we use them



##########
common/src/main/proto/TransportMessages.proto:
##########
@@ -100,6 +100,8 @@ enum MessageType {
   WORKER_EVENT_RESPONSE = 77;
   APPLICATION_META = 78;
   APPLICATION_META_REQUEST = 79;
+  OPEN_STREAM_LIST = 80;
+  OPEN_STREAM_LIST_RESPONSE = 81;

Review Comment:
   BATCH_OPEN_STREAM



##########
client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala:
##########
@@ -107,33 +114,121 @@ class CelebornShuffleReader[K, C](
       }
     }
 
-    val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
-    (startPartition until endPartition).map(partitionId => {
+    val startTime = System.currentTimeMillis()
+    val fetchTimeoutMs = conf.clientFetchTimeoutMs
+    val localFetchEnabled = conf.enableReadLocalShuffleFile
+    val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
+    // startPartition is irrelevant
+    val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+    // host-port -> (TransportClient, PartitionLocation Array, 
PbOpenStreamList)
+    val workerRequestMap = new util.HashMap[
+      String,
+      (TransportClient, util.ArrayList[PartitionLocation], 
PbOpenStreamList.Builder)]()
+
+    var partCnt = 0
+
+    (startPartition until endPartition).foreach { partitionId =>
+      if (fileGroups.partitionGroups.containsKey(partitionId)) {
+        fileGroups.partitionGroups.get(partitionId).asScala.foreach { location 
=>
+          partCnt += 1
+          val hostPort = location.hostAndFetchPort
+          if (!workerRequestMap.containsKey(hostPort)) {
+            val client = shuffleClient.getDataClientFactory().createClient(
+              location.getHost,
+              location.getFetchPort)
+            val pbOpenStreamList = PbOpenStreamList.newBuilder()
+            pbOpenStreamList.setShuffleKey(shuffleKey)
+            workerRequestMap.put(
+              hostPort,
+              (client, new util.ArrayList[PartitionLocation], 
pbOpenStreamList))
+          }
+          val (_, locArr, pbOpenStreamListBuilder) = 
workerRequestMap.get(hostPort)
+
+          locArr.add(location)
+          pbOpenStreamListBuilder.addFileName(location.getFileName)
+            .addStartIndex(startMapIndex)
+            .addEndIndex(endMapIndex)
+          if (localFetchEnabled) {
+            pbOpenStreamListBuilder.addReadLocalShuffle(true)
+          } else {
+            pbOpenStreamListBuilder.addReadLocalShuffle(false)
+          }
+        }
+      }
+    }
+
+    val locationStreamHandlerMap: ConcurrentHashMap[PartitionLocation, 
PbStreamHandler] =
+      JavaUtils.newConcurrentHashMap()
+
+    val futures = workerRequestMap.values().asScala.map { entry =>
       streamCreatorPool.submit(new Runnable {
         override def run(): Unit = {
-          if (exceptionRef.get() == null) {
+          val (client, locArr, pbOpenStreamListBuilder) = entry
+          val msg = new TransportMessage(
+            MessageType.OPEN_STREAM_LIST,
+            pbOpenStreamListBuilder.build().toByteArray)
+          val pbOpenStreamListResponse =
             try {
-              val inputStream = shuffleClient.readPartition(
-                shuffleId,
-                handle.shuffleId,
-                partitionId,
-                context.attemptNumber(),
-                startMapIndex,
-                endMapIndex,
-                if (throwsFetchFailure) exceptionMaker else null,
-                metricsCallback)
-              streams.put(partitionId, inputStream)
+              val response = client.sendRpcSync(msg.toByteBuffer, 
fetchTimeoutMs)
+              
TransportMessage.fromByteBuffer(response).getParsedPayload[PbOpenStreamListResponse]
             } catch {
-              case e: IOException =>
-                logError(s"Exception caught when readPartition $partitionId!", 
e)
-                exceptionRef.compareAndSet(null, e)
-              case e: Throwable =>
-                logError(s"Non IOException caught when readPartition 
$partitionId!", e)
-                exceptionRef.compareAndSet(null, new CelebornIOException(e))
+              case _: Exception => null
+            }
+          if (pbOpenStreamListResponse != null) {
+            0 until locArr.size() foreach { idx =>

Review Comment:
   Why can we ignore partitions whose status is not success? We may probably 
encounter NPE when get StreamHandler from locationStreamHandlerMap later



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to