This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1cdd5fa68715 [SPARK-48736][PYTHON] Support infra fro additional
includes for Python UDFs
1cdd5fa68715 is described below
commit 1cdd5fa68715970c9c465e2383a9445610b47f93
Author: Martin Grund <[email protected]>
AuthorDate: Fri Jun 28 08:39:16 2024 +0900
[SPARK-48736][PYTHON] Support infra fro additional includes for Python UDFs
### What changes were proposed in this pull request?
The interface of the SimplePython function already supports specifying
additional includes. However, the proto for the clients and the handling in the
planner has been missing. This patch simply adds the base infrastructure for
that.
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing UT
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47120 from grundprinzip/SPARK-48736.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/expressions.proto | 2 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 4 +++-
python/pyspark/sql/connect/proto/expressions_pb2.py | 20 ++++++++++----------
python/pyspark/sql/connect/proto/expressions_pb2.pyi | 9 +++++++++
4 files changed, 24 insertions(+), 11 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 404a2fdcb2e8..257634813e74 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -373,6 +373,8 @@ message PythonUDF {
bytes command = 3;
// (Required) Python version being used in the client.
string python_ver = 4;
+ // (Optional) Additional includes for the Python UDF.
+ repeated string additional_includes = 5;
}
message ScalarScalaUDF {
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a7fc87d8b65d..eaeb1c775ddb 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1764,8 +1764,10 @@ class SparkConnectPlanner(
command = fun.getCommand.toByteArray.toImmutableArraySeq,
// Empty environment variables
envVars = Maps.newHashMap(),
- pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = pythonExec,
+ // Merge the user specified includes with the includes managed by the
artifact manager.
+ pythonIncludes = (fun.getAdditionalIncludesList.asScala.toSeq ++
+ sessionHolder.artifactManager.getPythonIncludes).asJava,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
broadcastVars = Lists.newArrayList(),
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 521e15d6950b..c8a183105fd1 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import common_pb2 as
spark_dot_connect_dot_common
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x97/\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAtt
[...]
+
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x97/\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAtt
[...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -103,13 +103,13 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6242
_COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6606
_PYTHONUDF._serialized_start = 6609
- _PYTHONUDF._serialized_end = 6764
- _SCALARSCALAUDF._serialized_start = 6767
- _SCALARSCALAUDF._serialized_end = 6981
- _JAVAUDF._serialized_start = 6984
- _JAVAUDF._serialized_end = 7133
- _CALLFUNCTION._serialized_start = 7135
- _CALLFUNCTION._serialized_end = 7243
- _NAMEDARGUMENTEXPRESSION._serialized_start = 7245
- _NAMEDARGUMENTEXPRESSION._serialized_end = 7337
+ _PYTHONUDF._serialized_end = 6813
+ _SCALARSCALAUDF._serialized_start = 6816
+ _SCALARSCALAUDF._serialized_end = 7030
+ _JAVAUDF._serialized_start = 7033
+ _JAVAUDF._serialized_end = 7182
+ _CALLFUNCTION._serialized_start = 7184
+ _CALLFUNCTION._serialized_end = 7292
+ _NAMEDARGUMENTEXPRESSION._serialized_start = 7294
+ _NAMEDARGUMENTEXPRESSION._serialized_end = 7386
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index eaf4059b2dbc..42031d47bb85 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1466,6 +1466,7 @@ class PythonUDF(google.protobuf.message.Message):
EVAL_TYPE_FIELD_NUMBER: builtins.int
COMMAND_FIELD_NUMBER: builtins.int
PYTHON_VER_FIELD_NUMBER: builtins.int
+ ADDITIONAL_INCLUDES_FIELD_NUMBER: builtins.int
@property
def output_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Required) Output type of the Python UDF"""
@@ -1475,6 +1476,11 @@ class PythonUDF(google.protobuf.message.Message):
"""(Required) The encoded commands of the Python UDF"""
python_ver: builtins.str
"""(Required) Python version being used in the client."""
+ @property
+ def additional_includes(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Optional) Additional includes for the Python UDF."""
def __init__(
self,
*,
@@ -1482,6 +1488,7 @@ class PythonUDF(google.protobuf.message.Message):
eval_type: builtins.int = ...,
command: builtins.bytes = ...,
python_ver: builtins.str = ...,
+ additional_includes: collections.abc.Iterable[builtins.str] | None =
...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["output_type",
b"output_type"]
@@ -1489,6 +1496,8 @@ class PythonUDF(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "additional_includes",
+ b"additional_includes",
"command",
b"command",
"eval_type",
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]