This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 948873420e9 [SPARK-42556][CONNECT] Dataset.colregex should link a plan_id when it only matches a single column 948873420e9 is described below commit 948873420e99f728e18a25890ad375cdd39afe59 Author: Jiaan Geng <belie...@163.com> AuthorDate: Sat Mar 4 13:45:56 2023 +0800 [SPARK-42556][CONNECT] Dataset.colregex should link a plan_id when it only matches a single column ### What changes were proposed in this pull request? When colregex returns a single column it should link the plans plan_id. For reference here is the non-connect Dataset code that does this: https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L1512 This also needs to be fixed for the Python client. ### Why are the changes needed? Let the `UnresolvedAttribute` link plan_id if it is exist. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. Closes #40265 from beliefer/SPARK-42556. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> (cherry picked from commit c99a632fea74136964b27b28563115fe2d7667b3) Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 21 +++++++----- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 5 +++ .../main/protobuf/spark/connect/expressions.proto | 3 ++ .../resources/query-tests/queries/colRegex.json | 3 +- .../query-tests/queries/colRegex.proto.bin | Bin 60 -> 62 bytes .../sql/connect/planner/SparkConnectPlanner.scala | 6 +++- python/pyspark/sql/connect/dataframe.py | 5 ++- python/pyspark/sql/connect/expressions.py | 10 +++--- .../pyspark/sql/connect/proto/expressions_pb2.py | 38 ++++++++++----------- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 16 ++++++++- 10 files changed, 72 insertions(+), 35 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 1cd3c541950..e264f1c0c0c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -917,6 +917,13 @@ class Dataset[T] private[sql] ( .addAllParameters(parameters.map(p => functions.lit(p).expr).asJava) } + private def getPlanId: Option[Long] = + if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) { + Option(plan.getRoot.getCommon.getPlanId) + } else { + None + } + /** * Selects column based on the column name and returns it as a [[Column]]. * @@ -927,12 +934,7 @@ class Dataset[T] private[sql] ( * @since 3.4.0 */ def col(colName: String): Column = { - val planId = if (plan.getRoot.hasCommon && plan.getRoot.getCommon.hasPlanId) { - Option(plan.getRoot.getCommon.getPlanId) - } else { - None - } - Column.apply(colName, planId) + Column.apply(colName, getPlanId) } /** @@ -940,8 +942,11 @@ class Dataset[T] private[sql] ( * @group untypedrel * @since 3.4.0 */ - def colRegex(colName: String): Column = Column { builder => - builder.getUnresolvedRegexBuilder.setColName(colName) + def colRegex(colName: String): Column = { + Column { builder => + val unresolvedRegexBuilder = builder.getUnresolvedRegexBuilder.setColName(colName) + getPlanId.foreach(unresolvedRegexBuilder.setPlanId) + } } /** diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index a3f1de55892..5c35ef448be 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -493,6 +493,11 @@ class ClientE2ETestSuite extends RemoteSparkSession { val right = spark.range(100).select(col("id"), rand(12).as("a")) val joined = left.join(right, left("id") === right("id")).select(left("id"), right("a")) assert(joined.schema.catalogString === "struct<id:bigint,a:double>") + + val joined2 = left + .join(right, left.colRegex("id") === right.colRegex("id")) + .select(left("id"), right("a")) + assert(joined2.schema.catalogString === "struct<id:bigint,a:double>") } test("test temp view") { diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 1929d9cdca3..e37a13ee959 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -242,6 +242,9 @@ message Expression { message UnresolvedRegex { // (Required) The column name used to extract column with regex. string col_name = 1; + + // (Optional) The id of corresponding connect plan. + optional int64 plan_id = 2; } // Extracts a value or values from an Expression diff --git a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json index 56021594c88..3a7508b63a9 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.json @@ -13,7 +13,8 @@ }, "expressions": [{ "unresolvedRegex": { - "colName": "`a|id`" + "colName": "`a|id`", + "planId": "1" } }] } diff --git a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin index 2f3ab10233e..ce518b35fbd 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/colRegex.proto.bin differ diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index c8b1b3125f9..76a4c7faaa2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1179,7 +1179,11 @@ class SparkConnectPlanner(val session: SparkSession) { case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive) case _ => - UnresolvedAttribute.quotedString(regex.getColName) + val expr = UnresolvedAttribute.quotedString(regex.getColName) + if (regex.hasPlanId) { + expr.setTagValue(LogicalPlan.PLAN_ID_TAG, regex.getPlanId) + } + expr } } diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 955186787a6..471dbf89582 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -153,7 +153,10 @@ class DataFrame: error_class="NOT_STR", message_parameters={"arg_name": "colName", "arg_type": type(colName).__name__}, ) - return Column(UnresolvedRegex(colName)) + if self._plan is not None: + return Column(UnresolvedRegex(colName, self._plan._plan_id)) + else: + return Column(UnresolvedRegex(colName)) colRegex.__doc__ = PySparkDataFrame.colRegex.__doc__ diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index f3c9e2c70c4..2b1901167c1 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -642,18 +642,20 @@ class UnresolvedExtractValue(Expression): class UnresolvedRegex(Expression): - def __init__( - self, - col_name: str, - ) -> None: + def __init__(self, col_name: str, plan_id: Optional[int] = None) -> None: super().__init__() assert isinstance(col_name, str) self.col_name = col_name + assert plan_id is None or isinstance(plan_id, int) + self._plan_id = plan_id + def to_plan(self, session: "SparkConnectClient") -> proto.Expression: expr = proto.Expression() expr.unresolved_regex.col_name = self.col_name + if self._plan_id is not None: + expr.unresolved_regex.plan_id = self._plan_id return expr def __repr__(self) -> str: diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 891be5ea9ea..6e515235c7d 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbc%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xe6%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) @@ -300,7 +300,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 4901 + _EXPRESSION._serialized_end = 4943 _EXPRESSION_WINDOW._serialized_start = 1475 _EXPRESSION_WINDOW._serialized_end = 2258 _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765 @@ -332,21 +332,21 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4088 _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4170 _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4172 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4216 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4219 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4351 - _EXPRESSION_UPDATEFIELDS._serialized_start = 4354 - _EXPRESSION_UPDATEFIELDS._serialized_end = 4541 - _EXPRESSION_ALIAS._serialized_start = 4543 - _EXPRESSION_ALIAS._serialized_end = 4663 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4666 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4824 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4826 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4888 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4904 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5215 - _PYTHONUDF._serialized_start = 5218 - _PYTHONUDF._serialized_end = 5348 - _SCALARSCALAUDF._serialized_start = 5351 - _SCALARSCALAUDF._serialized_end = 5535 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4258 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4261 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4393 + _EXPRESSION_UPDATEFIELDS._serialized_start = 4396 + _EXPRESSION_UPDATEFIELDS._serialized_end = 4583 + _EXPRESSION_ALIAS._serialized_start = 4585 + _EXPRESSION_ALIAS._serialized_end = 4705 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4708 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4866 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4868 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4930 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4946 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5257 + _PYTHONUDF._serialized_start = 5260 + _PYTHONUDF._serialized_end = 5390 + _SCALARSCALAUDF._serialized_start = 5393 + _SCALARSCALAUDF._serialized_end = 5577 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 88b1fd8ef7e..996de7fef2d 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -753,16 +753,30 @@ class Expression(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor COL_NAME_FIELD_NUMBER: builtins.int + PLAN_ID_FIELD_NUMBER: builtins.int col_name: builtins.str """(Required) The column name used to extract column with regex.""" + plan_id: builtins.int + """(Optional) The id of corresponding connect plan.""" def __init__( self, *, col_name: builtins.str = ..., + plan_id: builtins.int | None = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"], + ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["col_name", b"col_name"] + self, + field_name: typing_extensions.Literal[ + "_plan_id", b"_plan_id", "col_name", b"col_name", "plan_id", b"plan_id" + ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"] + ) -> typing_extensions.Literal["plan_id"] | None: ... class UnresolvedExtractValue(google.protobuf.message.Message): """Extracts a value or values from an Expression""" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org