beliefer commented on code in PR #37393:
URL: https://github.com/apache/spark/pull/37393#discussion_r940803579


##########
sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala:
##########
@@ -0,0 +1,667 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.time.{Instant, ZoneId}
+import java.time.temporal.ChronoUnit
+import java.util
+import java.util.OptionalLong
+
+import scala.collection.mutable
+
+import org.scalatest.Assertions._
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
JoinedRow}
+import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
+import org.apache.spark.sql.connector.distributions.{Distribution, 
Distributions}
+import org.apache.spark.sql.connector.expressions._
+import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
+import org.apache.spark.sql.connector.read._
+import 
org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, 
Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.connector.write._
+import 
org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, 
StreamingWrite}
+import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A simple in-memory table. Rows are stored as a buffered group produced by 
each output task.
+ */
+class InMemoryBaseTable(
+    val name: String,
+    val schema: StructType,
+    override val partitioning: Array[Transform],
+    override val properties: util.Map[String, String],
+    val distribution: Distribution = Distributions.unspecified(),
+    val ordering: Array[SortOrder] = Array.empty,
+    val numPartitions: Option[Int] = None,
+    val isDistributionStrictlyRequired: Boolean = true)
+  extends Table with SupportsRead with SupportsWrite with 
SupportsMetadataColumns {
+
+  protected object PartitionKeyColumn extends MetadataColumn {
+    override def name: String = "_partition"
+    override def dataType: DataType = StringType
+    override def comment: String = "Partition key used to store the row"
+  }
+
+  private object IndexColumn extends MetadataColumn {
+    override def name: String = "index"
+    override def dataType: DataType = IntegerType
+    override def comment: String = "Metadata column used to conflict with a 
data column"
+  }
+
+  // purposely exposes a metadata column that conflicts with a data column in 
some tests
+  override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, 
PartitionKeyColumn)
+  private val metadataColumnNames = metadataColumns.map(_.name).toSet -- 
schema.map(_.name)
+
+  private val allowUnsupportedTransforms =
+    properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean
+
+  partitioning.foreach {
+    case _: IdentityTransform =>
+    case _: YearsTransform =>
+    case _: MonthsTransform =>
+    case _: DaysTransform =>
+    case _: HoursTransform =>
+    case _: BucketTransform =>
+    case _: SortedBucketTransform =>
+    case t if !allowUnsupportedTransforms =>
+      throw new IllegalArgumentException(s"Transform $t is not a supported 
transform")
+  }
+
+  // The key `Seq[Any]` is the partition values.
+  val dataMap: mutable.Map[Seq[Any], BufferedRows] = mutable.Map.empty
+
+  def data: Array[BufferedRows] = dataMap.values.toArray
+
+  def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq
+
+  val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map 
{ ref =>
+    schema.findNestedField(ref.fieldNames(), includeCollections = false) match 
{
+      case Some(_) => ref.fieldNames()
+      case None => throw new IllegalArgumentException(s"${ref.describe()} does 
not exist.")
+    }
+  }
+
+  private val UTC = ZoneId.of("UTC")
+  private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate
+
+  protected def getKey(row: InternalRow): Seq[Any] = {
+    getKey(row, schema)
+  }
+
+  protected def getKey(row: InternalRow, rowSchema: StructType): Seq[Any] = {
+    @scala.annotation.tailrec
+    def extractor(
+        fieldNames: Array[String],
+        schema: StructType,
+        row: InternalRow): (Any, DataType) = {
+      val index = schema.fieldIndex(fieldNames(0))
+      val value = row.toSeq(schema).apply(index)
+      if (fieldNames.length > 1) {
+        (value, schema(index).dataType) match {
+          case (row: InternalRow, nestedSchema: StructType) =>
+            extractor(fieldNames.drop(1), nestedSchema, row)
+          case (_, dataType) =>
+            throw new IllegalArgumentException(s"Unsupported type, 
${dataType.simpleString}")
+        }
+      } else {
+        (value, schema(index).dataType)
+      }
+    }
+
+    val cleanedSchema = 
CharVarcharUtils.replaceCharVarcharWithStringInSchema(rowSchema)
+    partitioning.map {
+      case IdentityTransform(ref) =>
+        extractor(ref.fieldNames, cleanedSchema, row)._1
+      case YearsTransform(ref) =>
+        extractor(ref.fieldNames, cleanedSchema, row) match {
+          case (days: Int, DateType) =>
+            ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, 
DateTimeUtils.daysToLocalDate(days))
+          case (micros: Long, TimestampType) =>
+            val localDate = 
DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
+            ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
+          case (v, t) =>
+            throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
+        }
+      case MonthsTransform(ref) =>
+        extractor(ref.fieldNames, cleanedSchema, row) match {
+          case (days: Int, DateType) =>
+            ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, 
DateTimeUtils.daysToLocalDate(days))
+          case (micros: Long, TimestampType) =>
+            val localDate = 
DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate
+            ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate)
+          case (v, t) =>
+            throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
+        }
+      case DaysTransform(ref) =>
+        extractor(ref.fieldNames, cleanedSchema, row) match {
+          case (days, DateType) =>
+            days
+          case (micros: Long, TimestampType) =>
+            ChronoUnit.DAYS.between(Instant.EPOCH, 
DateTimeUtils.microsToInstant(micros))
+          case (v, t) =>
+            throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
+        }
+      case HoursTransform(ref) =>
+        extractor(ref.fieldNames, cleanedSchema, row) match {
+          case (micros: Long, TimestampType) =>
+            ChronoUnit.HOURS.between(Instant.EPOCH, 
DateTimeUtils.microsToInstant(micros))
+          case (v, t) =>
+            throw new IllegalArgumentException(s"Match: unsupported 
argument(s) type - ($v, $t)")
+        }
+      case BucketTransform(numBuckets, cols, _) =>
+        val valueTypePairs = cols.map(col => extractor(col.fieldNames, 
cleanedSchema, row))
+        var valueHashCode = 0
+        valueTypePairs.foreach( pair =>
+          if ( pair._1 != null) valueHashCode += pair._1.hashCode()
+        )
+        var dataTypeHashCode = 0
+        valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
+        ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % 
numBuckets
+    }
+  }
+
+  protected def addPartitionKey(key: Seq[Any]): Unit = {}
+
+  protected def renamePartitionKey(
+      partitionSchema: StructType,
+      from: Seq[Any],
+      to: Seq[Any]): Boolean = {
+    val rows = dataMap.remove(from).getOrElse(new BufferedRows(from))
+    val newRows = new BufferedRows(to)
+    rows.rows.foreach { r =>
+      val newRow = new GenericInternalRow(r.numFields)
+      for (i <- 0 until r.numFields) newRow.update(i, r.get(i, 
schema(i).dataType))
+      for (i <- 0 until partitionSchema.length) {
+        val j = schema.fieldIndex(partitionSchema(i).name)
+        newRow.update(j, to(i))
+      }
+      newRows.withRow(newRow)
+    }
+    dataMap.put(to, newRows).foreach { _ =>
+      throw new IllegalStateException(
+        s"The ${to.mkString("[", ", ", "]")} partition exists already")
+    }
+    true
+  }
+
+  protected def removePartitionKey(key: Seq[Any]): Unit = dataMap.synchronized 
{
+    dataMap.remove(key)
+  }
+
+  protected def createPartitionKey(key: Seq[Any]): Unit = dataMap.synchronized 
{
+    if (!dataMap.contains(key)) {
+      val emptyRows = new BufferedRows(key)
+      val rows = if (key.length == schema.length) {
+        emptyRows.withRow(InternalRow.fromSeq(key))
+      } else emptyRows
+      dataMap.put(key, rows)
+    }
+  }
+
+  protected def clearPartition(key: Seq[Any]): Unit = dataMap.synchronized {
+    assert(dataMap.contains(key))
+    dataMap(key).clear()
+  }
+
+  def withData(data: Array[BufferedRows]): InMemoryBaseTable = {
+    withData(data, schema)
+  }
+
+  def withData(
+      data: Array[BufferedRows],
+      writeSchema: StructType): InMemoryBaseTable = dataMap.synchronized {
+    data.foreach(_.rows.foreach { row =>
+      val key = getKey(row, writeSchema)
+      dataMap += dataMap.get(key)
+        .map(key -> _.withRow(row))
+        .getOrElse(key -> new BufferedRows(key).withRow(row))
+      addPartitionKey(key)
+    })
+    this
+  }
+
+  override def capabilities: util.Set[TableCapability] = util.EnumSet.of(
+    TableCapability.BATCH_READ,
+    TableCapability.BATCH_WRITE,
+    TableCapability.STREAMING_WRITE,
+    TableCapability.OVERWRITE_BY_FILTER,
+    TableCapability.OVERWRITE_DYNAMIC,
+    TableCapability.TRUNCATE)
+
+  override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder 
= {
+    new InMemoryScanBuilder(schema)
+  }
+
+  class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder
+      with SupportsPushDownRequiredColumns {
+    private var schema: StructType = tableSchema
+
+    override def build: Scan =
+      InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, 
tableSchema)
+
+    override def pruneColumns(requiredSchema: StructType): Unit = {
+      val schemaNames = metadataColumnNames ++ tableSchema.map(_.name)
+      schema = StructType(requiredSchema.filter(f => 
schemaNames.contains(f.name)))
+    }
+  }
+
+  case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) 
extends Statistics
+
+  abstract class BatchScanBaseClass(
+      var data: Seq[InputPartition],
+      readSchema: StructType,
+      tableSchema: StructType)
+    extends Scan with Batch with SupportsReportStatistics with 
SupportsReportPartitioning {
+
+    override def toBatch: Batch = this
+
+    override def estimateStatistics(): Statistics = {
+      if (data.isEmpty) {
+        return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L))
+      }
+
+      val inputPartitions = data.map(_.asInstanceOf[BufferedRows])
+      val numRows = inputPartitions.map(_.rows.size).sum
+      // we assume an average object header is 12 bytes
+      val objectHeaderSizeInBytes = 12L
+      val rowSizeInBytes = objectHeaderSizeInBytes + schema.defaultSize
+      val sizeInBytes = numRows * rowSizeInBytes
+      InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows))
+    }
+
+    override def outputPartitioning(): Partitioning = {
+      if (InMemoryBaseTable.this.partitioning.nonEmpty) {
+        new KeyGroupedPartitioning(
+          InMemoryBaseTable.this.partitioning.map(_.asInstanceOf[Expression]),
+          data.size)
+      } else {
+        new UnknownPartitioning(data.size)
+      }
+    }
+
+    override def planInputPartitions(): Array[InputPartition] = data.toArray
+
+    override def createReaderFactory(): PartitionReaderFactory = {
+      val metadataColumns = 
readSchema.map(_.name).filter(metadataColumnNames.contains)
+      val nonMetadataColumns = readSchema.filterNot(f => 
metadataColumns.contains(f.name))
+      new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, 
tableSchema)
+    }
+  }
+
+  case class InMemoryBatchScan(
+      var _data: Seq[InputPartition],
+      readSchema: StructType,
+      tableSchema: StructType)
+    extends BatchScanBaseClass (_data, readSchema, tableSchema) with 
SupportsRuntimeFiltering {

Review Comment:
   ```suggestion
       extends BatchScanBaseClass(_data, readSchema, tableSchema) with 
SupportsRuntimeFiltering {
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to