aokolnychyi commented on a change in pull request #32921:
URL: https://github.com/apache/spark/pull/32921#discussion_r652183144
##########
File path:
sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
##########
@@ -245,21 +246,63 @@ class InMemoryTable(
}
}
- class InMemoryBatchScan(
- data: Array[InputPartition],
+ case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong)
extends Statistics
+
+ case class InMemoryBatchScan(
+ var data: Seq[InputPartition],
readSchema: StructType,
- tableSchema: StructType) extends Scan with Batch {
- override def readSchema(): StructType = readSchema
+ tableSchema: StructType)
+ extends Scan with Batch with SupportsDynamicFiltering with
SupportsReportStatistics {
override def toBatch: Batch = this
- override def planInputPartitions(): Array[InputPartition] = data
+ override def estimateStatistics(): Statistics = {
+ val inputPartitions = data.map(_.asInstanceOf[BufferedRows])
+ val numRows = inputPartitions.map(_.rows.size).sum
+ val rowSizeInBytes = schema.fields.zipWithIndex.map { case (field,
index) =>
+ field.dataType match {
+ case IntegerType => 4L
+ case LongType => 8L
+ case StringType =>
+ val numChars = inputPartitions
+ .map(_.rows.map(_.getString(index).length).sum)
+ .sum
+ numChars / numRows
+ case _ => 8L
+ }
+ }.sum
+ val sizeInBytes = numRows * rowSizeInBytes
+ InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows))
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = data.toArray
override def createReaderFactory(): PartitionReaderFactory = {
val metadataColumns =
readSchema.map(_.name).filter(metadataColumnNames.contains)
val nonMetadataColumns = readSchema.filterNot(f =>
metadataColumns.contains(f.name))
new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns,
tableSchema)
}
+
+ override def filterAttributes(): Array[NamedReference] = {
+ val scanFields = readSchema.fields.map(_.name).toSet
+ partitioning.flatMap(_.references)
+ .filter(ref => scanFields.contains(ref.fieldNames.mkString(".")))
+ }
+
+ override def filter(filters: Array[Filter]): Unit = {
+ if (partitioning.length == 1) {
Review comment:
It is just a temp trivial implementation.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]