This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1c3b94150b44 [SPARK-46709][SS] Expose partition_id column for state data source 1c3b94150b44 is described below commit 1c3b94150b44f51af4e23601fb6e7e51c4605712 Author: Chaoqin Li <chaoqin...@databricks.com> AuthorDate: Mon Jan 15 08:21:19 2024 +0900 [SPARK-46709][SS] Expose partition_id column for state data source ### What changes were proposed in this pull request? Expose the partition_id column of state data source was hidden by default. ### Why are the changes needed? partition_id column is useful to users. ### Does this PR introduce _any_ user-facing change? yes, Expose the partition_id column of state data source was hidden by default and modify the doc accordingly. ### How was this patch tested? Modify existing integration test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44717 from chaoqin-li1123/unhide_partition_id. Authored-by: Chaoqin Li <chaoqin...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- docs/structured-streaming-state-data-source.md | 4 ++-- .../datasources/v2/state/StateDataSource.scala | 3 ++- .../v2/state/StatePartitionReader.scala | 18 ++++------------ .../datasources/v2/state/StateTable.scala | 22 +++++-------------- .../StreamStreamJoinStatePartitionReader.scala | 18 ++++------------ .../v2/state/StateDataSourceReadSuite.scala | 25 ++++++---------------- 6 files changed, 24 insertions(+), 66 deletions(-) diff --git a/docs/structured-streaming-state-data-source.md b/docs/structured-streaming-state-data-source.md index ae323f6b0c14..986699130669 100644 --- a/docs/structured-streaming-state-data-source.md +++ b/docs/structured-streaming-state-data-source.md @@ -96,9 +96,9 @@ Each row in the source has the following schema: <td></td> </tr> <tr> - <td>_partition_id</td> + <td>partition_id</td> <td>int</td> - <td>metadata column (hidden unless specified with SELECT)</td> + <td></td> </tr> </table> diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 1192accaabef..1a8f444042c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DI import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -83,6 +83,7 @@ class StateDataSource extends TableProvider with DataSourceRegister { new StructType() .add("key", keySchema) .add("value", valueSchema) + .add("partition_id", IntegerType) } catch { case NonFatal(e) => throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 1e5f7216e8bf..ef8d7bf628bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId} @@ -99,18 +99,7 @@ class StatePartitionReader( } } - private val joinedRow = new JoinedRow() - - private def addMetadata(row: InternalRow): InternalRow = { - val metadataRow = new GenericInternalRow( - StateTable.METADATA_COLUMNS.map(_.name()).map { - case "_partition_id" => partition.partition.asInstanceOf[Any] - }.toArray - ) - joinedRow.withLeft(row).withRight(metadataRow) - } - - override def get(): InternalRow = addMetadata(current) + override def get(): InternalRow = current override def close(): Unit = { current = null @@ -118,9 +107,10 @@ class StatePartitionReader( } private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = { - val row = new GenericInternalRow(2) + val row = new GenericInternalRow(3) row.update(0, pair._1) row.update(1, pair._2) + row.update(2, partition.partition) row } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala index 96c1c01cede2..824968e709ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.state.StateStoreConf -import org.apache.spark.sql.types.{DataType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -69,18 +69,20 @@ class StateTable( override def properties(): util.Map[String, String] = Map.empty[String, String].asJava private def isValidSchema(schema: StructType): Boolean = { - if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value")) { + if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", "partition_id")) { false } else if (!SchemaUtil.getSchemaAsDataType(schema, "key").isInstanceOf[StructType]) { false } else if (!SchemaUtil.getSchemaAsDataType(schema, "value").isInstanceOf[StructType]) { false + } else if (!SchemaUtil.getSchemaAsDataType(schema, "partition_id").isInstanceOf[IntegerType]) { + false } else { true } } - override def metadataColumns(): Array[MetadataColumn] = METADATA_COLUMNS.toArray + override def metadataColumns(): Array[MetadataColumn] = Array.empty } /** @@ -89,18 +91,4 @@ class StateTable( */ object StateTable { private val CAPABILITY = Set(TableCapability.BATCH_READ).asJava - - val METADATA_COLUMNS: Seq[MetadataColumn] = Seq(PartitionId) - - private object PartitionId extends MetadataColumn { - override def name(): String = "_partition_id" - - override def dataType(): DataType = IntegerType - - override def isNullable: Boolean = false - - override def comment(): String = { - "Represents an ID for a physical state partition this row belongs to." - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala index 26492f8790c4..d0dd6cb7d1b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues @@ -148,18 +148,7 @@ class StreamStreamJoinStatePartitionReader( } } - private val joinedRow = new JoinedRow() - - private def addMetadata(row: InternalRow): InternalRow = { - val metadataRow = new GenericInternalRow( - StateTable.METADATA_COLUMNS.map(_.name()).map { - case "_partition_id" => partition.partition.asInstanceOf[Any] - }.toArray - ) - joinedRow.withLeft(row).withRight(metadataRow) - } - - override def get(): InternalRow = addMetadata(current) + override def get(): InternalRow = current override def close(): Unit = { current = null @@ -169,9 +158,10 @@ class StreamStreamJoinStatePartitionReader( } private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = { - val row = new GenericInternalRow(2) + val row = new GenericInternalRow(3) row.update(0, pair._1) row.update(1, pair._2) + row.update(2, partition.partition) row } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 86c3ab70af68..c800168b507a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -687,7 +687,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass } } - test("metadata column") { + test("partition_id column") { withTempDir { tempDir => import testImplicits._ val stream = MemoryStream[Int] @@ -712,14 +712,11 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass // skip version and operator ID to test out functionalities .load() - assert(!stateReadDf.schema.exists(_.name == "_partition_id"), - "metadata column should not be exposed until it is explicitly specified!") - val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) val resultDf = stateReadDf - .selectExpr("key.value AS key_value", "value.count AS value_count", "_partition_id") - .where("_partition_id % 2 = 0") + .selectExpr("key.value AS key_value", "value.count AS value_count", "partition_id") + .where("partition_id % 2 = 0") // NOTE: This is a hash function of distribution for stateful operator. val hash = HashPartitioning( @@ -738,17 +735,12 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass } } - test("metadata column with stream-stream join") { + test("partition_id column with stream-stream join") { val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) withTempDir { tempDir => runStreamStreamJoinQueryWithOneThousandInputs(tempDir.getAbsolutePath) - def assertPartitionIdColumnIsNotExposedByDefault(df: DataFrame): Unit = { - assert(!df.schema.exists(_.name == "_partition_id"), - "metadata column should not be exposed until it is explicitly specified!") - } - def assertPartitionIdColumn(df: DataFrame): Unit = { // NOTE: This is a hash function of distribution for stateful operator. // stream-stream join uses the grouping key for the equality match in the join condition. @@ -759,8 +751,8 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass numShufflePartitions) val partIdExpr = hash.partitionIdExpression - val dfWithPartition = df.selectExpr("key.field0 As key_0", "_partition_id") - .where("_partition_id % 2 = 0") + val dfWithPartition = df.selectExpr("key.field0 As key_0", "partition_id") + .where("partition_id % 2 = 0") checkAnswer(dfWithPartition, Range.inclusive(2, 1000, 2).map { idx => @@ -778,8 +770,6 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.JOIN_SIDE, side) .load() - - assertPartitionIdColumnIsNotExposedByDefault(stateReaderForLeft) assertPartitionIdColumn(stateReaderForLeft) val stateReaderForKeyToNumValues = spark.read @@ -789,7 +779,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass s"$side-keyToNumValues") .load() - assertPartitionIdColumnIsNotExposedByDefault(stateReaderForKeyToNumValues) + assertPartitionIdColumn(stateReaderForKeyToNumValues) val stateReaderForKeyWithIndexToValue = spark.read @@ -799,7 +789,6 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass s"$side-keyWithIndexToValue") .load() - assertPartitionIdColumnIsNotExposedByDefault(stateReaderForKeyWithIndexToValue) assertPartitionIdColumn(stateReaderForKeyWithIndexToValue) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org