[ 
https://issues.apache.org/jira/browse/SPARK-24721?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Wenchen Fan reassigned SPARK-24721:
-----------------------------------

    Assignee: Li Jin

> Failed to use PythonUDF with literal inputs in filter with data sources
> -----------------------------------------------------------------------
>
>                 Key: SPARK-24721
>                 URL: https://issues.apache.org/jira/browse/SPARK-24721
>             Project: Spark
>          Issue Type: Bug
>          Components: PySpark, SQL
>    Affects Versions: 2.3.1
>            Reporter: Xiao Li
>            Assignee: Li Jin
>            Priority: Major
>             Fix For: 2.4.0
>
>
> {code}
> import random
> from pyspark.sql.functions import *
> from pyspark.sql.types import *
> def random_probability(label):
>     if label == 1.0:
>       return random.uniform(0.5, 1.0)
>     else:
>       return random.uniform(0.0, 0.4999)
> def randomize_label(ratio):
>     
>     if random.random() >= ratio:
>       return 1.0
>     else:
>       return 0.0
> random_probability = udf(random_probability, DoubleType())
> randomize_label = udf(randomize_label, DoubleType())
> spark.range(10).write.mode("overwrite").format('csv').save("/tmp/tab3")
> babydf = spark.read.csv("/tmp/tab3")
> data_modified_label = babydf.withColumn(
>   'random_label', randomize_label(lit(1 - 0.1))
> )
> data_modified_random = data_modified_label.withColumn(
>   'random_probability', 
>   random_probability(col('random_label'))
> )
> data_modified_label.filter(col('random_label') == 0).show()
> {code}
> The above code will generate the following exception:
> {code}
> Py4JJavaError: An error occurred while calling o446.showString.
> : java.lang.RuntimeException: Invalid PythonUDF randomize_label(0.9), 
> requires attributes from more than one child.
>       at scala.sys.package$.error(package.scala:27)
>       at 
> org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract$2.apply(ExtractPythonUDFs.scala:166)
>       at 
> org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract$2.apply(ExtractPythonUDFs.scala:165)
>       at scala.collection.immutable.Stream.foreach(Stream.scala:594)
>       at 
> org.apache.spark.sql.execution.python.ExtractPythonUDFs$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(ExtractPythonUDFs.scala:165)
>       at 
> org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:116)
>       at 
> org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:112)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:310)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:310)
>       at 
> org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:77)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:309)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:327)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:208)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:325)
>       at 
> org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:307)
>       at 
> org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:112)
>       at 
> org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:92)
>       at 
> org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:119)
>       at 
> org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:119)
>       at 
> scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124)
>       at scala.collection.immutable.List.foldLeft(List.scala:84)
>       at 
> org.apache.spark.sql.execution.QueryExecution.prepareForExecution(QueryExecution.scala:119)
>       at 
> org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:109)
>       at 
> org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:109)
>       at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3016)
>       at org.apache.spark.sql.Dataset.head(Dataset.scala:2216)
>       at org.apache.spark.sql.Dataset.take(Dataset.scala:2429)
>       at org.apache.spark.sql.Dataset.showString(Dataset.scala:248)
>       at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
>       at 
> sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
>       at 
> sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
>       at java.lang.reflect.Method.invoke(Method.java:498)
>       at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
>       at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
>       at py4j.Gateway.invoke(Gateway.java:293)
>       at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
>       at py4j.commands.CallCommand.execute(CallCommand.java:79)
>       at py4j.GatewayConnection.run(GatewayConnection.java:226)
>       at java.lang.Thread.run(Thread.java:748)
> {code}



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to