This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 3cb841927f Spark: Rule for converting StaticInvoke to
ApplyFunctionExpression for V2 filter push down (#8088)
3cb841927f is described below
commit 3cb841927f4ed58c5d5a4cafb0a0b1a017939808
Author: Xianyang Liu <[email protected]>
AuthorDate: Wed Sep 13 07:21:36 2023 +0800
Spark: Rule for converting StaticInvoke to ApplyFunctionExpression for V2
filter push down (#8088)
---
.../extensions/IcebergSparkSessionExtensions.scala | 2 +
.../catalyst/optimizer/ReplaceStaticInvoke.scala | 93 ++++++
.../extensions/TestSystemFunctionPushDownDQL.java | 314 +++++++++++++++++++++
.../org/apache/iceberg/spark/source/PlanUtils.java | 99 +++++++
.../iceberg/spark/functions/SparkFunctions.java | 18 ++
.../spark/SystemFunctionPushDownHelper.java | 21 +-
.../spark/functions/TestSparkFunctions.java | 157 +++++++++++
.../apache/iceberg/spark/source/TestSparkScan.java | 54 +++-
8 files changed, 739 insertions(+), 19 deletions(-)
diff --git
a/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
b/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
index 90fc5af18d..4322d07007 100644
---
a/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
+++
b/spark/v3.4/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
@@ -31,6 +31,7 @@ import
org.apache.spark.sql.catalyst.analysis.RewriteMergeIntoTable
import org.apache.spark.sql.catalyst.analysis.RewriteUpdateTable
import
org.apache.spark.sql.catalyst.optimizer.ExtendedReplaceNullWithFalseInPredicate
import
org.apache.spark.sql.catalyst.optimizer.ExtendedSimplifyConditionalsInPredicate
+import org.apache.spark.sql.catalyst.optimizer.ReplaceStaticInvoke
import
org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser
import
org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy
import org.apache.spark.sql.execution.datasources.v2.ExtendedV2Writes
@@ -58,6 +59,7 @@ class IcebergSparkSessionExtensions extends
(SparkSessionExtensions => Unit) {
// optimizer extensions
extensions.injectOptimizerRule { _ =>
ExtendedSimplifyConditionalsInPredicate }
extensions.injectOptimizerRule { _ =>
ExtendedReplaceNullWithFalseInPredicate }
+ extensions.injectOptimizerRule { _ => ReplaceStaticInvoke }
// pre-CBO rules run only once and the order of the rules is important
// - dynamic filters should be added before replacing commands with
rewrite plans
// - scans must be planned before building writes
diff --git
a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
new file mode 100644
index 0000000000..1f0e164d84
--- /dev/null
+++
b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.iceberg.spark.functions.SparkFunctions
+import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression
+import org.apache.spark.sql.catalyst.expressions.BinaryComparison
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
+import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
+import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
+import org.apache.spark.sql.types.StructField
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Spark analyzes the Iceberg system function to {@link StaticInvoke} which
could not be pushed
+ * down to datasource. This rule will replace {@link StaticInvoke} to
+ * {@link ApplyFunctionExpression} for Iceberg system function in a filter
condition.
+ */
+object ReplaceStaticInvoke extends Rule[LogicalPlan] {
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON,
FILTER)) {
+ case filter @ Filter(condition, _) =>
+ val newCondition =
condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
+ case c @ BinaryComparison(left: StaticInvoke, right) if
canReplace(left) && right.foldable =>
+ c.withNewChildren(Seq(replaceStaticInvoke(left), right))
+
+ case c @ BinaryComparison(left, right: StaticInvoke) if
canReplace(right) && left.foldable =>
+ c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
+ }
+
+ if (newCondition fastEquals condition) {
+ filter
+ } else {
+ filter.copy(condition = newCondition)
+ }
+ }
+
+ private def replaceStaticInvoke(invoke: StaticInvoke): Expression = {
+ // Adaptive from `resolveV2Function` in
org.apache.spark.sql.catalyst.analysis.ResolveFunctions
+ val unbound = SparkFunctions.loadFunctionByClass(invoke.staticObject)
+ if (unbound == null) {
+ return invoke
+ }
+
+ val inputType = StructType(invoke.arguments.zipWithIndex.map {
+ case (exp, pos) => StructField(s"_$pos", exp.dataType, exp.nullable)
+ })
+
+ val bound = try {
+ unbound.bind(inputType)
+ } catch {
+ case _: Exception =>
+ return invoke
+ }
+
+ if (bound.inputTypes().length != invoke.arguments.length) {
+ return invoke
+ }
+
+ bound match {
+ case scalarFunc: ScalarFunction[_] =>
+ ApplyFunctionExpression(scalarFunc, invoke.arguments)
+ case _ => invoke
+ }
+ }
+
+ @inline
+ private def canReplace(invoke: StaticInvoke): Boolean = {
+ invoke.functionName == ScalarFunction.MAGIC_METHOD_NAME && !invoke.foldable
+ }
+}
diff --git
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java
new file mode 100644
index 0000000000..7f2857cce0
--- /dev/null
+++
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.extensions;
+
+import static org.apache.iceberg.expressions.Expressions.bucket;
+import static org.apache.iceberg.expressions.Expressions.day;
+import static org.apache.iceberg.expressions.Expressions.equal;
+import static org.apache.iceberg.expressions.Expressions.greaterThan;
+import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
+import static org.apache.iceberg.expressions.Expressions.hour;
+import static org.apache.iceberg.expressions.Expressions.lessThan;
+import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
+import static org.apache.iceberg.expressions.Expressions.month;
+import static org.apache.iceberg.expressions.Expressions.notEqual;
+import static org.apache.iceberg.expressions.Expressions.truncate;
+import static org.apache.iceberg.expressions.Expressions.year;
+import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.STRUCT;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.createPartitionedTable;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.createUnpartitionedTable;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToDayOrdinal;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToHourOrdinal;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToMonthOrdinal;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToYearOrdinal;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.iceberg.expressions.ExpressionUtil;
+import org.apache.iceberg.spark.SparkCatalogConfig;
+import org.apache.iceberg.spark.source.PlanUtils;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.assertj.core.api.Assertions;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runners.Parameterized;
+
+public class TestSystemFunctionPushDownDQL extends SparkExtensionsTestBase {
+ public TestSystemFunctionPushDownDQL(
+ String catalogName, String implementation, Map<String, String> config) {
+ super(catalogName, implementation, config);
+ }
+
+ @Parameterized.Parameters(name = "catalogName = {0}, implementation = {1},
config = {2}")
+ public static Object[][] parameters() {
+ return new Object[][] {
+ {
+ SparkCatalogConfig.HIVE.catalogName(),
+ SparkCatalogConfig.HIVE.implementation(),
+ SparkCatalogConfig.HIVE.properties(),
+ },
+ };
+ }
+
+ @Before
+ public void before() {
+ sql("USE %s", catalogName);
+ }
+
+ @After
+ public void removeTables() {
+ sql("DROP TABLE IF EXISTS %s", tableName);
+ }
+
+ @Test
+ public void testYearsFunctionOnUnpartitionedTable() {
+ createUnpartitionedTable(spark, tableName);
+ testYearsFunction(false);
+ }
+
+ @Test
+ public void testYearsFunctionOnPartitionedTable() {
+ createPartitionedTable(spark, tableName, "years(ts)");
+ testYearsFunction(true);
+ }
+
+ private void testYearsFunction(boolean partitioned) {
+ int targetYears =
timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00");
+ String query =
+ String.format(
+ "SELECT * FROM %s WHERE system.years(ts) = %s ORDER BY id",
tableName, targetYears);
+
+ Dataset<Row> df = spark.sql(query);
+ LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+ checkExpressions(optimizedPlan, partitioned, "years");
+ checkPushedFilters(optimizedPlan, equal(year("ts"), targetYears));
+
+ List<Object[]> actual = rowsToJava(df.collectAsList());
+ Assertions.assertThat(actual.size()).isEqualTo(5);
+ }
+
+ @Test
+ public void testMonthsFunctionOnUnpartitionedTable() {
+ createUnpartitionedTable(spark, tableName);
+ testMonthsFunction(false);
+ }
+
+ @Test
+ public void testMonthsFunctionOnPartitionedTable() {
+ createPartitionedTable(spark, tableName, "months(ts)");
+ testMonthsFunction(true);
+ }
+
+ private void testMonthsFunction(boolean partitioned) {
+ int targetMonths =
timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00");
+ String query =
+ String.format(
+ "SELECT * FROM %s WHERE system.months(ts) > %s ORDER BY id",
tableName, targetMonths);
+
+ Dataset<Row> df = spark.sql(query);
+ LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+ checkExpressions(optimizedPlan, partitioned, "months");
+ checkPushedFilters(optimizedPlan, greaterThan(month("ts"), targetMonths));
+
+ List<Object[]> actual = rowsToJava(df.collectAsList());
+ Assertions.assertThat(actual.size()).isEqualTo(5);
+ }
+
+ @Test
+ public void testDaysFunctionOnUnpartitionedTable() {
+ createUnpartitionedTable(spark, tableName);
+ testDaysFunction(false);
+ }
+
+ @Test
+ public void testDaysFunctionOnPartitionedTable() {
+ createPartitionedTable(spark, tableName, "days(ts)");
+ testDaysFunction(true);
+ }
+
+ private void testDaysFunction(boolean partitioned) {
+ String timestamp = "2018-11-20T00:00:00.000000+00:00";
+ int targetDays = timestampStrToDayOrdinal(timestamp);
+ String query =
+ String.format(
+ "SELECT * FROM %s WHERE system.days(ts) < date('%s') ORDER BY id",
+ tableName, timestamp);
+
+ Dataset<Row> df = spark.sql(query);
+ LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+ checkExpressions(optimizedPlan, partitioned, "days");
+ checkPushedFilters(optimizedPlan, lessThan(day("ts"), targetDays));
+
+ List<Object[]> actual = rowsToJava(df.collectAsList());
+ Assertions.assertThat(actual.size()).isEqualTo(5);
+ }
+
+ @Test
+ public void testHoursFunctionOnUnpartitionedTable() {
+ createUnpartitionedTable(spark, tableName);
+ testHoursFunction(false);
+ }
+
+ @Test
+ public void testHoursFunctionOnPartitionedTable() {
+ createPartitionedTable(spark, tableName, "hours(ts)");
+ testHoursFunction(true);
+ }
+
+ private void testHoursFunction(boolean partitioned) {
+ int targetHours =
timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00");
+ String query =
+ String.format(
+ "SELECT * FROM %s WHERE system.hours(ts) >= %s ORDER BY id",
tableName, targetHours);
+
+ Dataset<Row> df = spark.sql(query);
+ LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+ checkExpressions(optimizedPlan, partitioned, "hours");
+ checkPushedFilters(optimizedPlan, greaterThanOrEqual(hour("ts"),
targetHours));
+
+ List<Object[]> actual = rowsToJava(df.collectAsList());
+ Assertions.assertThat(actual.size()).isEqualTo(8);
+ }
+
+ @Test
+ public void testBucketLongFunctionOnUnpartitionedTable() {
+ createUnpartitionedTable(spark, tableName);
+ testBucketLongFunction(false);
+ }
+
+ @Test
+ public void testBucketLongFunctionOnPartitionedTable() {
+ createPartitionedTable(spark, tableName, "bucket(5, id)");
+ testBucketLongFunction(true);
+ }
+
+ private void testBucketLongFunction(boolean partitioned) {
+ int target = 2;
+ String query =
+ String.format(
+ "SELECT * FROM %s WHERE system.bucket(5, id) <= %s ORDER BY id",
tableName, target);
+
+ Dataset<Row> df = spark.sql(query);
+ LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+ checkExpressions(optimizedPlan, partitioned, "bucket");
+ checkPushedFilters(optimizedPlan, lessThanOrEqual(bucket("id", 5),
target));
+
+ List<Object[]> actual = rowsToJava(df.collectAsList());
+ Assertions.assertThat(actual.size()).isEqualTo(5);
+ }
+
+ @Test
+ public void testBucketStringFunctionOnUnpartitionedTable() {
+ createUnpartitionedTable(spark, tableName);
+ testBucketStringFunction(false);
+ }
+
+ @Test
+ public void testBucketStringFunctionOnPartitionedTable() {
+ createPartitionedTable(spark, tableName, "bucket(5, data)");
+ testBucketStringFunction(true);
+ }
+
+ private void testBucketStringFunction(boolean partitioned) {
+ int target = 2;
+ String query =
+ String.format(
+ "SELECT * FROM %s WHERE system.bucket(5, data) != %s ORDER BY id",
tableName, target);
+
+ Dataset<Row> df = spark.sql(query);
+ LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+ checkExpressions(optimizedPlan, partitioned, "bucket");
+ checkPushedFilters(optimizedPlan, notEqual(bucket("data", 5), target));
+
+ List<Object[]> actual = rowsToJava(df.collectAsList());
+ Assertions.assertThat(actual.size()).isEqualTo(8);
+ }
+
+ @Test
+ public void testTruncateFunctionOnUnpartitionedTable() {
+ createUnpartitionedTable(spark, tableName);
+ testTruncateFunction(false);
+ }
+
+ @Test
+ public void testTruncateFunctionOnPartitionedTable() {
+ createPartitionedTable(spark, tableName, "truncate(4, data)");
+ testTruncateFunction(true);
+ }
+
+ private void testTruncateFunction(boolean partitioned) {
+ String target = "data";
+ String query =
+ String.format(
+ "SELECT * FROM %s WHERE system.truncate(4, data) = '%s' ORDER BY
id",
+ tableName, target);
+
+ Dataset<Row> df = spark.sql(query);
+ LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();
+
+ checkExpressions(optimizedPlan, partitioned, "truncate");
+ checkPushedFilters(optimizedPlan, equal(truncate("data", 4), target));
+
+ List<Object[]> actual = rowsToJava(df.collectAsList());
+ Assertions.assertThat(actual.size()).isEqualTo(5);
+ }
+
+ private void checkExpressions(
+ LogicalPlan optimizedPlan, boolean partitioned, String
expectedFunctionName) {
+ List<Expression> staticInvokes =
+ PlanUtils.collectSparkExpressions(
+ optimizedPlan, expression -> expression instanceof StaticInvoke);
+ Assertions.assertThat(staticInvokes).isEmpty();
+
+ List<Expression> applyExpressions =
+ PlanUtils.collectSparkExpressions(
+ optimizedPlan, expression -> expression instanceof
ApplyFunctionExpression);
+
+ if (partitioned) {
+ Assertions.assertThat(applyExpressions).isEmpty();
+ } else {
+ Assertions.assertThat(applyExpressions.size()).isEqualTo(1);
+ ApplyFunctionExpression expression = (ApplyFunctionExpression)
applyExpressions.get(0);
+ Assertions.assertThat(expression.name()).isEqualTo(expectedFunctionName);
+ }
+ }
+
+ private void checkPushedFilters(
+ LogicalPlan optimizedPlan, org.apache.iceberg.expressions.Expression
expected) {
+ List<org.apache.iceberg.expressions.Expression> pushedFilters =
+ PlanUtils.collectPushDownFilters(optimizedPlan);
+ Assertions.assertThat(pushedFilters.size()).isEqualTo(1);
+ org.apache.iceberg.expressions.Expression actual = pushedFilters.get(0);
+ Assertions.assertThat(ExpressionUtil.equivalent(expected, actual, STRUCT,
true))
+ .as("Pushed filter should match")
+ .isTrue();
+ }
+}
diff --git
a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java
new file mode 100644
index 0000000000..148717e142
--- /dev/null
+++
b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/source/PlanUtils.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.source;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation;
+import scala.PartialFunction;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+
+public class PlanUtils {
+ private PlanUtils() {}
+
+ public static List<org.apache.iceberg.expressions.Expression>
collectPushDownFilters(
+ LogicalPlan logicalPlan) {
+ return
JavaConverters.asJavaCollection(logicalPlan.collectLeaves()).stream()
+ .flatMap(
+ plan -> {
+ if (!(plan instanceof DataSourceV2ScanRelation)) {
+ return Stream.empty();
+ }
+
+ DataSourceV2ScanRelation scanRelation =
(DataSourceV2ScanRelation) plan;
+ if (!(scanRelation.scan() instanceof SparkBatchQueryScan)) {
+ return Stream.empty();
+ }
+
+ SparkBatchQueryScan batchQueryScan = (SparkBatchQueryScan)
scanRelation.scan();
+ return batchQueryScan.filterExpressions().stream();
+ })
+ .collect(Collectors.toList());
+ }
+
+ public static List<Expression> collectSparkExpressions(
+ LogicalPlan logicalPlan, Predicate<Expression> predicate) {
+ Seq<List<Expression>> list =
+ logicalPlan.collect(
+ new PartialFunction<LogicalPlan, List<Expression>>() {
+
+ @Override
+ public List<Expression> apply(LogicalPlan plan) {
+ return
JavaConverters.asJavaCollection(plan.expressions()).stream()
+ .flatMap(expr -> collectSparkExpressions(expr,
predicate).stream())
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public boolean isDefinedAt(LogicalPlan plan) {
+ return true;
+ }
+ });
+
+ return JavaConverters.asJavaCollection(list).stream()
+ .flatMap(Collection::stream)
+ .collect(Collectors.toList());
+ }
+
+ private static List<Expression> collectSparkExpressions(
+ Expression expression, Predicate<Expression> predicate) {
+ Seq<Expression> list =
+ expression.collect(
+ new PartialFunction<Expression, Expression>() {
+ @Override
+ public Expression apply(Expression expr) {
+ return expr;
+ }
+
+ @Override
+ public boolean isDefinedAt(Expression expr) {
+ return predicate.test(expr);
+ }
+ });
+
+ return Lists.newArrayList(JavaConverters.asJavaCollection(list));
+ }
+}
diff --git
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java
index d14bd45831..6d9cadec57 100644
---
a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java
+++
b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/SparkFunctions.java
@@ -39,6 +39,15 @@ public class SparkFunctions {
"bucket", new BucketFunction(),
"truncate", new TruncateFunction());
+ private static final Map<Class<?>, UnboundFunction> CLASS_TO_FUNCTIONS =
+ ImmutableMap.of(
+ YearsFunction.class, new YearsFunction(),
+ MonthsFunction.class, new MonthsFunction(),
+ DaysFunction.class, new DaysFunction(),
+ HoursFunction.class, new HoursFunction(),
+ BucketFunction.class, new BucketFunction(),
+ TruncateFunction.class, new TruncateFunction());
+
private static final List<String> FUNCTION_NAMES =
ImmutableList.copyOf(FUNCTIONS.keySet());
// Functions that are added to all Iceberg catalogs should be accessed with
the `system`
@@ -54,4 +63,13 @@ public class SparkFunctions {
// function resolution is case-insensitive to match the existing Spark
behavior for functions
return FUNCTIONS.get(name.toLowerCase(Locale.ROOT));
}
+
+ public static UnboundFunction loadFunctionByClass(Class<?> functionClass) {
+ Class<?> declaringClass = functionClass.getDeclaringClass();
+ if (declaringClass == null) {
+ return null;
+ }
+
+ return CLASS_TO_FUNCTIONS.get(declaringClass);
+ }
}
diff --git
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java
index 9258bb4f0e..059325e02a 100644
---
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java
+++
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/SystemFunctionPushDownHelper.java
@@ -18,10 +18,17 @@
*/
package org.apache.iceberg.spark;
+import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.DateTimeUtil;
import org.apache.spark.sql.SparkSession;
public class SystemFunctionPushDownHelper {
+ public static final Types.StructType STRUCT =
+ Types.StructType.of(
+ Types.NestedField.optional(1, "id", Types.LongType.get()),
+ Types.NestedField.optional(2, "ts", Types.TimestampType.withZone()),
+ Types.NestedField.optional(3, "data", Types.StringType.get()));
+
private SystemFunctionPushDownHelper() {}
public static void createUnpartitionedTable(SparkSession spark, String
tableName) {
@@ -98,19 +105,19 @@ public class SystemFunctionPushDownHelper {
"(9, CAST('2018-12-21T15:02:15.230570+00:00' AS TIMESTAMP),
'material-9')");
}
- public static int years(String date) {
- return DateTimeUtil.daysToYears(DateTimeUtil.isoDateToDays(date));
+ public static int timestampStrToYearOrdinal(String timestamp) {
+ return
DateTimeUtil.microsToYears(DateTimeUtil.isoTimestamptzToMicros(timestamp));
}
- public static int months(String date) {
- return DateTimeUtil.daysToMonths(DateTimeUtil.isoDateToDays(date));
+ public static int timestampStrToMonthOrdinal(String timestamp) {
+ return
DateTimeUtil.microsToMonths(DateTimeUtil.isoTimestamptzToMicros(timestamp));
}
- public static int days(String date) {
- return DateTimeUtil.isoDateToDays(date);
+ public static int timestampStrToDayOrdinal(String timestamp) {
+ return
DateTimeUtil.microsToDays(DateTimeUtil.isoTimestamptzToMicros(timestamp));
}
- public static int hours(String timestamp) {
+ public static int timestampStrToHourOrdinal(String timestamp) {
return
DateTimeUtil.microsToHours(DateTimeUtil.isoTimestamptzToMicros(timestamp));
}
diff --git
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java
new file mode 100644
index 0000000000..34308e77d2
--- /dev/null
+++
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/functions/TestSparkFunctions.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.functions;
+
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.DecimalType;
+import org.assertj.core.api.Assertions;
+import org.junit.Test;
+
+public class TestSparkFunctions {
+
+ @Test
+ public void testBuildYearsFunctionFromClass() {
+ UnboundFunction expected = new YearsFunction();
+
+ YearsFunction.DateToYearsFunction dateToYearsFunc = new
YearsFunction.DateToYearsFunction();
+ checkBuildFunc(dateToYearsFunc, expected);
+
+ YearsFunction.TimestampToYearsFunction tsToYearsFunc =
+ new YearsFunction.TimestampToYearsFunction();
+ checkBuildFunc(tsToYearsFunc, expected);
+
+ YearsFunction.TimestampNtzToYearsFunction tsNtzToYearsFunc =
+ new YearsFunction.TimestampNtzToYearsFunction();
+ checkBuildFunc(tsNtzToYearsFunc, expected);
+ }
+
+ @Test
+ public void testBuildMonthsFunctionFromClass() {
+ UnboundFunction expected = new MonthsFunction();
+
+ MonthsFunction.DateToMonthsFunction dateToMonthsFunc =
+ new MonthsFunction.DateToMonthsFunction();
+ checkBuildFunc(dateToMonthsFunc, expected);
+
+ MonthsFunction.TimestampToMonthsFunction tsToMonthsFunc =
+ new MonthsFunction.TimestampToMonthsFunction();
+ checkBuildFunc(tsToMonthsFunc, expected);
+
+ MonthsFunction.TimestampNtzToMonthsFunction tsNtzToMonthsFunc =
+ new MonthsFunction.TimestampNtzToMonthsFunction();
+ checkBuildFunc(tsNtzToMonthsFunc, expected);
+ }
+
+ @Test
+ public void testBuildDaysFunctionFromClass() {
+ UnboundFunction expected = new DaysFunction();
+
+ DaysFunction.DateToDaysFunction dateToDaysFunc = new
DaysFunction.DateToDaysFunction();
+ checkBuildFunc(dateToDaysFunc, expected);
+
+ DaysFunction.TimestampToDaysFunction tsToDaysFunc = new
DaysFunction.TimestampToDaysFunction();
+ checkBuildFunc(tsToDaysFunc, expected);
+
+ DaysFunction.TimestampNtzToDaysFunction tsNtzToDaysFunc =
+ new DaysFunction.TimestampNtzToDaysFunction();
+ checkBuildFunc(tsNtzToDaysFunc, expected);
+ }
+
+ @Test
+ public void testBuildHoursFunctionFromClass() {
+ UnboundFunction expected = new HoursFunction();
+
+ HoursFunction.TimestampToHoursFunction tsToHoursFunc =
+ new HoursFunction.TimestampToHoursFunction();
+ checkBuildFunc(tsToHoursFunc, expected);
+
+ HoursFunction.TimestampNtzToHoursFunction tsNtzToHoursFunc =
+ new HoursFunction.TimestampNtzToHoursFunction();
+ checkBuildFunc(tsNtzToHoursFunc, expected);
+ }
+
+ @Test
+ public void testBuildBucketFunctionFromClass() {
+ UnboundFunction expected = new BucketFunction();
+
+ BucketFunction.BucketInt bucketDateFunc = new
BucketFunction.BucketInt(DataTypes.DateType);
+ checkBuildFunc(bucketDateFunc, expected);
+
+ BucketFunction.BucketInt bucketIntFunc = new
BucketFunction.BucketInt(DataTypes.IntegerType);
+ checkBuildFunc(bucketIntFunc, expected);
+
+ BucketFunction.BucketLong bucketLongFunc = new
BucketFunction.BucketLong(DataTypes.LongType);
+ checkBuildFunc(bucketLongFunc, expected);
+
+ BucketFunction.BucketLong bucketTsFunc = new
BucketFunction.BucketLong(DataTypes.TimestampType);
+ checkBuildFunc(bucketTsFunc, expected);
+
+ BucketFunction.BucketLong bucketTsNtzFunc =
+ new BucketFunction.BucketLong(DataTypes.TimestampNTZType);
+ checkBuildFunc(bucketTsNtzFunc, expected);
+
+ BucketFunction.BucketDecimal bucketDecimalFunc =
+ new BucketFunction.BucketDecimal(new DecimalType());
+ checkBuildFunc(bucketDecimalFunc, expected);
+
+ BucketFunction.BucketString bucketStringFunc = new
BucketFunction.BucketString();
+ checkBuildFunc(bucketStringFunc, expected);
+
+ BucketFunction.BucketBinary bucketBinary = new
BucketFunction.BucketBinary();
+ checkBuildFunc(bucketBinary, expected);
+ }
+
+ @Test
+ public void testBuildTruncateFunctionFromClass() {
+ UnboundFunction expected = new TruncateFunction();
+
+ TruncateFunction.TruncateTinyInt truncateTinyIntFunc = new
TruncateFunction.TruncateTinyInt();
+ checkBuildFunc(truncateTinyIntFunc, expected);
+
+ TruncateFunction.TruncateSmallInt truncateSmallIntFunc =
+ new TruncateFunction.TruncateSmallInt();
+ checkBuildFunc(truncateSmallIntFunc, expected);
+
+ TruncateFunction.TruncateInt truncateIntFunc = new
TruncateFunction.TruncateInt();
+ checkBuildFunc(truncateIntFunc, expected);
+
+ TruncateFunction.TruncateBigInt truncateBigIntFunc = new
TruncateFunction.TruncateBigInt();
+ checkBuildFunc(truncateBigIntFunc, expected);
+
+ TruncateFunction.TruncateDecimal truncateDecimalFunc =
+ new TruncateFunction.TruncateDecimal(10, 9);
+ checkBuildFunc(truncateDecimalFunc, expected);
+
+ TruncateFunction.TruncateString truncateStringFunc = new
TruncateFunction.TruncateString();
+ checkBuildFunc(truncateStringFunc, expected);
+
+ TruncateFunction.TruncateBinary truncateBinaryFunc = new
TruncateFunction.TruncateBinary();
+ checkBuildFunc(truncateBinaryFunc, expected);
+ }
+
+ private void checkBuildFunc(ScalarFunction<?> function, UnboundFunction
expected) {
+ UnboundFunction actual =
SparkFunctions.loadFunctionByClass(function.getClass());
+
+ Assertions.assertThat(actual).isNotNull();
+ Assertions.assertThat(actual.name()).isEqualTo(expected.name());
+
Assertions.assertThat(actual.description()).isEqualTo(expected.description());
+ }
+}
diff --git
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java
index b3373e908b..78d169bf73 100644
---
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java
+++
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkScan.java
@@ -20,10 +20,10 @@ package org.apache.iceberg.spark.source;
import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.createPartitionedTable;
import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.createUnpartitionedTable;
-import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.days;
-import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.hours;
-import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.months;
-import static org.apache.iceberg.spark.SystemFunctionPushDownHelper.years;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToDayOrdinal;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToHourOrdinal;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToMonthOrdinal;
+import static
org.apache.iceberg.spark.SystemFunctionPushDownHelper.timestampStrToYearOrdinal;
import static org.apache.spark.sql.functions.date_add;
import static org.apache.spark.sql.functions.expr;
@@ -120,7 +120,11 @@ public class TestSparkScan extends
SparkTestBaseWithCatalog {
YearsFunction.TimestampToYearsFunction function = new
YearsFunction.TimestampToYearsFunction();
UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
- Predicate predicate = new Predicate("=", expressions(udf,
intLit(years("2017-11-22"))));
+ Predicate predicate =
+ new Predicate(
+ "=",
+ expressions(
+ udf,
intLit(timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00"))));
pushFilters(builder, predicate);
Batch scan = builder.build().toBatch();
@@ -144,7 +148,11 @@ public class TestSparkScan extends
SparkTestBaseWithCatalog {
YearsFunction.TimestampToYearsFunction function = new
YearsFunction.TimestampToYearsFunction();
UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
- Predicate predicate = new Predicate("=", expressions(udf,
intLit(years("2017-11-22"))));
+ Predicate predicate =
+ new Predicate(
+ "=",
+ expressions(
+ udf,
intLit(timestampStrToYearOrdinal("2017-11-22T00:00:00.000000+00:00"))));
pushFilters(builder, predicate);
Batch scan = builder.build().toBatch();
@@ -169,7 +177,11 @@ public class TestSparkScan extends
SparkTestBaseWithCatalog {
MonthsFunction.TimestampToMonthsFunction function =
new MonthsFunction.TimestampToMonthsFunction();
UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
- Predicate predicate = new Predicate(">", expressions(udf,
intLit(months("2017-11-22"))));
+ Predicate predicate =
+ new Predicate(
+ ">",
+ expressions(
+ udf,
intLit(timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00"))));
pushFilters(builder, predicate);
Batch scan = builder.build().toBatch();
@@ -194,7 +206,11 @@ public class TestSparkScan extends
SparkTestBaseWithCatalog {
MonthsFunction.TimestampToMonthsFunction function =
new MonthsFunction.TimestampToMonthsFunction();
UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
- Predicate predicate = new Predicate(">", expressions(udf,
intLit(months("2017-11-22"))));
+ Predicate predicate =
+ new Predicate(
+ ">",
+ expressions(
+ udf,
intLit(timestampStrToMonthOrdinal("2017-11-22T00:00:00.000000+00:00"))));
pushFilters(builder, predicate);
Batch scan = builder.build().toBatch();
@@ -218,7 +234,11 @@ public class TestSparkScan extends
SparkTestBaseWithCatalog {
DaysFunction.TimestampToDaysFunction function = new
DaysFunction.TimestampToDaysFunction();
UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
- Predicate predicate = new Predicate("<", expressions(udf,
dateLit(days("2018-11-20"))));
+ Predicate predicate =
+ new Predicate(
+ "<",
+ expressions(
+ udf,
dateLit(timestampStrToDayOrdinal("2018-11-20T00:00:00.000000+00:00"))));
pushFilters(builder, predicate);
Batch scan = builder.build().toBatch();
@@ -242,7 +262,11 @@ public class TestSparkScan extends
SparkTestBaseWithCatalog {
DaysFunction.TimestampToDaysFunction function = new
DaysFunction.TimestampToDaysFunction();
UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
- Predicate predicate = new Predicate("<", expressions(udf,
dateLit(days("2018-11-20"))));
+ Predicate predicate =
+ new Predicate(
+ "<",
+ expressions(
+ udf,
dateLit(timestampStrToDayOrdinal("2018-11-20T00:00:00.000000+00:00"))));
pushFilters(builder, predicate);
Batch scan = builder.build().toBatch();
@@ -267,7 +291,10 @@ public class TestSparkScan extends
SparkTestBaseWithCatalog {
HoursFunction.TimestampToHoursFunction function = new
HoursFunction.TimestampToHoursFunction();
UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
Predicate predicate =
- new Predicate(">=", expressions(udf,
intLit(hours("2017-11-22T06:02:09.243857+00:00"))));
+ new Predicate(
+ ">=",
+ expressions(
+ udf,
intLit(timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00"))));
pushFilters(builder, predicate);
Batch scan = builder.build().toBatch();
@@ -292,7 +319,10 @@ public class TestSparkScan extends
SparkTestBaseWithCatalog {
HoursFunction.TimestampToHoursFunction function = new
HoursFunction.TimestampToHoursFunction();
UserDefinedScalarFunc udf = toUDF(function, expressions(fieldRef("ts")));
Predicate predicate =
- new Predicate(">=", expressions(udf,
intLit(hours("2017-11-22T06:02:09.243857+00:00"))));
+ new Predicate(
+ ">=",
+ expressions(
+ udf,
intLit(timestampStrToHourOrdinal("2017-11-22T06:02:09.243857+00:00"))));
pushFilters(builder, predicate);
Batch scan = builder.build().toBatch();