[
https://issues.apache.org/jira/browse/SPARK-53426?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Dongjoon Hyun closed SPARK-53426.
---------------------------------
> Support named table argument with asTable() API
> -----------------------------------------------
>
> Key: SPARK-53426
> URL: https://issues.apache.org/jira/browse/SPARK-53426
> Project: Spark
> Issue Type: Sub-task
> Components: PySpark
> Affects Versions: 4.1.0
> Reporter: Allison Wang
> Priority: Major
>
> Named table argument does not work for Python UDTFs table arguments API:
> {code:java}
> def test_arrow_udtf_with_named_arguments(self):
> @arrow_udtf(returnType="result_id bigint, multiplier_used int")
> class NamedArgsUDTF:
> def eval(
> self,
> table_data: "pa.RecordBatch",
> multiplier: "pa.Array"
> ) -> Iterator["pa.Table"]:
> assert isinstance(
> table_data, pa.RecordBatch
> ), f"Expected pa.RecordBatch for table_data, got
> {type(table_data)}"
> assert isinstance(
> multiplier, pa.Array
> ), f"Expected pa.Array for multiplier, got
> {type(multiplier)}" multiplier_val = multiplier[0].as_py()
> # Convert record batch to table
> table = pa.table(table_data)
> id_column = table.column("id") # Multiply each
> id by the multiplier
> multiplied_ids = pa.compute.multiply(id_column,
> pa.scalar(multiplier_val)) result_table = pa.table({
> "result_id": multiplied_ids,
> "multiplier_used": pa.array([multiplier_val] *
> table.num_rows, type=pa.int32())
> })
> yield result_table # Test with DataFrame API using
> named arguments
> input_df = self.spark.range(3) # [0, 1, 2]
> > result_df = NamedArgsUDTF(table_data=input_df.asTable(),
> > multiplier=lit(5))python/pyspark/sql/tests/arrow/test_arrow_udtf.py:812:
> _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> _ _ _ _ _ _ _ _ _ _ _ _ _
> python/pyspark/sql/udtf.py:450: in __call__
> j_named_arg = sc._jvm.PythonSQLUtils.namedArgumentExpression(key, j_arg)
> ../../.virtualenvs/spark/lib/python3.11/site-packages/py4j/java_gateway.py:1322:
> in __call__
> return_value = get_return_value(
> python/pyspark/errors/exceptions/captured.py:288: in deco
> return f(*a, **kw)
> _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> _ _ _ _ _ _ _ _ _ _ _ _ _answer = 'xspy4j.Py4JException: Method
> namedArgumentExpression([class java.lang.String, class
> org.apache.spark.sql.TableArg])
> d....ClientServerConnection.run(ClientServerConnection.java:108)\\n\tat
> java.base/java.lang.Thread.run(Thread.java:840)\\n'
> gateway_client = <py4j.clientserver.JavaClient object at 0x109e20090>
> target_id = 'z:org.apache.spark.sql.api.python.PythonSQLUtils', name =
> 'namedArgumentExpression' def get_return_value(answer, gateway_client,
> target_id=None, name=None):
> """Converts an answer received from the Java gateway into a Python
> object. For example, string representation of integers are converted
> to Python
> integer, string representation of objects are converted to JavaObject
> instances, etc. :param answer: the string returned by the Java
> gateway
> :param gateway_client: the gateway client used to communicate with
> the Java
> Gateway. Only necessary if the answer is a reference (e.g.,
> object,
> list, map)
> :param target_id: the name of the object from which the answer comes
> from
> (e.g., *object1* in `object1.hello()`). Optional.
> :param name: the name of the member from which the answer comes from
> (e.g., *hello* in `object1.hello()`). Optional.
> """
> if is_error(answer)[0]:
> if len(answer) > 1:
> type = answer[1]
> value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
> if answer[1] == REFERENCE_TYPE:
> raise Py4JJavaError(
> "An error occurred while calling {0}{1}{2}.\n".
> format(target_id, ".", name), value)
> else:
> > raise Py4JError(
> "An error occurred while calling {0}{1}{2}.
> Trace:\n{3}\n".
> format(target_id, ".", name, value))
> E py4j.protocol.Py4JError: An error occurred while calling
> z:org.apache.spark.sql.api.python.PythonSQLUtils.namedArgumentExpression.
> Trace:
> E py4j.Py4JException: Method namedArgumentExpression([class
> java.lang.String, class org.apache.spark.sql.TableArg]) does not exist
> E at
> py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:321)
> E at
> py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:342)
> E at py4j.Gateway.invoke(Gateway.java:276)
> E at
> py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
> E at
> py4j.commands.CallCommand.execute(CallCommand.java:79)
> E at
> py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:184)
> E at
> py4j.ClientServerConnection.run(ClientServerConnection.java:108)
> E at java.base/java.lang.Thread.run(Thread.java:840)
> {code}
>
>
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]