This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 3cb841927f Spark: Rule for converting StaticInvoke to 
ApplyFunctionExpression for V2 filter push down (#8088)
3cb841927f is described below

commit 3cb841927f4ed58c5d5a4cafb0a0b1a017939808
Author: Xianyang Liu <[email protected]>
AuthorDate: Wed Sep 13 07:21:36 2023 +0800

    Spark: Rule for converting StaticInvoke to ApplyFunctionExpression for V2 
filter push down (#8088)
---
 .../extensions/IcebergSparkSessionExtensions.scala |   2 +
 .../catalyst/optimizer/ReplaceStaticInvoke.scala   |  93 ++++++
 .../extensions/TestSystemFunctionPushDownDQL.java  | 314 +++++++++++++++++++++
 .../org/apache/iceberg/spark/source/PlanUtils.java |  99 +++++++
 .../iceberg/spark/functions/SparkFunctions.java    |  18 ++
 .../spark/SystemFunctionPushDownHelper.java        |  21 +-
 .../spark/functions/TestSparkFunctions.java        | 157 +++++++++++
 .../apache/iceberg/spark/source/TestSparkScan.java |  54 +++-
 8 files changed, 739 insertions(+), 19 deletions(-)

diff --git 
a/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
 
b/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
index 90fc5af18d..4322d07007 100644
--- 
a/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
+++ 
b/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
@@ -31,6 +31,7 @@ import 
org.apache.spark.sql.catalyst.analysis.RewriteMergeIntoTable
 import org.apache.spark.sql.catalyst.analysis.RewriteUpdateTable
 import 
org.apache.spark.sql.catalyst.optimizer.ExtendedReplaceNullWithFalseInPredicate
 import 
org.apache.spark.sql.catalyst.optimizer.ExtendedSimplifyConditionalsInPredicate
+import org.apache.spark.sql.catalyst.optimizer.ReplaceStaticInvoke
 import 
org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser
 import 
org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy
 import org.apache.spark.sql.execution.datasources.v2.ExtendedV2Writes
@@ -58,6 +59,7 @@ class IcebergSparkSessionExtensions extends 
(SparkSessionExtensions => Unit) {
     // optimizer extensions
     extensions.injectOptimizerRule { _ => 
ExtendedSimplifyConditionalsInPredicate }
     extensions.injectOptimizerRule { _ => 
ExtendedReplaceNullWithFalseInPredicate }
+    extensions.injectOptimizerRule { _ => ReplaceStaticInvoke }
     // pre-CBO rules run only once and the order of the rules is important
     // - dynamic filters should be added before replacing commands with 
rewrite plans
     // - scans must be planned before building writes
diff --git 
a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
 
b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
new file mode 100644
index 0000000000..1f0e164d84
--- /dev/null
+++ 
b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.iceberg.spark.functions.SparkFunctions
+import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression
+import org.apache.spark.sql.catalyst.expressions.BinaryComparison
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
+import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
+import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
+import org.apache.spark.sql.types.StructField
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Spark analyzes the Iceberg system function to {@link StaticInvoke} which 
could not be pushed
+ * down to datasource. This rule will replace {@link StaticInvoke} to
+ * {@link ApplyFunctionExpression} for Iceberg system function in a filter 
condition.
+ */
+object ReplaceStaticInvoke extends Rule[LogicalPlan] {
+
+  override def apply(plan: LogicalPlan): LogicalPlan =
+    plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, 
FILTER)) {
+      case filter @ Filter(condition, _) =>
+        val newCondition = 
condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
+          case c @ BinaryComparison(left: StaticInvoke, right) if 
canReplace(left) && right.foldable =>
+            c.withNewChildren(Seq(replaceStaticInvoke(left), right))
+
+          case c @ BinaryComparison(left, right: StaticInvoke) if 
canReplace(right) && left.foldable =>
+            c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
+        }
+
+        if (newCondition fastEquals condition) {
+          filter
+        } else {
+          filter.copy(condition = newCondition)
+        }
+  }
+
+  private def replaceStaticInvoke(invoke: StaticInvoke): Expression = {
+    // Adaptive from `resolveV2Function` in 
org.apache.spark.sql.catalyst.analysis.ResolveFunctions
+    val unbound = SparkFunctions.loadFunctionByClass(invoke.staticObject)
+    if (unbound == null) {
+      return invoke
+    }
+
+    val inputType = StructType(invoke.arguments.zipWithIndex.map {
+      case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
+    })
+
+    val bound = try {
+      unbound.bind(inputType)
+    } catch {
+      case _: Exception =>
+        return invoke
+    }
+
+    if (bound.inputTypes().length != invoke.arguments.length) {
+      return invoke
+    }
+
+    bound match {
+      case scalarFunc: ScalarFunction[_] =>
+        ApplyFunctionExpression(scalarFunc, invoke.arguments)
+      case _ => invoke
+    }
+  }
+
+  @inline
+  private def canReplace(invoke: StaticInvoke): Boolean = {
+    invoke.functionName == ScalarFunction.MAGIC_METHOD_NAME && !invoke.foldable
+  }
+}
diff --git 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java
 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java
new file mode 100644
index 0000000000..7f2857cce0
--- /dev/null
+++ 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import static org.apache.iceberg.expressions.Expressions.bucket;
+import static org.apache.iceberg.expressions.Expressions.day;
+import static org.apache.iceberg.expressions.Expressions.equal;
+import static org.apache.iceberg.expressions.Expressions.greaterThan;
+import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
+import static org.apache.iceberg.expressions.Expressions.hour;
+import static org.apache.iceberg.expressions.Expressions.lessThan;
+import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
+import static org.apache.iceberg.expressions.Expressions.month;
+import static org.apache.iceberg.expressions.Expressions.notEqual;
+import static org.apache.iceberg.expressions.Expressions.truncate;
+import static org.apache.iceberg.expressions.Expressions.year;
+import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.STRUCT;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.createPartitionedTable;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.createUnpartitionedTable;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToDayOrdinal;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToHourOrdinal;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToMonthOrdinal;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToYearOrdinal;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.expressions.ExpressionUtil;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.source.PlanUtils;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.assertj.core.api.Assertions;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runners.Parameterized;
+
+public class TestSystemFunctionPushDownDQL extends SparkExtensionsTestBase {
+  public TestSystemFunctionPushDownDQL(
+      String catalogName, String implementation, Map<String, String> config) {
+    super(catalogName, implementation, config);
+  }
+
+  @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1}, 
config = {2}")
+  public static Object[][] parameters() {
+    return new Object[][] {
+      {
+        SparkCatalogConfig.HIVE.catalogName(),
+        SparkCatalogConfig.HIVE.implementation(),
+        SparkCatalogConfig.HIVE.properties(),
+      },
+    };
+  }
+
+  @Before
+  public void before() {
+    sql("USE %s", catalogName);
+  }
+
+  @After
+  public void removeTables() {
+    sql("DROP TABLE IF EXISTS %s", tableName);
+  }
+
+  @Test
+  public void testYearsFunctionOnUnpartitionedTable() {
+    createUnpartitionedTable(spark, tableName);
+    testYearsFunction(false);
+  }
+
+  @Test
+  public void testYearsFunctionOnPartitionedTable() {
+    createPartitionedTable(spark, tableName, "years(ts)");
+    testYearsFunction(true);
+  }
+
+  private void testYearsFunction(boolean partitioned) {
+    int targetYears = 
timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00");
+    String query =
+        String.format(
+            "SELECT * FROM %s WHERE system.years(ts) = %s ORDER BY id", 
tableName, targetYears);
+
+    Dataset<Row> df = spark.sql(query);
+    LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+    checkExpressions(optimizedPlan, partitioned, "years");
+    checkPushedFilters(optimizedPlan, equal(year("ts"), targetYears));
+
+    List<Object[]> actual = rowsToJava(df.collectAsList());
+    Assertions.assertThat(actual.size()).isEqualTo(5);
+  }
+
+  @Test
+  public void testMonthsFunctionOnUnpartitionedTable() {
+    createUnpartitionedTable(spark, tableName);
+    testMonthsFunction(false);
+  }
+
+  @Test
+  public void testMonthsFunctionOnPartitionedTable() {
+    createPartitionedTable(spark, tableName, "months(ts)");
+    testMonthsFunction(true);
+  }
+
+  private void testMonthsFunction(boolean partitioned) {
+    int targetMonths = 
timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00");
+    String query =
+        String.format(
+            "SELECT * FROM %s WHERE system.months(ts) > %s ORDER BY id", 
tableName, targetMonths);
+
+    Dataset<Row> df = spark.sql(query);
+    LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+    checkExpressions(optimizedPlan, partitioned, "months");
+    checkPushedFilters(optimizedPlan, greaterThan(month("ts"), targetMonths));
+
+    List<Object[]> actual = rowsToJava(df.collectAsList());
+    Assertions.assertThat(actual.size()).isEqualTo(5);
+  }
+
+  @Test
+  public void testDaysFunctionOnUnpartitionedTable() {
+    createUnpartitionedTable(spark, tableName);
+    testDaysFunction(false);
+  }
+
+  @Test
+  public void testDaysFunctionOnPartitionedTable() {
+    createPartitionedTable(spark, tableName, "days(ts)");
+    testDaysFunction(true);
+  }
+
+  private void testDaysFunction(boolean partitioned) {
+    String timestamp = "2018-11-20T00:00:00.000000+00:00";
+    int targetDays = timestampStrToDayOrdinal(timestamp);
+    String query =
+        String.format(
+            "SELECT * FROM %s WHERE system.days(ts) < date('%s') ORDER BY id",
+            tableName, timestamp);
+
+    Dataset<Row> df = spark.sql(query);
+    LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+    checkExpressions(optimizedPlan, partitioned, "days");
+    checkPushedFilters(optimizedPlan, lessThan(day("ts"), targetDays));
+
+    List<Object[]> actual = rowsToJava(df.collectAsList());
+    Assertions.assertThat(actual.size()).isEqualTo(5);
+  }
+
+  @Test
+  public void testHoursFunctionOnUnpartitionedTable() {
+    createUnpartitionedTable(spark, tableName);
+    testHoursFunction(false);
+  }
+
+  @Test
+  public void testHoursFunctionOnPartitionedTable() {
+    createPartitionedTable(spark, tableName, "hours(ts)");
+    testHoursFunction(true);
+  }
+
+  private void testHoursFunction(boolean partitioned) {
+    int targetHours = 
timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00");
+    String query =
+        String.format(
+            "SELECT * FROM %s WHERE system.hours(ts) >= %s ORDER BY id", 
tableName, targetHours);
+
+    Dataset<Row> df = spark.sql(query);
+    LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+    checkExpressions(optimizedPlan, partitioned, "hours");
+    checkPushedFilters(optimizedPlan, greaterThanOrEqual(hour("ts"), 
targetHours));
+
+    List<Object[]> actual = rowsToJava(df.collectAsList());
+    Assertions.assertThat(actual.size()).isEqualTo(8);
+  }
+
+  @Test
+  public void testBucketLongFunctionOnUnpartitionedTable() {
+    createUnpartitionedTable(spark, tableName);
+    testBucketLongFunction(false);
+  }
+
+  @Test
+  public void testBucketLongFunctionOnPartitionedTable() {
+    createPartitionedTable(spark, tableName, "bucket(5, id)");
+    testBucketLongFunction(true);
+  }
+
+  private void testBucketLongFunction(boolean partitioned) {
+    int target = 2;
+    String query =
+        String.format(
+            "SELECT * FROM %s WHERE system.bucket(5, id) <= %s ORDER BY id", 
tableName, target);
+
+    Dataset<Row> df = spark.sql(query);
+    LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+    checkExpressions(optimizedPlan, partitioned, "bucket");
+    checkPushedFilters(optimizedPlan, lessThanOrEqual(bucket("id", 5), 
target));
+
+    List<Object[]> actual = rowsToJava(df.collectAsList());
+    Assertions.assertThat(actual.size()).isEqualTo(5);
+  }
+
+  @Test
+  public void testBucketStringFunctionOnUnpartitionedTable() {
+    createUnpartitionedTable(spark, tableName);
+    testBucketStringFunction(false);
+  }
+
+  @Test
+  public void testBucketStringFunctionOnPartitionedTable() {
+    createPartitionedTable(spark, tableName, "bucket(5, data)");
+    testBucketStringFunction(true);
+  }
+
+  private void testBucketStringFunction(boolean partitioned) {
+    int target = 2;
+    String query =
+        String.format(
+            "SELECT * FROM %s WHERE system.bucket(5, data) != %s ORDER BY id", 
tableName, target);
+
+    Dataset<Row> df = spark.sql(query);
+    LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+    checkExpressions(optimizedPlan, partitioned, "bucket");
+    checkPushedFilters(optimizedPlan, notEqual(bucket("data", 5), target));
+
+    List<Object[]> actual = rowsToJava(df.collectAsList());
+    Assertions.assertThat(actual.size()).isEqualTo(8);
+  }
+
+  @Test
+  public void testTruncateFunctionOnUnpartitionedTable() {
+    createUnpartitionedTable(spark, tableName);
+    testTruncateFunction(false);
+  }
+
+  @Test
+  public void testTruncateFunctionOnPartitionedTable() {
+    createPartitionedTable(spark, tableName, "truncate(4, data)");
+    testTruncateFunction(true);
+  }
+
+  private void testTruncateFunction(boolean partitioned) {
+    String target = "data";
+    String query =
+        String.format(
+            "SELECT * FROM %s WHERE system.truncate(4, data) = '%s' ORDER BY 
id",
+            tableName, target);
+
+    Dataset<Row> df = spark.sql(query);
+    LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+    checkExpressions(optimizedPlan, partitioned, "truncate");
+    checkPushedFilters(optimizedPlan, equal(truncate("data", 4), target));
+
+    List<Object[]> actual = rowsToJava(df.collectAsList());
+    Assertions.assertThat(actual.size()).isEqualTo(5);
+  }
+
+  private void checkExpressions(
+      LogicalPlan optimizedPlan, boolean partitioned, String 
expectedFunctionName) {
+    List<Expression> staticInvokes =
+        PlanUtils.collectSparkExpressions(
+            optimizedPlan, expression -> expression instanceof StaticInvoke);
+    Assertions.assertThat(staticInvokes).isEmpty();
+
+    List<Expression> applyExpressions =
+        PlanUtils.collectSparkExpressions(
+            optimizedPlan, expression -> expression instanceof 
ApplyFunctionExpression);
+
+    if (partitioned) {
+      Assertions.assertThat(applyExpressions).isEmpty();
+    } else {
+      Assertions.assertThat(applyExpressions.size()).isEqualTo(1);
+      ApplyFunctionExpression expression = (ApplyFunctionExpression) 
applyExpressions.get(0);
+      Assertions.assertThat(expression.name()).isEqualTo(expectedFunctionName);
+    }
+  }
+
+  private void checkPushedFilters(
+      LogicalPlan optimizedPlan, org.apache.iceberg.expressions.Expression 
expected) {
+    List<org.apache.iceberg.expressions.Expression> pushedFilters =
+        PlanUtils.collectPushDownFilters(optimizedPlan);
+    Assertions.assertThat(pushedFilters.size()).isEqualTo(1);
+    org.apache.iceberg.expressions.Expression actual = pushedFilters.get(0);
+    Assertions.assertThat(ExpressionUtil.equivalent(expected, actual, STRUCT, 
true))
+        .as("Pushed filter should match")
+        .isTrue();
+  }
+}
diff --git 
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java
 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java
new file mode 100644
index 0000000000..148717e142
--- /dev/null
+++ 
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation;
+import scala.PartialFunction;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+
+public class PlanUtils {
+  private PlanUtils() {}
+
+  public static List<org.apache.iceberg.expressions.Expression> 
collectPushDownFilters(
+      LogicalPlan logicalPlan) {
+    return 
JavaConverters.asJavaCollection(logicalPlan.collectLeaves()).stream()
+        .flatMap(
+            plan -> {
+              if (!(plan instanceof DataSourceV2ScanRelation)) {
+                return Stream.empty();
+              }
+
+              DataSourceV2ScanRelation scanRelation = 
(DataSourceV2ScanRelation) plan;
+              if (!(scanRelation.scan() instanceof SparkBatchQueryScan)) {
+                return Stream.empty();
+              }
+
+              SparkBatchQueryScan batchQueryScan = (SparkBatchQueryScan) 
scanRelation.scan();
+              return batchQueryScan.filterExpressions().stream();
+            })
+        .collect(Collectors.toList());
+  }
+
+  public static List<Expression> collectSparkExpressions(
+      LogicalPlan logicalPlan, Predicate<Expression> predicate) {
+    Seq<List<Expression>> list =
+        logicalPlan.collect(
+            new PartialFunction<LogicalPlan, List<Expression>>() {
+
+              @Override
+              public List<Expression> apply(LogicalPlan plan) {
+                return 
JavaConverters.asJavaCollection(plan.expressions()).stream()
+                    .flatMap(expr -> collectSparkExpressions(expr, 
predicate).stream())
+                    .collect(Collectors.toList());
+              }
+
+              @Override
+              public boolean isDefinedAt(LogicalPlan plan) {
+                return true;
+              }
+            });
+
+    return JavaConverters.asJavaCollection(list).stream()
+        .flatMap(Collection::stream)
+        .collect(Collectors.toList());
+  }
+
+  private static List<Expression> collectSparkExpressions(
+      Expression expression, Predicate<Expression> predicate) {
+    Seq<Expression> list =
+        expression.collect(
+            new PartialFunction<Expression, Expression>() {
+              @Override
+              public Expression apply(Expression expr) {
+                return expr;
+              }
+
+              @Override
+              public boolean isDefinedAt(Expression expr) {
+                return predicate.test(expr);
+              }
+            });
+
+    return Lists.newArrayList(JavaConverters.asJavaCollection(list));
+  }
+}
diff --git 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java
 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java
index d14bd45831..6d9cadec57 100644
--- 
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java
+++ 
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java
@@ -39,6 +39,15 @@ public class SparkFunctions {
           "bucket", new BucketFunction(),
           "truncate", new TruncateFunction());
 
+  private static final Map<Class<?>, UnboundFunction> CLASS_TO_FUNCTIONS =
+      ImmutableMap.of(
+          YearsFunction.class, new YearsFunction(),
+          MonthsFunction.class, new MonthsFunction(),
+          DaysFunction.class, new DaysFunction(),
+          HoursFunction.class, new HoursFunction(),
+          BucketFunction.class, new BucketFunction(),
+          TruncateFunction.class, new TruncateFunction());
+
   private static final List<String> FUNCTION_NAMES = 
ImmutableList.copyOf(FUNCTIONS.keySet());
 
   // Functions that are added to all Iceberg catalogs should be accessed with 
the `system`
@@ -54,4 +63,13 @@ public class SparkFunctions {
     // function resolution is case-insensitive to match the existing Spark 
behavior for functions
     return FUNCTIONS.get(name.toLowerCase(Locale.ROOT));
   }
+
+  public static UnboundFunction loadFunctionByClass(Class<?> functionClass) {
+    Class<?> declaringClass = functionClass.getDeclaringClass();
+    if (declaringClass == null) {
+      return null;
+    }
+
+    return CLASS_TO_FUNCTIONS.get(declaringClass);
+  }
 }
diff --git 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java
 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java
index 9258bb4f0e..059325e02a 100644
--- 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java
+++ 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java
@@ -18,10 +18,17 @@
  */
 package org.apache.iceberg.spark;
 
+import org.apache.iceberg.types.Types;
 import org.apache.iceberg.util.DateTimeUtil;
 import org.apache.spark.sql.SparkSession;
 
 public class SystemFunctionPushDownHelper {
+  public static final Types.StructType STRUCT =
+      Types.StructType.of(
+          Types.NestedField.optional(1, "id", Types.LongType.get()),
+          Types.NestedField.optional(2, "ts", Types.TimestampType.withZone()),
+          Types.NestedField.optional(3, "data", Types.StringType.get()));
+
   private SystemFunctionPushDownHelper() {}
 
   public static void createUnpartitionedTable(SparkSession spark, String 
tableName) {
@@ -98,19 +105,19 @@ public class SystemFunctionPushDownHelper {
         "(9, CAST('2018-12-21T15:02:15.230570+00:00' AS TIMESTAMP), 
'material-9')");
   }
 
-  public static int years(String date) {
-    return DateTimeUtil.daysToYears(DateTimeUtil.isoDateToDays(date));
+  public static int timestampStrToYearOrdinal(String timestamp) {
+    return 
DateTimeUtil.microsToYears(DateTimeUtil.isoTimestamptzToMicros(timestamp));
   }
 
-  public static int months(String date) {
-    return DateTimeUtil.daysToMonths(DateTimeUtil.isoDateToDays(date));
+  public static int timestampStrToMonthOrdinal(String timestamp) {
+    return 
DateTimeUtil.microsToMonths(DateTimeUtil.isoTimestamptzToMicros(timestamp));
   }
 
-  public static int days(String date) {
-    return DateTimeUtil.isoDateToDays(date);
+  public static int timestampStrToDayOrdinal(String timestamp) {
+    return 
DateTimeUtil.microsToDays(DateTimeUtil.isoTimestamptzToMicros(timestamp));
   }
 
-  public static int hours(String timestamp) {
+  public static int timestampStrToHourOrdinal(String timestamp) {
     return 
DateTimeUtil.microsToHours(DateTimeUtil.isoTimestamptzToMicros(timestamp));
   }
 
diff --git 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java
 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java
new file mode 100644
index 0000000000..34308e77d2
--- /dev/null
+++ 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.functions;
+
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.DecimalType;
+import org.assertj.core.api.Assertions;
+import org.junit.Test;
+
+public class TestSparkFunctions {
+
+  @Test
+  public void testBuildYearsFunctionFromClass() {
+    UnboundFunction expected = new YearsFunction();
+
+    YearsFunction.DateToYearsFunction dateToYearsFunc = new 
YearsFunction.DateToYearsFunction();
+    checkBuildFunc(dateToYearsFunc, expected);
+
+    YearsFunction.TimestampToYearsFunction tsToYearsFunc =
+        new YearsFunction.TimestampToYearsFunction();
+    checkBuildFunc(tsToYearsFunc, expected);
+
+    YearsFunction.TimestampNtzToYearsFunction tsNtzToYearsFunc =
+        new YearsFunction.TimestampNtzToYearsFunction();
+    checkBuildFunc(tsNtzToYearsFunc, expected);
+  }
+
+  @Test
+  public void testBuildMonthsFunctionFromClass() {
+    UnboundFunction expected = new MonthsFunction();
+
+    MonthsFunction.DateToMonthsFunction dateToMonthsFunc =
+        new MonthsFunction.DateToMonthsFunction();
+    checkBuildFunc(dateToMonthsFunc, expected);
+
+    MonthsFunction.TimestampToMonthsFunction tsToMonthsFunc =
+        new MonthsFunction.TimestampToMonthsFunction();
+    checkBuildFunc(tsToMonthsFunc, expected);
+
+    MonthsFunction.TimestampNtzToMonthsFunction tsNtzToMonthsFunc =
+        new MonthsFunction.TimestampNtzToMonthsFunction();
+    checkBuildFunc(tsNtzToMonthsFunc, expected);
+  }
+
+  @Test
+  public void testBuildDaysFunctionFromClass() {
+    UnboundFunction expected = new DaysFunction();
+
+    DaysFunction.DateToDaysFunction dateToDaysFunc = new 
DaysFunction.DateToDaysFunction();
+    checkBuildFunc(dateToDaysFunc, expected);
+
+    DaysFunction.TimestampToDaysFunction tsToDaysFunc = new 
DaysFunction.TimestampToDaysFunction();
+    checkBuildFunc(tsToDaysFunc, expected);
+
+    DaysFunction.TimestampNtzToDaysFunction tsNtzToDaysFunc =
+        new DaysFunction.TimestampNtzToDaysFunction();
+    checkBuildFunc(tsNtzToDaysFunc, expected);
+  }
+
+  @Test
+  public void testBuildHoursFunctionFromClass() {
+    UnboundFunction expected = new HoursFunction();
+
+    HoursFunction.TimestampToHoursFunction tsToHoursFunc =
+        new HoursFunction.TimestampToHoursFunction();
+    checkBuildFunc(tsToHoursFunc, expected);
+
+    HoursFunction.TimestampNtzToHoursFunction tsNtzToHoursFunc =
+        new HoursFunction.TimestampNtzToHoursFunction();
+    checkBuildFunc(tsNtzToHoursFunc, expected);
+  }
+
+  @Test
+  public void testBuildBucketFunctionFromClass() {
+    UnboundFunction expected = new BucketFunction();
+
+    BucketFunction.BucketInt bucketDateFunc = new 
BucketFunction.BucketInt(DataTypes.DateType);
+    checkBuildFunc(bucketDateFunc, expected);
+
+    BucketFunction.BucketInt bucketIntFunc = new 
BucketFunction.BucketInt(DataTypes.IntegerType);
+    checkBuildFunc(bucketIntFunc, expected);
+
+    BucketFunction.BucketLong bucketLongFunc = new 
BucketFunction.BucketLong(DataTypes.LongType);
+    checkBuildFunc(bucketLongFunc, expected);
+
+    BucketFunction.BucketLong bucketTsFunc = new 
BucketFunction.BucketLong(DataTypes.TimestampType);
+    checkBuildFunc(bucketTsFunc, expected);
+
+    BucketFunction.BucketLong bucketTsNtzFunc =
+        new BucketFunction.BucketLong(DataTypes.TimestampNTZType);
+    checkBuildFunc(bucketTsNtzFunc, expected);
+
+    BucketFunction.BucketDecimal bucketDecimalFunc =
+        new BucketFunction.BucketDecimal(new DecimalType());
+    checkBuildFunc(bucketDecimalFunc, expected);
+
+    BucketFunction.BucketString bucketStringFunc = new 
BucketFunction.BucketString();
+    checkBuildFunc(bucketStringFunc, expected);
+
+    BucketFunction.BucketBinary bucketBinary = new 
BucketFunction.BucketBinary();
+    checkBuildFunc(bucketBinary, expected);
+  }
+
+  @Test
+  public void testBuildTruncateFunctionFromClass() {
+    UnboundFunction expected = new TruncateFunction();
+
+    TruncateFunction.TruncateTinyInt truncateTinyIntFunc = new 
TruncateFunction.TruncateTinyInt();
+    checkBuildFunc(truncateTinyIntFunc, expected);
+
+    TruncateFunction.TruncateSmallInt truncateSmallIntFunc =
+        new TruncateFunction.TruncateSmallInt();
+    checkBuildFunc(truncateSmallIntFunc, expected);
+
+    TruncateFunction.TruncateInt truncateIntFunc = new 
TruncateFunction.TruncateInt();
+    checkBuildFunc(truncateIntFunc, expected);
+
+    TruncateFunction.TruncateBigInt truncateBigIntFunc = new 
TruncateFunction.TruncateBigInt();
+    checkBuildFunc(truncateBigIntFunc, expected);
+
+    TruncateFunction.TruncateDecimal truncateDecimalFunc =
+        new TruncateFunction.TruncateDecimal(10, 9);
+    checkBuildFunc(truncateDecimalFunc, expected);
+
+    TruncateFunction.TruncateString truncateStringFunc = new 
TruncateFunction.TruncateString();
+    checkBuildFunc(truncateStringFunc, expected);
+
+    TruncateFunction.TruncateBinary truncateBinaryFunc = new 
TruncateFunction.TruncateBinary();
+    checkBuildFunc(truncateBinaryFunc, expected);
+  }
+
+  private void checkBuildFunc(ScalarFunction<?> function, UnboundFunction 
expected) {
+    UnboundFunction actual = 
SparkFunctions.loadFunctionByClass(function.getClass());
+
+    Assertions.assertThat(actual).isNotNull();
+    Assertions.assertThat(actual.name()).isEqualTo(expected.name());
+    
Assertions.assertThat(actual.description()).isEqualTo(expected.description());
+  }
+}
diff --git 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java
 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java
index b3373e908b..78d169bf73 100644
--- 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java
+++ 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java
@@ -20,10 +20,10 @@ package org.apache.iceberg.spark.source;
 
 import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.createPartitionedTable;
 import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.createUnpartitionedTable;
-import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.days;
-import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.hours;
-import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.months;
-import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.years;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToDayOrdinal;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToHourOrdinal;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToMonthOrdinal;
+import static 
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToYearOrdinal;
 import static org.apache.spark.sql.functions.date_add;
 import static org.apache.spark.sql.functions.expr;
 
@@ -120,7 +120,11 @@ public class TestSparkScan extends 
SparkTestBaseWithCatalog {
 
     YearsFunction.TimestampToYearsFunction function = new 
YearsFunction.TimestampToYearsFunction();
     UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
-    Predicate predicate = new Predicate("=", expressions(udf, 
intLit(years("2017-11-22"))));
+    Predicate predicate =
+        new Predicate(
+            "=",
+            expressions(
+                udf, 
intLit(timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00"))));
     pushFilters(builder, predicate);
     Batch scan = builder.build().toBatch();
 
@@ -144,7 +148,11 @@ public class TestSparkScan extends 
SparkTestBaseWithCatalog {
 
     YearsFunction.TimestampToYearsFunction function = new 
YearsFunction.TimestampToYearsFunction();
     UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
-    Predicate predicate = new Predicate("=", expressions(udf, 
intLit(years("2017-11-22"))));
+    Predicate predicate =
+        new Predicate(
+            "=",
+            expressions(
+                udf, 
intLit(timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00"))));
     pushFilters(builder, predicate);
     Batch scan = builder.build().toBatch();
 
@@ -169,7 +177,11 @@ public class TestSparkScan extends 
SparkTestBaseWithCatalog {
     MonthsFunction.TimestampToMonthsFunction function =
         new MonthsFunction.TimestampToMonthsFunction();
     UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
-    Predicate predicate = new Predicate(">", expressions(udf, 
intLit(months("2017-11-22"))));
+    Predicate predicate =
+        new Predicate(
+            ">",
+            expressions(
+                udf, 
intLit(timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00"))));
     pushFilters(builder, predicate);
     Batch scan = builder.build().toBatch();
 
@@ -194,7 +206,11 @@ public class TestSparkScan extends 
SparkTestBaseWithCatalog {
     MonthsFunction.TimestampToMonthsFunction function =
         new MonthsFunction.TimestampToMonthsFunction();
     UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
-    Predicate predicate = new Predicate(">", expressions(udf, 
intLit(months("2017-11-22"))));
+    Predicate predicate =
+        new Predicate(
+            ">",
+            expressions(
+                udf, 
intLit(timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00"))));
     pushFilters(builder, predicate);
     Batch scan = builder.build().toBatch();
 
@@ -218,7 +234,11 @@ public class TestSparkScan extends 
SparkTestBaseWithCatalog {
 
     DaysFunction.TimestampToDaysFunction function = new 
DaysFunction.TimestampToDaysFunction();
     UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
-    Predicate predicate = new Predicate("<", expressions(udf, 
dateLit(days("2018-11-20"))));
+    Predicate predicate =
+        new Predicate(
+            "<",
+            expressions(
+                udf, 
dateLit(timestampStrToDayOrdinal("2018-11-20T00:00:00.000000+00:00"))));
     pushFilters(builder, predicate);
     Batch scan = builder.build().toBatch();
 
@@ -242,7 +262,11 @@ public class TestSparkScan extends 
SparkTestBaseWithCatalog {
 
     DaysFunction.TimestampToDaysFunction function = new 
DaysFunction.TimestampToDaysFunction();
     UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
-    Predicate predicate = new Predicate("<", expressions(udf, 
dateLit(days("2018-11-20"))));
+    Predicate predicate =
+        new Predicate(
+            "<",
+            expressions(
+                udf, 
dateLit(timestampStrToDayOrdinal("2018-11-20T00:00:00.000000+00:00"))));
     pushFilters(builder, predicate);
     Batch scan = builder.build().toBatch();
 
@@ -267,7 +291,10 @@ public class TestSparkScan extends 
SparkTestBaseWithCatalog {
     HoursFunction.TimestampToHoursFunction function = new 
HoursFunction.TimestampToHoursFunction();
     UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
     Predicate predicate =
-        new Predicate(">=", expressions(udf, 
intLit(hours("2017-11-22T06:02:09.243857+00:00"))));
+        new Predicate(
+            ">=",
+            expressions(
+                udf, 
intLit(timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00"))));
     pushFilters(builder, predicate);
     Batch scan = builder.build().toBatch();
 
@@ -292,7 +319,10 @@ public class TestSparkScan extends 
SparkTestBaseWithCatalog {
     HoursFunction.TimestampToHoursFunction function = new 
HoursFunction.TimestampToHoursFunction();
     UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
     Predicate predicate =
-        new Predicate(">=", expressions(udf, 
intLit(hours("2017-11-22T06:02:09.243857+00:00"))));
+        new Predicate(
+            ">=",
+            expressions(
+                udf, 
intLit(timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00"))));
     pushFilters(builder, predicate);
     Batch scan = builder.build().toBatch();
 


Reply via email to