waitinfuture commented on code in PR #2362:
URL: https://github.com/apache/celeborn/pull/2362#discussion_r1615073010


##########
client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala:
##########
@@ -107,60 +114,139 @@ class CelebornShuffleReader[K, C](
       }
     }
 
-    val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
-    (startPartition until endPartition).map(partitionId => {
+    val startTime = System.currentTimeMillis()
+    val fetchTimeoutMs = conf.clientFetchTimeoutMs
+    val localFetchEnabled = conf.enableReadLocalShuffleFile
+    val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
+    // startPartition is irrelevant
+    val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
+    // host-port -> (TransportClient, PartitionLocation Array, 
PbOpenStreamList)
+    val workerRequestMap = new util.HashMap[
+      String,
+      (TransportClient, util.ArrayList[PartitionLocation], 
PbOpenStreamList.Builder)]()
+
+    var partCnt = 0
+
+    (startPartition until endPartition).foreach { partitionId =>
+      if (fileGroups.partitionGroups.containsKey(partitionId)) {
+        fileGroups.partitionGroups.get(partitionId).asScala.foreach { location 
=>
+          partCnt += 1
+          val hostPort = location.hostAndFetchPort
+          if (!workerRequestMap.containsKey(hostPort)) {
+            val client = shuffleClient.getDataClientFactory().createClient(
+              location.getHost,
+              location.getFetchPort)
+            val pbOpenStreamList = PbOpenStreamList.newBuilder()
+            pbOpenStreamList.setShuffleKey(shuffleKey)
+            workerRequestMap.put(
+              hostPort,
+              (client, new util.ArrayList[PartitionLocation], 
pbOpenStreamList))
+          }
+          val (_, locArr, pbOpenStreamListBuilder) = 
workerRequestMap.get(hostPort)
+
+          locArr.add(location)
+          pbOpenStreamListBuilder.addFileName(location.getFileName)
+            .addStartIndex(startMapIndex)
+            .addEndIndex(endMapIndex)
+          pbOpenStreamListBuilder.addReadLocalShuffle(localFetchEnabled)
+        }
+      }
+    }
+
+    val locationStreamHandlerMap: ConcurrentHashMap[PartitionLocation, 
PbStreamHandler] =
+      JavaUtils.newConcurrentHashMap()
+
+    val futures = workerRequestMap.values().asScala.map { entry =>
       streamCreatorPool.submit(new Runnable {
         override def run(): Unit = {
-          if (exceptionRef.get() == null) {
+          val (client, locArr, pbOpenStreamListBuilder) = entry
+          val msg = new TransportMessage(
+            MessageType.BATCH_OPEN_STREAM,
+            pbOpenStreamListBuilder.build().toByteArray)
+          val pbOpenStreamListResponse =
             try {
-              val inputStream = shuffleClient.readPartition(
-                shuffleId,
-                handle.shuffleId,
-                partitionId,
-                context.attemptNumber(),
-                startMapIndex,
-                endMapIndex,
-                if (throwsFetchFailure) exceptionMaker else null,
-                metricsCallback)
-              streams.put(partitionId, inputStream)
+              val response = client.sendRpcSync(msg.toByteBuffer, 
fetchTimeoutMs)
+              
TransportMessage.fromByteBuffer(response).getParsedPayload[PbOpenStreamListResponse]
             } catch {
-              case e: IOException =>
-                logError(s"Exception caught when readPartition $partitionId!", 
e)
-                exceptionRef.compareAndSet(null, e)
-              case e: Throwable =>
-                logError(s"Non IOException caught when readPartition 
$partitionId!", e)
-                exceptionRef.compareAndSet(null, new CelebornIOException(e))
+              case _: Exception => null
+            }
+          if (pbOpenStreamListResponse != null) {
+            0 until locArr.size() foreach { idx =>

Review Comment:
   Hi @s0nskar , in fact there is one stream for one partition id, which 
probably contains multiple PartitionLocations for that partition id, spreading 
in multiple workers, see `CelebornShuffleReader#createInputStream`. I think 
it's a reasonable abstraction to write code like
   ```
    val recordIter = (startPartition until 
endPartition).iterator.map(partitionId => xxx
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to