hvanhovell commented on code in PR #53497: URL: https://github.com/apache/spark/pull/53497#discussion_r2625481560
########## sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/LocalRelationDatasource.scala: ########## @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.execution + +import java.io.InputStream +import java.util + +import scala.util.control.NonFatal + +import org.apache.arrow.memory.BufferAllocator + +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.{Column, SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.{ArrowUtils, CaseInsensitiveStringMap, ConcatenatingArrowStreamReader, MessageIterator} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.storage.CacheId + +/** + * Datasource implementation that can read a Connect Chunked Datasource directly from the + * block manager. + * + * @param id of the LocalRelation. + * @param sessionId of the LocalRelation. + * @param schema of the LocalRelation. + * @param dataHashes ids of the blocks stored in the BlockManager. + */ +class LocalRelationTable( + id: Long, + sessionId: String, + schema: StructType, + dataHashes: Seq[String]) + extends Table with SupportsRead { + + override def name(): String = s"LocalRelation[session=$sessionId, id=$id]" + + override def columns(): Array[Column] = schema.fields.map { field => + Column.create(field.name, field.dataType, field.nullable) + } + + override def capabilities(): util.Set[TableCapability] = util.EnumSet.of( + TableCapability.BATCH_READ, + TableCapability.MICRO_BATCH_READ) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { + override def build(): Scan = new LocalRelationScan(schema, sessionId, dataHashes) + } +} + +case class LocalRelationScanPartition(sessionId: String, dataHash: String) extends InputPartition + +class LocalRelationScan(schema: StructType, sessionId: String, dataHashes: Seq[String]) + extends Scan with Batch { + override def readSchema(): StructType = schema + + override def toBatch: Batch = this + + override def planInputPartitions(): Array[InputPartition] = + dataHashes.map(hash => LocalRelationScanPartition(sessionId, hash)).toArray + + override def createReaderFactory(): PartitionReaderFactory = + LocalRelationScanPartitionReaderFactory + + override def columnarSupportMode(): Scan.ColumnarSupportMode = + Scan.ColumnarSupportMode.SUPPORTED +} + +object LocalRelationScanPartitionReaderFactory extends PartitionReaderFactory { + override def supportColumnarReads(partition: InputPartition): Boolean = true + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = + throw new IllegalStateException("Row based reads are not supported (yet)") + + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + val localRelationScanPartition = partition.asInstanceOf[LocalRelationScanPartition] + new LocalRelationPartitionReader(localRelationScanPartition) + } +} + +class LocalRelationPartitionReader(partition: LocalRelationScanPartition) + extends PartitionReader[ColumnarBatch] with Logging { + private var input: InputStream = _ + private var allocator: BufferAllocator = _ + private var reader: ConcatenatingArrowStreamReader = _ + + private def init(): Unit = { + if (reader == null) { + val blockManager = SparkEnv.get.blockManager + val blockId = CacheId(partition.sessionId, partition.dataHash) + input = blockManager.getLocalBytes(blockId).map(_.toInputStream()) Review Comment: I think this can be improved. I don't think we replicate the block to this executor by doing the remote read like this. If we were to replicate, then we could also try to use locality to make re-reads more effective. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
