allisonwang-db commented on code in PR #42422:
URL: https://github.com/apache/spark/pull/42422#discussion_r1293793008
##########
python/pyspark/sql/functions.py:
##########
@@ -15623,6 +15629,38 @@ def udtf(
| 1| x|
+---+---+
+ UDTF can use keyword arguments:
+
+ >>> @udtf
+ ... class TestUDTFWithKwargs:
+ ... @staticmethod
+ ... def analyze(
+ ... a: AnalyzeArgument, b: AnalyzeArgument, **kwargs:
AnalyzeArgument
+ ... ) -> AnalyzeResult:
+ ... return AnalyzeResult(
+ ... StructType().add("a", a.data_type)
+ ... .add("b", b.data_type)
+ ... .add("x", kwargs["x"].data_type)
+ ... )
+ ...
+ ... def eval(self, a, b, **kwargs):
+ ... yield a, b, kwargs["x"]
+ ...
+ >>> TestUDTFWithKwargs(lit(1), x=lit("x"), b=lit("b")).show()
+ +---+---+---+
+ | a| b| x|
+ +---+---+---+
+ | 1| b| x|
+ +---+---+---+
+
+ >>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs)
+ >>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show()
Review Comment:
Super nit: SELECT * FROM test_udtf(1, x => 'x', b => 'b')
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala:
##########
@@ -273,9 +289,13 @@ object UserDefinedPythonTableFunction {
case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly
(crashed)", eof)
} finally {
- if (!releasedOrClosed) {
- // An error happened. Force to close the worker.
- env.destroyPythonWorker(pythonExec, workerModule,
envVars.asScala.toMap, worker)
+ try {
+ bufferStream.close()
+ } finally {
+ if (!releasedOrClosed) {
+ // An error happened. Force to close the worker.
+ env.destroyPythonWorker(pythonExec, workerModule,
envVars.asScala.toMap, worker)
+ }
Review Comment:
Just curious, why do we need to change this part?
##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -1795,6 +1796,93 @@ def terminate(self):
assertSchemaEqual(df.schema, StructType().add("col1",
IntegerType()))
assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
+ def test_udtf_with_named_arguments(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a, b):
+ yield a,
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ for i, df in enumerate(
+ [
+ self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
+ self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
+ TestUDTF(a=lit(10), b=lit("x")),
+ TestUDTF(b=lit("x"), a=lit(10)),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(a=10)])
+
+ def test_udtf_with_named_arguments_negative(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a, b):
Review Comment:
What if the UDTF already has default values for its parameters? For instance:
```
def eval(self, a, b = 1):
```
Maybe we can add one more test case for this.
##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -1795,6 +1796,93 @@ def terminate(self):
assertSchemaEqual(df.schema, StructType().add("col1",
IntegerType()))
assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
+ def test_udtf_with_named_arguments(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a, b):
+ yield a,
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ for i, df in enumerate(
+ [
+ self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
+ self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
+ TestUDTF(a=lit(10), b=lit("x")),
+ TestUDTF(b=lit("x"), a=lit(10)),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(a=10)])
+
+ def test_udtf_with_named_arguments_negative(self):
+ @udtf(returnType="a: int")
+ class TestUDTF:
+ def eval(self, a, b):
+ yield a,
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ with self.assertRaisesRegex(
+ AnalysisException,
+
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ ):
+ self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show()
+
+ with self.assertRaisesRegex(AnalysisException,
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+ self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show()
+
+ with self.assertRaisesRegex(
+ PythonException, r"eval\(\) got an unexpected keyword argument 'c'"
+ ):
+ self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show()
+
+ def test_udtf_with_kwargs(self):
+ @udtf(returnType="a: int, b: string")
+ class TestUDTF:
+ def eval(self, **kwargs):
+ yield kwargs["a"], kwargs["b"]
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ for i, df in enumerate(
+ [
+ self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
+ self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
+ TestUDTF(a=lit(10), b=lit("x")),
+ TestUDTF(b=lit("x"), a=lit(10)),
+ ]
+ ):
+ with self.subTest(query_no=i):
+ assertDataFrameEqual(df, [Row(a=10, b="x")])
+
+ def test_udtf_with_analyze_kwargs(self):
+ @udtf
+ class TestUDTF:
+ @staticmethod
+ def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult:
+ return AnalyzeResult(
+ StructType(
+ [StructField(key, arg.data_type) for key, arg in
sorted(kwargs.items())]
+ )
+ )
+
+ def eval(self, **kwargs):
+ yield tuple(value for _, value in sorted(kwargs.items()))
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ for i, df in enumerate(
+ [
+ self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
+ self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
Review Comment:
Does named arguments support lateral references? Can we add a few more tests
for lateral joins?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]