jingz-db commented on code in PR #48000:
URL: https://github.com/apache/spark/pull/48000#discussion_r1750971828


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -81,6 +84,112 @@ object SchemaUtil {
     row
   }
 
+  /**
+   * For map state variables, state rows are stored as composite key.
+   * To return grouping key -> Map{user key -> value} as one state reader row 
to
+   * the users, we need to perform grouping on state rows by their grouping 
key,
+   * and construct a map for that grouping key.
+   *
+   * We traverse the iterator returned from state store,
+   * and will only return a row for `next()` only if the grouping key in the 
next row
+   * from state store is different (or there are no more rows)
+   *
+   * Note that all state rows with the same grouping key are co-located so 
they will
+   * appear consecutively during the iterator traversal.
+   */
+  def unifyMapStateRowPair(
+      stateRows: Iterator[UnsafeRowPair],
+      compositeKeySchema: StructType,
+      partitionId: Int): Iterator[InternalRow] = {
+    val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
+      compositeKeySchema, "key"
+    ).asInstanceOf[StructType]
+    val userKeySchema = SchemaUtil.getSchemaAsDataType(
+      compositeKeySchema, "userKey"
+    ).asInstanceOf[StructType]
+
+    def appendKVPairToMap(
+        curMap: mutable.Map[Any, Any],
+        stateRowPair: UnsafeRowPair): Unit = {
+      curMap += (
+        stateRowPair.key.get(1, userKeySchema)
+          .asInstanceOf[UnsafeRow].copy() ->
+          stateRowPair.value.copy()
+        )
+    }
+
+    def updateDataRow(
+        groupingKey: Any,
+        curMap: mutable.Map[Any, Any]): GenericInternalRow = {
+      val row = new GenericInternalRow(3)
+      val mapData = new ArrayBasedMapData(
+        ArrayData.toArrayData(curMap.keys.toArray),
+        ArrayData.toArrayData(curMap.values.toArray)
+      )
+      row.update(0, groupingKey)
+      row.update(1, mapData)
+      row.update(2, partitionId)
+      row
+    }
+
+    // All of the rows with the same grouping key were co-located and were
+    // grouped together consecutively.
+    new Iterator[InternalRow] {
+      var curGroupingKey: UnsafeRow = _
+      var curStateRowPair: UnsafeRowPair = _
+      val curMap = mutable.Map.empty[Any, Any]
+
+      override def hasNext: Boolean =
+        stateRows.hasNext || !curMap.isEmpty
+
+      override def next(): InternalRow = {
+        var keepTraverse = true
+        while (stateRows.hasNext && keepTraverse) {
+          curStateRowPair = stateRows.next()
+          if (curGroupingKey == null) {
+            // First time in the iterator
+            // Need to make a copy because we need to keep the
+            // value across function calls
+            curGroupingKey = curStateRowPair.key
+              .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy()
+            appendKVPairToMap(curMap, curStateRowPair)
+          } else {
+            val curPairGroupingKey =
+              curStateRowPair.key.get(0, groupingKeySchema)
+            if (curPairGroupingKey == curGroupingKey) {
+              appendKVPairToMap(curMap, curStateRowPair)
+            } else {
+              // find a different grouping key, exit loop and return a row
+              keepTraverse = false
+            }
+          }
+        }
+        if (!keepTraverse) {
+          // found a different grouping key
+          val row = updateDataRow(curGroupingKey, curMap)
+          // update vars
+          curGroupingKey =
+            curStateRowPair.key.get(0, groupingKeySchema)
+              .asInstanceOf[UnsafeRow].copy()
+          // empty the map, append current row
+          curMap.clear()
+          appendKVPairToMap(curMap, curStateRowPair)
+          // return map value of previous grouping key
+          row
+        } else {
+          // reach the end of the state rows
+          if (curMap.isEmpty) null.asInstanceOf[InternalRow]

Review Comment:
   Sure, that makes a lot of sense. I modified this to be throw 
`NoSuchElementException` the same as other iterators.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -81,6 +84,112 @@ object SchemaUtil {
     row
   }
 
+  /**
+   * For map state variables, state rows are stored as composite key.
+   * To return grouping key -> Map{user key -> value} as one state reader row 
to
+   * the users, we need to perform grouping on state rows by their grouping 
key,
+   * and construct a map for that grouping key.
+   *
+   * We traverse the iterator returned from state store,
+   * and will only return a row for `next()` only if the grouping key in the 
next row
+   * from state store is different (or there are no more rows)
+   *
+   * Note that all state rows with the same grouping key are co-located so 
they will
+   * appear consecutively during the iterator traversal.
+   */
+  def unifyMapStateRowPair(
+      stateRows: Iterator[UnsafeRowPair],
+      compositeKeySchema: StructType,
+      partitionId: Int): Iterator[InternalRow] = {
+    val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
+      compositeKeySchema, "key"
+    ).asInstanceOf[StructType]
+    val userKeySchema = SchemaUtil.getSchemaAsDataType(
+      compositeKeySchema, "userKey"
+    ).asInstanceOf[StructType]
+
+    def appendKVPairToMap(
+        curMap: mutable.Map[Any, Any],
+        stateRowPair: UnsafeRowPair): Unit = {
+      curMap += (
+        stateRowPair.key.get(1, userKeySchema)
+          .asInstanceOf[UnsafeRow].copy() ->
+          stateRowPair.value.copy()
+        )
+    }
+
+    def updateDataRow(
+        groupingKey: Any,
+        curMap: mutable.Map[Any, Any]): GenericInternalRow = {
+      val row = new GenericInternalRow(3)
+      val mapData = new ArrayBasedMapData(
+        ArrayData.toArrayData(curMap.keys.toArray),
+        ArrayData.toArrayData(curMap.values.toArray)
+      )
+      row.update(0, groupingKey)
+      row.update(1, mapData)
+      row.update(2, partitionId)
+      row
+    }
+
+    // All of the rows with the same grouping key were co-located and were
+    // grouped together consecutively.
+    new Iterator[InternalRow] {
+      var curGroupingKey: UnsafeRow = _
+      var curStateRowPair: UnsafeRowPair = _
+      val curMap = mutable.Map.empty[Any, Any]
+
+      override def hasNext: Boolean =
+        stateRows.hasNext || !curMap.isEmpty
+
+      override def next(): InternalRow = {
+        var keepTraverse = true
+        while (stateRows.hasNext && keepTraverse) {
+          curStateRowPair = stateRows.next()
+          if (curGroupingKey == null) {
+            // First time in the iterator
+            // Need to make a copy because we need to keep the
+            // value across function calls
+            curGroupingKey = curStateRowPair.key
+              .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy()
+            appendKVPairToMap(curMap, curStateRowPair)
+          } else {
+            val curPairGroupingKey =
+              curStateRowPair.key.get(0, groupingKeySchema)
+            if (curPairGroupingKey == curGroupingKey) {
+              appendKVPairToMap(curMap, curStateRowPair)
+            } else {
+              // find a different grouping key, exit loop and return a row
+              keepTraverse = false
+            }
+          }
+        }
+        if (!keepTraverse) {
+          // found a different grouping key
+          val row = updateDataRow(curGroupingKey, curMap)
+          // update vars
+          curGroupingKey =
+            curStateRowPair.key.get(0, groupingKeySchema)
+              .asInstanceOf[UnsafeRow].copy()
+          // empty the map, append current row
+          curMap.clear()
+          appendKVPairToMap(curMap, curStateRowPair)
+          // return map value of previous grouping key
+          row
+        } else {
+          // reach the end of the state rows
+          if (curMap.isEmpty) null.asInstanceOf[InternalRow]

Review Comment:
   Sure, that makes a lot of sense. I modified this to throw 
`NoSuchElementException` the same as other iterators.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to