This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3f333a0820a9 [SPARK-50642][CONNECT][SS] Fix the state schema for
FlatMapGroupsWithState in spark connect when there is no initial state
3f333a0820a9 is described below
commit 3f333a0820a991a7642632a49e430843840b75ee
Author: huanliwang-db <[email protected]>
AuthorDate: Thu Jan 2 12:01:47 2025 +0900
[SPARK-50642][CONNECT][SS] Fix the state schema for FlatMapGroupsWithState
in spark connect when there is no initial state
In spark connect, when there is no initial state, we derived the state
schema from the input:
create the initialDs from the original input:
https://github.com/apache/spark/blob/master/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L679-L689
derived the state expression encoder from this `initialDs` which is
incorrect:
https://github.com/apache/spark/blob/master/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala#L729
Our unit tests fail to cover this case because it doesn't do the state
update:
https://github.com/apache/spark/blob/master/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala#L55-L59
after changing the `stateFunc` to the following
```
val stateFunc =
(key: String, values: Iterator[ClickEvent], state:
GroupState[ClickState]) => {
if (state.exists) throw new IllegalArgumentException("state.exists
should be false")
val newState = ClickState(key, values.size)
state.update(newState)
Iterator(newState)
}
```
the test is actually failing with
```
Cause: org.apache.spark.SparkException: Job aborted due to stage failure:
Task 122 in stage 2.0 failed 1 times, most recent failure: Lost task 122.0 in
stage 2.0 (TID 12) (192.168.68.84 executor driver):
java.lang.ClassCastException: class org.apache.spark.sql.streaming.ClickState
cannot be cast to class org.apache.spark.sql.streaming.ClickEvent
(org.apache.spark.sql.streaming.ClickState and
org.apache.spark.sql.streaming.ClickEvent are in unnamed module of loader 'app')
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.CreateNamedStruct_0$(Unknown
Source)
```
### What changes were proposed in this pull request?
* introduce a new `state_schema` proto field
* pass the state agnostic encoder to the serialized udf
* pass the state schema to query proto for spark connect
* rebuild the state expression encoder based on the state agnostic encoder
and state schema.
### Why are the changes needed?
fix the broken behavior for flatMapGroupsWithState on spark connect
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
modified the existing unit tests.
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49260 from huanliwang-db/huanliwang-db/fmgws-client.
Lead-authored-by: huanliwang-db <[email protected]>
Co-authored-by: Huanli Wang <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../apache/spark/sql/KeyValueGroupedDataset.scala | 21 +++++++++-
.../FlatMapGroupsWithStateStreamingSuite.scala | 16 ++++++--
python/pyspark/sql/connect/proto/relations_pb2.py | 48 +++++++++++-----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 17 ++++++++
.../main/protobuf/spark/connect/relations.proto | 3 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 31 ++++++++++----
6 files changed, 99 insertions(+), 37 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 63b5f27c4745..d5505d2222c4 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -27,7 +27,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor,
ProductEncoder}
import org.apache.spark.sql.connect.ConnectConversions._
-import org.apache.spark.sql.connect.common.UdfUtils
+import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfUtils}
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.toExpr
@@ -502,6 +502,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
}
val outputEncoder = agnosticEncoderFor[U]
+ val stateEncoder = agnosticEncoderFor[S]
val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func,
valueMapFunc)
sparkSession.newDataset[U](outputEncoder) { builder =>
@@ -509,11 +510,12 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
groupMapBuilder
.setInput(plan.getRoot)
.addAllGroupingExpressions(groupingExprs)
- .setFunc(getUdf(nf, outputEncoder)(ivEncoder))
+ .setFunc(getUdf(nf, outputEncoder, stateEncoder)(ivEncoder))
.setIsMapGroupsWithState(isMapGroupWithState)
.setOutputMode(if (outputMode.isEmpty) OutputMode.Update.toString
else outputMode.get.toString)
.setTimeoutConf(timeoutConf.toString)
+
.setStateSchema(DataTypeProtoConverter.toConnectProtoType(stateEncoder.schema))
if (initialStateImpl != null) {
groupMapBuilder
@@ -533,6 +535,21 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV](
udf.apply(inputEncoders.map(_ => col("*")):
_*).expr.getCommonInlineUserDefinedFunction
}
+ private def getUdf[U: Encoder, S: Encoder](
+ nf: AnyRef,
+ outputEncoder: AgnosticEncoder[U],
+ stateEncoder: AgnosticEncoder[S])(
+ inEncoders: AgnosticEncoder[_]*): proto.CommonInlineUserDefinedFunction
= {
+ // Apply keyAs changes by setting kEncoder
+ // Add the state encoder to the inputEncoders.
+ val inputEncoders = kEncoder +: stateEncoder +: inEncoders
+ val udf = SparkUserDefinedFunction(
+ function = nf,
+ inputEncoders = inputEncoders,
+ outputEncoder = outputEncoder)
+ udf.apply(inputEncoders.map(_ => col("*")):
_*).expr.getCommonInlineUserDefinedFunction
+ }
+
/**
* We cannot deserialize a connect [[KeyValueGroupedDataset]] because of a
class clash on the
* server side. We null out the instance for now.
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
index dc74463f1a25..9bd6614028cb 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
@@ -55,7 +55,9 @@ class FlatMapGroupsWithStateStreamingSuite extends QueryTest
with RemoteSparkSes
val stateFunc =
(key: String, values: Iterator[ClickEvent], state:
GroupState[ClickState]) => {
if (state.exists) throw new IllegalArgumentException("state.exists
should be false")
- Iterator(ClickState(key, values.size))
+ val newState = ClickState(key, values.size)
+ state.update(newState)
+ Iterator(newState)
}
spark.sql("DROP TABLE IF EXISTS my_sink")
@@ -96,7 +98,9 @@ class FlatMapGroupsWithStateStreamingSuite extends QueryTest
with RemoteSparkSes
val stateFunc =
(key: String, values: Iterator[ClickEvent], state:
GroupState[ClickState]) => {
val currState = state.getOption.getOrElse(ClickState(key, 0))
- Iterator(ClickState(key, currState.count + values.size))
+ val newState = ClickState(key, currState.count + values.size)
+ state.update(newState)
+ Iterator(newState)
}
val initialState = flatMapGroupsWithStateInitialStateData
.toDS()
@@ -141,7 +145,9 @@ class FlatMapGroupsWithStateStreamingSuite extends
QueryTest with RemoteSparkSes
val stateFunc =
(key: String, values: Iterator[ClickEvent], state:
GroupState[ClickState]) => {
if (state.exists) throw new IllegalArgumentException("state.exists
should be false")
- ClickState(key, values.size)
+ val newState = ClickState(key, values.size)
+ state.update(newState)
+ newState
}
spark.sql("DROP TABLE IF EXISTS my_sink")
@@ -183,7 +189,9 @@ class FlatMapGroupsWithStateStreamingSuite extends
QueryTest with RemoteSparkSes
val stateFunc =
(key: String, values: Iterator[ClickEvent], state:
GroupState[ClickState]) => {
val currState = state.getOption.getOrElse(ClickState(key, 0))
- ClickState(key, currState.count + values.size)
+ val newState = ClickState(key, currState.count + values.size)
+ state.update(newState)
+ newState
}
val initialState = flatMapGroupsWithStateInitialStateData
.toDS()
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 506b266f6014..b7248d4b1708 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -42,7 +42,7 @@ from pyspark.sql.connect.proto import common_pb2 as
spark_dot_connect_dot_common
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xdd\x1c\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.Project [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto"\xdd\x1c\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.Project [...]
)
_globals = globals()
@@ -208,29 +208,29 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_MAPPARTITIONS"]._serialized_start = 12857
_globals["_MAPPARTITIONS"]._serialized_end = 13089
_globals["_GROUPMAP"]._serialized_start = 13092
- _globals["_GROUPMAP"]._serialized_end = 13727
- _globals["_COGROUPMAP"]._serialized_start = 13730
- _globals["_COGROUPMAP"]._serialized_end = 14256
- _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 14259
- _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 14616
- _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 14619
- _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 14863
- _globals["_PYTHONUDTF"]._serialized_start = 14866
- _globals["_PYTHONUDTF"]._serialized_end = 15043
- _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 15046
- _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 15197
- _globals["_PYTHONDATASOURCE"]._serialized_start = 15199
- _globals["_PYTHONDATASOURCE"]._serialized_end = 15274
- _globals["_COLLECTMETRICS"]._serialized_start = 15277
- _globals["_COLLECTMETRICS"]._serialized_end = 15413
- _globals["_PARSE"]._serialized_start = 15416
- _globals["_PARSE"]._serialized_end = 15804
+ _globals["_GROUPMAP"]._serialized_end = 13809
+ _globals["_COGROUPMAP"]._serialized_start = 13812
+ _globals["_COGROUPMAP"]._serialized_end = 14338
+ _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 14341
+ _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 14698
+ _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 14701
+ _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 14945
+ _globals["_PYTHONUDTF"]._serialized_start = 14948
+ _globals["_PYTHONUDTF"]._serialized_end = 15125
+ _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 15128
+ _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 15279
+ _globals["_PYTHONDATASOURCE"]._serialized_start = 15281
+ _globals["_PYTHONDATASOURCE"]._serialized_end = 15356
+ _globals["_COLLECTMETRICS"]._serialized_start = 15359
+ _globals["_COLLECTMETRICS"]._serialized_end = 15495
+ _globals["_PARSE"]._serialized_start = 15498
+ _globals["_PARSE"]._serialized_end = 15886
_globals["_PARSE_OPTIONSENTRY"]._serialized_start = 4941
_globals["_PARSE_OPTIONSENTRY"]._serialized_end = 4999
- _globals["_PARSE_PARSEFORMAT"]._serialized_start = 15705
- _globals["_PARSE_PARSEFORMAT"]._serialized_end = 15793
- _globals["_ASOFJOIN"]._serialized_start = 15807
- _globals["_ASOFJOIN"]._serialized_end = 16282
- _globals["_LATERALJOIN"]._serialized_start = 16285
- _globals["_LATERALJOIN"]._serialized_end = 16515
+ _globals["_PARSE_PARSEFORMAT"]._serialized_start = 15787
+ _globals["_PARSE_PARSEFORMAT"]._serialized_end = 15875
+ _globals["_ASOFJOIN"]._serialized_start = 15889
+ _globals["_ASOFJOIN"]._serialized_end = 16364
+ _globals["_LATERALJOIN"]._serialized_start = 16367
+ _globals["_LATERALJOIN"]._serialized_end = 16597
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index e5a6bff9e430..371d735b9e87 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -3409,6 +3409,7 @@ class GroupMap(google.protobuf.message.Message):
IS_MAP_GROUPS_WITH_STATE_FIELD_NUMBER: builtins.int
OUTPUT_MODE_FIELD_NUMBER: builtins.int
TIMEOUT_CONF_FIELD_NUMBER: builtins.int
+ STATE_SCHEMA_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for Group Map API: apply,
applyInPandas."""
@@ -3447,6 +3448,9 @@ class GroupMap(google.protobuf.message.Message):
"""(Optional) The output mode of the function."""
timeout_conf: builtins.str
"""(Optional) Timeout configuration for groups that do not receive data
for a while."""
+ @property
+ def state_schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+ """(Optional) The schema for the grouped state."""
def __init__(
self,
*,
@@ -3469,6 +3473,7 @@ class GroupMap(google.protobuf.message.Message):
is_map_groups_with_state: builtins.bool | None = ...,
output_mode: builtins.str | None = ...,
timeout_conf: builtins.str | None = ...,
+ state_schema: pyspark.sql.connect.proto.types_pb2.DataType | None =
...,
) -> None: ...
def HasField(
self,
@@ -3477,6 +3482,8 @@ class GroupMap(google.protobuf.message.Message):
b"_is_map_groups_with_state",
"_output_mode",
b"_output_mode",
+ "_state_schema",
+ b"_state_schema",
"_timeout_conf",
b"_timeout_conf",
"func",
@@ -3489,6 +3496,8 @@ class GroupMap(google.protobuf.message.Message):
b"is_map_groups_with_state",
"output_mode",
b"output_mode",
+ "state_schema",
+ b"state_schema",
"timeout_conf",
b"timeout_conf",
],
@@ -3500,6 +3509,8 @@ class GroupMap(google.protobuf.message.Message):
b"_is_map_groups_with_state",
"_output_mode",
b"_output_mode",
+ "_state_schema",
+ b"_state_schema",
"_timeout_conf",
b"_timeout_conf",
"func",
@@ -3518,6 +3529,8 @@ class GroupMap(google.protobuf.message.Message):
b"output_mode",
"sorting_expressions",
b"sorting_expressions",
+ "state_schema",
+ b"state_schema",
"timeout_conf",
b"timeout_conf",
],
@@ -3534,6 +3547,10 @@ class GroupMap(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_output_mode",
b"_output_mode"]
) -> typing_extensions.Literal["output_mode"] | None: ...
@typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_state_schema",
b"_state_schema"]
+ ) -> typing_extensions.Literal["state_schema"] | None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_timeout_conf",
b"_timeout_conf"]
) -> typing_extensions.Literal["timeout_conf"] | None: ...
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
index 7a86db279914..5ab9f64149f5 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -975,6 +975,9 @@ message GroupMap {
// (Optional) Timeout configuration for groups that do not receive data for
a while.
optional string timeout_conf = 9;
+
+ // (Optional) The schema for the grouped state.
+ optional DataType state_schema = 10;
}
message CoGroupMap {
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 628b758dd4e2..f4be1d17b0e9 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -667,7 +667,8 @@ class SparkConnectPlanner(
private def transformTypedGroupMap(
rel: proto.GroupMap,
commonUdf: proto.CommonInlineUserDefinedFunction): LogicalPlan = {
- val udf = TypedScalaUdf(commonUdf)
+ val unpackedUdf = unpackUdf(commonUdf)
+ val udf = TypedScalaUdf(unpackedUdf, None)
val ds = UntypedKeyValueGroupedDataset(
rel.getInput,
rel.getGroupingExpressionsList,
@@ -697,6 +698,18 @@ class SparkConnectPlanner(
InternalOutputModes(rel.getOutputMode)
}
+ val stateSchema =
DataTypeProtoConverter.toCatalystType(rel.getStateSchema) match {
+ case s: StructType => s
+ case other =>
+ throw InvalidPlanInput(
+ s"Invalid state schema dataType $other for flatMapGroupsWithState")
+ }
+ val stateEncoder = TypedScalaUdf.encoderFor(
+ // the state agnostic encoder is the second element in the input
encoders.
+ unpackedUdf.inputEncoders.tail.head,
+ "state",
+ Some(DataTypeUtils.toAttributes(stateSchema)))
+
val flatMapGroupsWithState = if (hasInitialState) {
new FlatMapGroupsWithState(
udf.function
@@ -706,7 +719,7 @@ class SparkConnectPlanner(
ds.groupingAttributes,
ds.dataAttributes,
udf.outputObjAttr,
- initialDs.vEncoder.asInstanceOf[ExpressionEncoder[Any]],
+ stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
rel.getIsMapGroupsWithState,
timeoutConf,
@@ -725,7 +738,7 @@ class SparkConnectPlanner(
ds.groupingAttributes,
ds.dataAttributes,
udf.outputObjAttr,
- initialDs.vEncoder.asInstanceOf[ExpressionEncoder[Any]],
+ stateEncoder.asInstanceOf[ExpressionEncoder[Any]],
outputMode,
rel.getIsMapGroupsWithState,
timeoutConf,
@@ -947,10 +960,7 @@ class SparkConnectPlanner(
}
}
- def apply(
- commonUdf: proto.CommonInlineUserDefinedFunction,
- inputAttrs: Option[Seq[Attribute]] = None): TypedScalaUdf = {
- val udf = unpackUdf(commonUdf)
+ def apply(udf: UdfPacket, inputAttrs: Option[Seq[Attribute]]):
TypedScalaUdf = {
// There might be more than one inputs, but we only interested in the
first one.
// Most typed API takes one UDF input.
// For the few that takes more than one inputs, e.g. grouping function
mapping UDFs,
@@ -960,6 +970,13 @@ class SparkConnectPlanner(
TypedScalaUdf(udf.function, udf.outputEncoder, inEnc, inputAttrs)
}
+ def apply(
+ commonUdf: proto.CommonInlineUserDefinedFunction,
+ inputAttrs: Option[Seq[Attribute]] = None): TypedScalaUdf = {
+ val udf = unpackUdf(commonUdf)
+ apply(udf, inputAttrs)
+ }
+
def encoderFor(
encoder: AgnosticEncoder[_],
errorType: String,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]