This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 56dec397e26d [SPARK-48666][SQL] Do not push down filter if it contains 
PythonUDFs
56dec397e26d is described below

commit 56dec397e26d44e2b578ecea92be4c5e343c2a50
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Jul 12 19:08:24 2024 +0900

    [SPARK-48666][SQL] Do not push down filter if it contains PythonUDFs
    
    This PR proposes to prevent pushing down Python UDFs. This PR uses the same 
approach as https://github.com/apache/spark/pull/47033, therefore added the 
author as a co-author, but simplifies the change.
    
    Extracting filters to push down happens first
    
    
https://github.com/apache/spark/blob/cbe6846c477bc8b6d94385ddd0097c4e97b05d41/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala#L46
    
    
https://github.com/apache/spark/blob/cbe6846c477bc8b6d94385ddd0097c4e97b05d41/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L211
    
    
https://github.com/apache/spark/blob/cbe6846c477bc8b6d94385ddd0097c4e97b05d41/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala#L51
    
    Before extracting Python UDFs
    
    
https://github.com/apache/spark/blob/cbe6846c477bc8b6d94385ddd0097c4e97b05d41/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala#L80
    
    Here is full stacktrace:
    
    ```
    [INTERNAL_ERROR] Cannot evaluate expression: pyUDF(cast(input[0, bigint, 
true] as string)) SQLSTATE: XX000
    org.apache.spark.SparkException: [INTERNAL_ERROR] Cannot evaluate 
expression: pyUDF(cast(input[0, bigint, true] as string)) SQLSTATE: XX000
            at 
org.apache.spark.SparkException$.internalError(SparkException.scala:92)
            at 
org.apache.spark.SparkException$.internalError(SparkException.scala:96)
            at 
org.apache.spark.sql.errors.QueryExecutionErrors$.cannotEvaluateExpressionError(QueryExecutionErrors.scala:65)
            at 
org.apache.spark.sql.catalyst.expressions.FoldableUnevaluable.eval(Expression.scala:387)
            at 
org.apache.spark.sql.catalyst.expressions.FoldableUnevaluable.eval$(Expression.scala:386)
            at 
org.apache.spark.sql.catalyst.expressions.PythonUDF.eval(PythonUDF.scala:72)
            at 
org.apache.spark.sql.catalyst.expressions.UnaryExpression.eval(Expression.scala:563)
            at 
org.apache.spark.sql.catalyst.expressions.IsNotNull.eval(nullExpressions.scala:403)
            at 
org.apache.spark.sql.catalyst.expressions.InterpretedPredicate.eval(predicates.scala:53)
            at 
org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.$anonfun$prunePartitionsByFilter$1(ExternalCatalogUtils.scala:189)
            at 
org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.$anonfun$prunePartitionsByFilter$1$adapted(ExternalCatalogUtils.scala:188)
            at scala.collection.immutable.List.filter(List.scala:516)
            at scala.collection.immutable.List.filter(List.scala:79)
            at 
org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.prunePartitionsByFilter(ExternalCatalogUtils.scala:188)
            at 
org.apache.spark.sql.catalyst.catalog.InMemoryCatalog.listPartitionsByFilter(InMemoryCatalog.scala:604)
            at 
org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener.listPartitionsByFilter(ExternalCatalogWithListener.scala:262)
            at 
org.apache.spark.sql.catalyst.catalog.SessionCatalog.listPartitionsByFilter(SessionCatalog.scala:1358)
            at 
org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils$.listPartitionsByFilter(ExternalCatalogUtils.scala:168)
            at 
org.apache.spark.sql.execution.datasources.CatalogFileIndex.filterPartitions(CatalogFileIndex.scala:74)
            at 
org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions$$anonfun$apply$1.applyOrElse(PruneFileSourcePartitions.scala:72)
            at 
org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions$$anonfun$apply$1.applyOrElse(PruneFileSourcePartitions.scala:50)
            at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:470)
            at 
org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:84)
            at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:470)
            at 
org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:37)
            at 
org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:330)
            at 
org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:326)
            at 
org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37)
            at 
org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37)
            at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:475)
            at 
org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1251)
            at 
org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1250)
            at 
org.apache.spark.sql.catalyst.plans.logical.Join.mapChildren(basicLogicalOperators.scala:552)
            at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:475)
            at 
org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:37)
            at 
org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:330)
            at 
org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:326)
            at 
org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37)
            at 
org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:37)
            at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:446)
            at 
org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions$.apply(PruneFileSourcePartitions.scala:50)
            at 
org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions$.apply(PruneFileSourcePartitions.scala:35)
            at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$2(RuleExecutor.scala:226)
            at scala.collection.LinearSeqOps.foldLeft(LinearSeq.scala:183)
            at scala.collection.LinearSeqOps.foldLeft$(LinearSeq.scala:179)
            at scala.collection.immutable.List.foldLeft(List.scala:79)
            at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1(RuleExecutor.scala:223)
            at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$execute$1$adapted(RuleExecutor.scala:215)
            at scala.collection.immutable.List.foreach(List.scala:334)
            at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.execute(RuleExecutor.scala:215)
            at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.$anonfun$executeAndTrack$1(RuleExecutor.scala:186)
            at 
org.apache.spark.sql.catalyst.QueryPlanningTracker$.withTracker(QueryPlanningTracker.scala:89)
            at 
org.apache.spark.sql.catalyst.rules.RuleExecutor.executeAndTrack(RuleExecutor.scala:186)
            at 
org.apache.spark.sql.execution.QueryExecution.$anonfun$optimizedPlan$1(QueryExecution.scala:167)
            at 
org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:138)
            at 
org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$2(QueryExecution.scala:234)
            at 
org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:608)
            at 
org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:234)
            at 
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:923)
            at 
org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:233)
            at 
org.apache.spark.sql.execution.QueryExecution.optimizedPlan$lzycompute(QueryExecution.scala:163)
            at 
org.apache.spark.sql.execution.QueryExecution.optimizedPlan(QueryExecution.scala:159)
            at 
org.apache.spark.sql.execution.python.PythonUDFSuite.$anonfun$new$19(PythonUDFSuite.scala:136)
            at 
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.scala:18)
            at 
org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
            at 
org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
            at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:99)
            at 
org.apache.spark.sql.test.SQLTestUtilsBase.withTable(SQLTestUtils.scala:307)
            at 
org.apache.spark.sql.test.SQLTestUtilsBase.withTable$(SQLTestUtils.scala:305)
            at 
org.apache.spark.sql.execution.python.PythonUDFSuite.withTable(PythonUDFSuite.scala:25)
            at 
org.apache.spark.sql.execution.python.PythonUDFSuite.$anonfun$new$18(PythonUDFSuite.scala:130)
            at 
scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.scala:18)
            at 
org.scalatest.enablers.Timed$$anon$1.timeoutAfter(Timed.scala:127)
            at 
org.scalatest.concurrent.TimeLimits$.failAfterImpl(TimeLimits.scala:282)
            at 
org.scalatest.concurrent.TimeLimits.failAfter(TimeLimits.scala:231)
            at 
org.scalatest.concurrent.TimeLimits.failAfter$(TimeLimits.scala:230)
            at org.apache.spark.SparkFunSuite.failAfter(SparkFunSuite.scala:69)
            at 
org.apache.spark.SparkFunSuite.$anonfun$test$2(SparkFunSuite.scala:155)
            at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85)
            at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83)
            at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104)
            at org.scalatest.Transformer.apply(Transformer.scala:22)
            at org.scalatest.Transformer.apply(Transformer.scala:20)
            at 
org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:226)
            at 
org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:227)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:224)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:236)
            at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:236)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:218)
            at 
org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:69)
            at 
org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234)
            at 
org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227)
            at org.apache.spark.SparkFunSuite.runTest(SparkFunSuite.scala:69)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:269)
            at 
org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413)
            at scala.collection.immutable.List.foreach(List.scala:334)
            at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401)
            at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396)
            at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:269)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:268)
            at 
org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1564)
            at org.scalatest.Suite.run(Suite.scala:1114)
            at org.scalatest.Suite.run$(Suite.scala:1096)
            at 
org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1564)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:273)
            at org.scalatest.SuperEngine.runImpl(Engine.scala:535)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:273)
            at 
org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:272)
            at 
org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:69)
            at 
org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213)
            at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210)
            at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208)
            at org.apache.spark.SparkFunSuite.run(SparkFunSuite.scala:69)
            at org.scalatest.tools.SuiteRunner.run(SuiteRunner.scala:47)
            at 
org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13(Runner.scala:1321)
            at 
org.scalatest.tools.Runner$.$anonfun$doRunRunRunDaDoRunRun$13$adapted(Runner.scala:1315)
            at scala.collection.immutable.List.foreach(List.scala:334)
            at 
org.scalatest.tools.Runner$.doRunRunRunDaDoRunRun(Runner.scala:1315)
            at 
org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24(Runner.scala:992)
            at 
org.scalatest.tools.Runner$.$anonfun$runOptionallyWithPassFailReporter$24$adapted(Runner.scala:970)
            at 
org.scalatest.tools.Runner$.withClassLoaderAndDispatchReporter(Runner.scala:1481)
            at 
org.scalatest.tools.Runner$.runOptionallyWithPassFailReporter(Runner.scala:970)
            at org.scalatest.tools.Runner$.run(Runner.scala:798)
            at org.scalatest.tools.Runner.run(Runner.scala)
            at 
org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.runScalaTest2or3(ScalaTestRunner.java:43)
            at 
org.jetbrains.plugins.scala.testingSupport.scalaTest.ScalaTestRunner.main(ScalaTestRunner.java:26)
    ```
    
    In order for end users to use Python UDFs against partitioned columns.
    
    Yes, this fixes a bug - this PR allows to use Python UDF in partitioned 
columns.
    
    Unittest added.
    
    No.
    
    Closes #47033
    
    Closes #47313 from HyukjinKwon/SPARK-48666.
    
    Lead-authored-by: Hyukjin Kwon <[email protected]>
    Co-authored-by: Wei Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit d74785359c50bf966cfe892d3a9eae1a06341db2)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../datasources/PruneFileSourcePartitions.scala          |  7 ++++++-
 .../sql/execution/datasources/v2/FileScanBuilder.scala   |  7 +++++--
 .../spark/sql/execution/python/PythonUDFSuite.scala      | 16 ++++++++++++++--
 .../sql/hive/execution/PruneHiveTablePartitions.scala    |  9 +++++++--
 4 files changed, 32 insertions(+), 7 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index 1dffea4e1bc8..d5923a577daa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -63,7 +63,12 @@ private[sql] object PruneFileSourcePartitions extends 
Rule[LogicalPlan] {
             _))
         if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty =>
       val normalizedFilters = DataSourceStrategy.normalizeExprs(
-        filters.filter(f => f.deterministic && 
!SubqueryExpression.hasSubquery(f)),
+        filters.filter { f =>
+          f.deterministic &&
+            !SubqueryExpression.hasSubquery(f) &&
+            // Python UDFs might exist because this rule is applied before 
``ExtractPythonUDFs``.
+            !f.exists(_.isInstanceOf[PythonUDF])
+        },
         logicalRelation.output)
       val (partitionKeyFilters, _) = DataSourceUtils
         .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
index 447a36fe622c..7e0bc25a9a1e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
 import scala.collection.mutable
 
 import org.apache.spark.sql.{sources, SparkSession}
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF, 
SubqueryExpression}
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read.{ScanBuilder, 
SupportsPushDownRequiredColumns}
 import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, 
DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils}
@@ -73,7 +73,10 @@ abstract class FileScanBuilder(
     val (deterministicFilters, nonDeterminsticFilters) = 
filters.partition(_.deterministic)
     val (partitionFilters, dataFilters) =
       DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, 
deterministicFilters)
-    this.partitionFilters = partitionFilters
+    this.partitionFilters = partitionFilters.filter { f =>
+      // Python UDFs might exist because this rule is applied before 
``ExtractPythonUDFs``.
+      !SubqueryExpression.hasSubquery(f) && 
!f.exists(_.isInstanceOf[PythonUDF])
+    }
     this.dataFilters = dataFilters
     val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter]
     for (filterExpr <- dataFilters) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
index d86faec1a7bb..9a168dc80a03 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDFSuite.scala
@@ -17,8 +17,8 @@
 
 package org.apache.spark.sql.execution.python
 
-import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest}
-import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest, Row}
+import org.apache.spark.sql.functions.{col, count}
 import org.apache.spark.sql.test.SharedSparkSession
 
 class PythonUDFSuite extends QueryTest with SharedSparkSession {
@@ -111,4 +111,16 @@ class PythonUDFSuite extends QueryTest with 
SharedSparkSession {
     val pandasTestUDF = TestGroupedAggPandasUDF(name = udfName)
     
assert(df.agg(pandasTestUDF(df("id"))).schema.fieldNames.exists(_.startsWith(udfName)))
   }
+
+  test("SPARK-48666: Python UDF execution against partitioned column") {
+    assume(shouldTestPythonUDFs)
+    withTable("t") {
+      spark.range(1).selectExpr("id AS t", "(id + 1) AS 
p").write.partitionBy("p").saveAsTable("t")
+      val table = spark.table("t")
+      val newTable = table.withColumn("new_column", pythonTestUDF(table("p")))
+      val df = newTable.as("t1").join(
+        newTable.as("t2"), col("t1.new_column") === col("t2.new_column"))
+      checkAnswer(df, Row(0, 1, 1, 0, 1, 1))
+    }
+  }
 }
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
index 395ee86579e5..779562bed5b0 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
@@ -22,7 +22,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.analysis.CastSupport
 import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, 
Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, 
Expression, ExpressionSet, PredicateHelper, PythonUDF, SubqueryExpression}
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, 
Project}
 import 
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation
@@ -50,7 +50,12 @@ private[sql] class PruneHiveTablePartitions(session: 
SparkSession)
       filters: Seq[Expression],
       relation: HiveTableRelation): ExpressionSet = {
     val normalizedFilters = DataSourceStrategy.normalizeExprs(
-      filters.filter(f => f.deterministic && 
!SubqueryExpression.hasSubquery(f)), relation.output)
+      filters.filter { f =>
+        f.deterministic &&
+          !SubqueryExpression.hasSubquery(f) &&
+          // Python UDFs might exist because this rule is applied before 
``ExtractPythonUDFs``.
+          !f.exists(_.isInstanceOf[PythonUDF])
+      }, relation.output)
     val partitionColumnSet = AttributeSet(relation.partitionCols)
     ExpressionSet(
       normalizedFilters.flatMap(extractPredicatesWithinOutputSet(_, 
partitionColumnSet)))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to