This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1cdd5fa68715 [SPARK-48736][PYTHON] Support infra fro additional 
includes for Python UDFs
1cdd5fa68715 is described below

commit 1cdd5fa68715970c9c465e2383a9445610b47f93
Author: Martin Grund <[email protected]>
AuthorDate: Fri Jun 28 08:39:16 2024 +0900

    [SPARK-48736][PYTHON] Support infra fro additional includes for Python UDFs
    
    ### What changes were proposed in this pull request?
    The interface of the SimplePython function already supports specifying 
additional includes. However, the proto for the clients and the handling in the 
planner has been missing. This patch simply adds the base infrastructure for 
that.
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47120 from grundprinzip/SPARK-48736.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../main/protobuf/spark/connect/expressions.proto    |  2 ++
 .../sql/connect/planner/SparkConnectPlanner.scala    |  4 +++-
 python/pyspark/sql/connect/proto/expressions_pb2.py  | 20 ++++++++++----------
 python/pyspark/sql/connect/proto/expressions_pb2.pyi |  9 +++++++++
 4 files changed, 24 insertions(+), 11 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 404a2fdcb2e8..257634813e74 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -373,6 +373,8 @@ message PythonUDF {
   bytes command = 3;
   // (Required) Python version being used in the client.
   string python_ver = 4;
+  // (Optional) Additional includes for the Python UDF.
+  repeated string additional_includes = 5;
 }
 
 message ScalarScalaUDF {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a7fc87d8b65d..eaeb1c775ddb 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1764,8 +1764,10 @@ class SparkConnectPlanner(
       command = fun.getCommand.toByteArray.toImmutableArraySeq,
       // Empty environment variables
       envVars = Maps.newHashMap(),
-      pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
       pythonExec = pythonExec,
+      // Merge the user specified includes with the includes managed by the 
artifact manager.
+      pythonIncludes = (fun.getAdditionalIncludesList.asScala.toSeq ++
+        sessionHolder.artifactManager.getPythonIncludes).asJava,
       pythonVer = fun.getPythonVer,
       // Empty broadcast variables
       broadcastVars = Lists.newArrayList(),
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 521e15d6950b..c8a183105fd1 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import common_pb2 as 
spark_dot_connect_dot_common
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x97/\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAtt
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto\x1a\x1aspark/connect/common.proto"\x97/\n\nExpression\x12\x37\n\x06\x63ommon\x18\x12
 
\x01(\x0b\x32\x1f.spark.connect.ExpressionCommonR\x06\x63ommon\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAtt
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -103,13 +103,13 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 6242
     _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6606
     _PYTHONUDF._serialized_start = 6609
-    _PYTHONUDF._serialized_end = 6764
-    _SCALARSCALAUDF._serialized_start = 6767
-    _SCALARSCALAUDF._serialized_end = 6981
-    _JAVAUDF._serialized_start = 6984
-    _JAVAUDF._serialized_end = 7133
-    _CALLFUNCTION._serialized_start = 7135
-    _CALLFUNCTION._serialized_end = 7243
-    _NAMEDARGUMENTEXPRESSION._serialized_start = 7245
-    _NAMEDARGUMENTEXPRESSION._serialized_end = 7337
+    _PYTHONUDF._serialized_end = 6813
+    _SCALARSCALAUDF._serialized_start = 6816
+    _SCALARSCALAUDF._serialized_end = 7030
+    _JAVAUDF._serialized_start = 7033
+    _JAVAUDF._serialized_end = 7182
+    _CALLFUNCTION._serialized_start = 7184
+    _CALLFUNCTION._serialized_end = 7292
+    _NAMEDARGUMENTEXPRESSION._serialized_start = 7294
+    _NAMEDARGUMENTEXPRESSION._serialized_end = 7386
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index eaf4059b2dbc..42031d47bb85 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1466,6 +1466,7 @@ class PythonUDF(google.protobuf.message.Message):
     EVAL_TYPE_FIELD_NUMBER: builtins.int
     COMMAND_FIELD_NUMBER: builtins.int
     PYTHON_VER_FIELD_NUMBER: builtins.int
+    ADDITIONAL_INCLUDES_FIELD_NUMBER: builtins.int
     @property
     def output_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
         """(Required) Output type of the Python UDF"""
@@ -1475,6 +1476,11 @@ class PythonUDF(google.protobuf.message.Message):
     """(Required) The encoded commands of the Python UDF"""
     python_ver: builtins.str
     """(Required) Python version being used in the client."""
+    @property
+    def additional_includes(
+        self,
+    ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+        """(Optional) Additional includes for the Python UDF."""
     def __init__(
         self,
         *,
@@ -1482,6 +1488,7 @@ class PythonUDF(google.protobuf.message.Message):
         eval_type: builtins.int = ...,
         command: builtins.bytes = ...,
         python_ver: builtins.str = ...,
+        additional_includes: collections.abc.Iterable[builtins.str] | None = 
...,
     ) -> None: ...
     def HasField(
         self, field_name: typing_extensions.Literal["output_type", 
b"output_type"]
@@ -1489,6 +1496,8 @@ class PythonUDF(google.protobuf.message.Message):
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "additional_includes",
+            b"additional_includes",
             "command",
             b"command",
             "eval_type",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to