This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1c3b94150b44 [SPARK-46709][SS] Expose partition_id column for state 
data source
1c3b94150b44 is described below

commit 1c3b94150b44f51af4e23601fb6e7e51c4605712
Author: Chaoqin Li <chaoqin...@databricks.com>
AuthorDate: Mon Jan 15 08:21:19 2024 +0900

    [SPARK-46709][SS] Expose partition_id column for state data source
    
    ### What changes were proposed in this pull request?
    Expose the partition_id column of state data source was hidden by default.
    
    ### Why are the changes needed?
    partition_id column is useful to users.
    
    ### Does this PR introduce _any_ user-facing change?
    yes, Expose the partition_id column of state data source was hidden by 
default and modify the doc accordingly.
    
    ### How was this patch tested?
    Modify existing integration test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #44717 from chaoqin-li1123/unhide_partition_id.
    
    Authored-by: Chaoqin Li <chaoqin...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 docs/structured-streaming-state-data-source.md     |  4 ++--
 .../datasources/v2/state/StateDataSource.scala     |  3 ++-
 .../v2/state/StatePartitionReader.scala            | 18 ++++------------
 .../datasources/v2/state/StateTable.scala          | 22 +++++--------------
 .../StreamStreamJoinStatePartitionReader.scala     | 18 ++++------------
 .../v2/state/StateDataSourceReadSuite.scala        | 25 ++++++----------------
 6 files changed, 24 insertions(+), 66 deletions(-)

diff --git a/docs/structured-streaming-state-data-source.md 
b/docs/structured-streaming-state-data-source.md
index ae323f6b0c14..986699130669 100644
--- a/docs/structured-streaming-state-data-source.md
+++ b/docs/structured-streaming-state-data-source.md
@@ -96,9 +96,9 @@ Each row in the source has the following schema:
   <td></td>
 </tr>
 <tr>
-  <td>_partition_id</td>
+  <td>partition_id</td>
   <td>int</td>
-  <td>metadata column (hidden unless specified with SELECT)</td>
+  <td></td>
 </tr>
 </table>
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 1192accaabef..1a8f444042c2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -35,7 +35,7 @@ import 
org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DI
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
 RightSide}
 import 
org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker,
 StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
 import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructType}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 /**
@@ -83,6 +83,7 @@ class StateDataSource extends TableProvider with 
DataSourceRegister {
       new StructType()
         .add("key", keySchema)
         .add("value", valueSchema)
+        .add("partition_id", IntegerType)
     } catch {
       case NonFatal(e) =>
         throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 1e5f7216e8bf..ef8d7bf628bf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
JoinedRow, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
 import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, 
StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
@@ -99,18 +99,7 @@ class StatePartitionReader(
     }
   }
 
-  private val joinedRow = new JoinedRow()
-
-  private def addMetadata(row: InternalRow): InternalRow = {
-    val metadataRow = new GenericInternalRow(
-      StateTable.METADATA_COLUMNS.map(_.name()).map {
-        case "_partition_id" => partition.partition.asInstanceOf[Any]
-      }.toArray
-    )
-    joinedRow.withLeft(row).withRight(metadataRow)
-  }
-
-  override def get(): InternalRow = addMetadata(current)
+  override def get(): InternalRow = current
 
   override def close(): Unit = {
     current = null
@@ -118,9 +107,10 @@ class StatePartitionReader(
   }
 
   private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
-    val row = new GenericInternalRow(2)
+    val row = new GenericInternalRow(3)
     row.update(0, pair._1)
     row.update(1, pair._2)
+    row.update(2, partition.partition)
     row
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
index 96c1c01cede2..824968e709ba 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.read.ScanBuilder
 import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
 import org.apache.spark.sql.execution.streaming.state.StateStoreConf
-import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
+import org.apache.spark.sql.types.{IntegerType, StructType}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.util.ArrayImplicits._
 
@@ -69,18 +69,20 @@ class StateTable(
   override def properties(): util.Map[String, String] = Map.empty[String, 
String].asJava
 
   private def isValidSchema(schema: StructType): Boolean = {
-    if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value")) {
+    if (schema.fieldNames.toImmutableArraySeq != Seq("key", "value", 
"partition_id")) {
       false
     } else if (!SchemaUtil.getSchemaAsDataType(schema, 
"key").isInstanceOf[StructType]) {
       false
     } else if (!SchemaUtil.getSchemaAsDataType(schema, 
"value").isInstanceOf[StructType]) {
       false
+    } else if (!SchemaUtil.getSchemaAsDataType(schema, 
"partition_id").isInstanceOf[IntegerType]) {
+      false
     } else {
       true
     }
   }
 
-  override def metadataColumns(): Array[MetadataColumn] = 
METADATA_COLUMNS.toArray
+  override def metadataColumns(): Array[MetadataColumn] = Array.empty
 }
 
 /**
@@ -89,18 +91,4 @@ class StateTable(
  */
 object StateTable {
   private val CAPABILITY = Set(TableCapability.BATCH_READ).asJava
-
-  val METADATA_COLUMNS: Seq[MetadataColumn] = Seq(PartitionId)
-
-  private object PartitionId extends MetadataColumn {
-    override def name(): String = "_partition_id"
-
-    override def dataType(): DataType = IntegerType
-
-    override def isNullable: Boolean = false
-
-    override def comment(): String = {
-      "Represents an ID for a physical state partition this row belongs to."
-    }
-  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
index 26492f8790c4..d0dd6cb7d1b9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.state
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
GenericInternalRow, Literal, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
 import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
@@ -148,18 +148,7 @@ class StreamStreamJoinStatePartitionReader(
     }
   }
 
-  private val joinedRow = new JoinedRow()
-
-  private def addMetadata(row: InternalRow): InternalRow = {
-    val metadataRow = new GenericInternalRow(
-      StateTable.METADATA_COLUMNS.map(_.name()).map {
-        case "_partition_id" => partition.partition.asInstanceOf[Any]
-      }.toArray
-    )
-    joinedRow.withLeft(row).withRight(metadataRow)
-  }
-
-  override def get(): InternalRow = addMetadata(current)
+  override def get(): InternalRow = current
 
   override def close(): Unit = {
     current = null
@@ -169,9 +158,10 @@ class StreamStreamJoinStatePartitionReader(
   }
 
   private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
-    val row = new GenericInternalRow(2)
+    val row = new GenericInternalRow(3)
     row.update(0, pair._1)
     row.update(1, pair._2)
+    row.update(2, partition.partition)
     row
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 86c3ab70af68..c800168b507a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -687,7 +687,7 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
     }
   }
 
-  test("metadata column") {
+  test("partition_id column") {
     withTempDir { tempDir =>
       import testImplicits._
       val stream = MemoryStream[Int]
@@ -712,14 +712,11 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
         // skip version and operator ID to test out functionalities
         .load()
 
-      assert(!stateReadDf.schema.exists(_.name == "_partition_id"),
-      "metadata column should not be exposed until it is explicitly 
specified!")
-
       val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
 
       val resultDf = stateReadDf
-        .selectExpr("key.value AS key_value", "value.count AS value_count", 
"_partition_id")
-        .where("_partition_id % 2 = 0")
+        .selectExpr("key.value AS key_value", "value.count AS value_count", 
"partition_id")
+        .where("partition_id % 2 = 0")
 
       // NOTE: This is a hash function of distribution for stateful operator.
       val hash = HashPartitioning(
@@ -738,17 +735,12 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
     }
   }
 
-  test("metadata column with stream-stream join") {
+  test("partition_id column with stream-stream join") {
     val numShufflePartitions = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS)
 
     withTempDir { tempDir =>
       runStreamStreamJoinQueryWithOneThousandInputs(tempDir.getAbsolutePath)
 
-      def assertPartitionIdColumnIsNotExposedByDefault(df: DataFrame): Unit = {
-        assert(!df.schema.exists(_.name == "_partition_id"),
-          "metadata column should not be exposed until it is explicitly 
specified!")
-      }
-
       def assertPartitionIdColumn(df: DataFrame): Unit = {
         // NOTE: This is a hash function of distribution for stateful operator.
         // stream-stream join uses the grouping key for the equality match in 
the join condition.
@@ -759,8 +751,8 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
           numShufflePartitions)
         val partIdExpr = hash.partitionIdExpression
 
-        val dfWithPartition = df.selectExpr("key.field0 As key_0", 
"_partition_id")
-          .where("_partition_id % 2 = 0")
+        val dfWithPartition = df.selectExpr("key.field0 As key_0", 
"partition_id")
+          .where("partition_id % 2 = 0")
 
         checkAnswer(dfWithPartition,
           Range.inclusive(2, 1000, 2).map { idx =>
@@ -778,8 +770,6 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
           .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
           .option(StateSourceOptions.JOIN_SIDE, side)
           .load()
-
-        assertPartitionIdColumnIsNotExposedByDefault(stateReaderForLeft)
         assertPartitionIdColumn(stateReaderForLeft)
 
         val stateReaderForKeyToNumValues = spark.read
@@ -789,7 +779,7 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
             s"$side-keyToNumValues")
           .load()
 
-        
assertPartitionIdColumnIsNotExposedByDefault(stateReaderForKeyToNumValues)
+
         assertPartitionIdColumn(stateReaderForKeyToNumValues)
 
         val stateReaderForKeyWithIndexToValue = spark.read
@@ -799,7 +789,6 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
             s"$side-keyWithIndexToValue")
           .load()
 
-        
assertPartitionIdColumnIsNotExposedByDefault(stateReaderForKeyWithIndexToValue)
         assertPartitionIdColumn(stateReaderForKeyWithIndexToValue)
       }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to