This is an automated email from the ASF dual-hosted git repository.
etudenhoefner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/main by this push:
new d23c4902eb Spark: Backport tests for struct aggregation pushdown to
3.3/3.4, cleanup assertion (#10333)
d23c4902eb is described below
commit d23c4902eb7ed319176bb7d74e04b6f2175e3593
Author: Amogh Jahagirdar <[email protected]>
AuthorDate: Tue May 14 01:10:43 2024 -0600
Spark: Backport tests for struct aggregation pushdown to 3.3/3.4, cleanup
assertion (#10333)
---
.../iceberg/spark/sql/TestAggregatePushDown.java | 122 +++++++++++++++++++++
.../iceberg/spark/sql/TestAggregatePushDown.java | 122 +++++++++++++++++++++
.../iceberg/spark/sql/TestAggregatePushDown.java | 6 +-
3 files changed, 247 insertions(+), 3 deletions(-)
diff --git
a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index 37ae96a248..9ea1a563ef 100644
---
a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++
b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -21,6 +21,7 @@ package org.apache.iceberg.spark.sql;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
+import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -35,6 +36,7 @@ import
org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.SparkCatalogTestBase;
import org.apache.iceberg.spark.SparkTestBase;
import org.apache.spark.sql.SparkSession;
+import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Assert;
import org.junit.BeforeClass;
@@ -470,6 +472,126 @@ public class TestAggregatePushDown extends
SparkCatalogTestBase {
Assert.assertFalse("max not pushed down for complex types",
explainContainsPushDownAggregates);
}
+ @Test
+ public void testAggregationPushdownStructInteger() {
+ sql("CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:BIGINT>) USING
iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))",
tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2))", tableName);
+ sql("INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 3))", tableName);
+
+ String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+ String aggField = "struct_with_int.c1";
+ assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L,
3L, 2L);
+ assertExplainContains(
+ sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+ "count(struct_with_int.c1)",
+ "max(struct_with_int.c1)",
+ "min(struct_with_int.c1)");
+ }
+
+ @Test
+ public void testAggregationPushdownNestedStruct() {
+ sql(
+ "CREATE TABLE %s (id BIGINT, struct_with_int
STRUCT<c1:STRUCT<c2:STRUCT<c3:STRUCT<c4:BIGINT>>>>) USING iceberg",
+ tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\",
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", NULL)))))",
+ tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\",
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 2)))))",
+ tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\",
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 3)))))",
+ tableName);
+
+ String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+ String aggField = "struct_with_int.c1.c2.c3.c4";
+
+ assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L,
3L, 2L);
+
+ assertExplainContains(
+ sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+ "count(struct_with_int.c1.c2.c3.c4)",
+ "max(struct_with_int.c1.c2.c3.c4)",
+ "min(struct_with_int.c1.c2.c3.c4)");
+ }
+
+ @Test
+ public void testAggregationPushdownStructTimestamp() {
+ sql(
+ "CREATE TABLE %s (id BIGINT, struct_with_ts STRUCT<c1:TIMESTAMP>)
USING iceberg",
+ tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))",
tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\",
timestamp('2023-01-30T22:22:22Z')))",
+ tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\",
timestamp('2023-01-30T22:23:23Z')))",
+ tableName);
+
+ String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+ String aggField = "struct_with_ts.c1";
+
+ assertAggregates(
+ sql(query, aggField, aggField, aggField, tableName),
+ 2L,
+ new Timestamp(1675117403000L),
+ new Timestamp(1675117342000L));
+
+ assertExplainContains(
+ sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+ "count(struct_with_ts.c1)",
+ "max(struct_with_ts.c1)",
+ "min(struct_with_ts.c1)");
+ }
+
+ @Test
+ public void testAggregationPushdownOnBucketedColumn() {
+ sql(
+ "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:INT>) USING
iceberg PARTITIONED BY (bucket(8, id))",
+ tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))",
tableName);
+ sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))",
tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName);
+
+ String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+ String aggField = "id";
+ assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L,
2L, 1L);
+ assertExplainContains(
+ sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+ "count(id)",
+ "max(id)",
+ "min(id)");
+ }
+
+ private void assertAggregates(
+ List<Object[]> actual, Object expectedCount, Object expectedMax, Object
expectedMin) {
+ Object actualCount = actual.get(0)[0];
+ Object actualMax = actual.get(0)[1];
+ Object actualMin = actual.get(0)[2];
+
+ Assertions.assertThat(actualCount)
+ .as("Expected and actual count should equal")
+ .isEqualTo(expectedCount);
+ Assertions.assertThat(actualMax)
+ .as("Expected and actual max should equal")
+ .isEqualTo(expectedMax);
+ Assertions.assertThat(actualMin)
+ .as("Expected and actual min should equal")
+ .isEqualTo(expectedMin);
+ }
+
+ private void assertExplainContains(List<Object[]> explain, String...
expectedFragments) {
+ String explainString =
explain.get(0)[0].toString().toLowerCase(Locale.ROOT);
+ Arrays.stream(expectedFragments)
+ .forEach(
+ fragment ->
+ Assertions.assertThat(explainString)
+ .as("Expected to find plan fragment in explain plan")
+ .contains(fragment));
+ }
+
@Test
public void testAggregatePushDownInDeleteCopyOnWrite() {
sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
diff --git
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index 37ae96a248..9ea1a563ef 100644
---
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -21,6 +21,7 @@ package org.apache.iceberg.spark.sql;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
+import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -35,6 +36,7 @@ import
org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.spark.SparkCatalogTestBase;
import org.apache.iceberg.spark.SparkTestBase;
import org.apache.spark.sql.SparkSession;
+import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Assert;
import org.junit.BeforeClass;
@@ -470,6 +472,126 @@ public class TestAggregatePushDown extends
SparkCatalogTestBase {
Assert.assertFalse("max not pushed down for complex types",
explainContainsPushDownAggregates);
}
+ @Test
+ public void testAggregationPushdownStructInteger() {
+ sql("CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:BIGINT>) USING
iceberg", tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))",
tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2))", tableName);
+ sql("INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 3))", tableName);
+
+ String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+ String aggField = "struct_with_int.c1";
+ assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L,
3L, 2L);
+ assertExplainContains(
+ sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+ "count(struct_with_int.c1)",
+ "max(struct_with_int.c1)",
+ "min(struct_with_int.c1)");
+ }
+
+ @Test
+ public void testAggregationPushdownNestedStruct() {
+ sql(
+ "CREATE TABLE %s (id BIGINT, struct_with_int
STRUCT<c1:STRUCT<c2:STRUCT<c3:STRUCT<c4:BIGINT>>>>) USING iceberg",
+ tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\",
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", NULL)))))",
+ tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\",
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 2)))))",
+ tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\",
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 3)))))",
+ tableName);
+
+ String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+ String aggField = "struct_with_int.c1.c2.c3.c4";
+
+ assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L,
3L, 2L);
+
+ assertExplainContains(
+ sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+ "count(struct_with_int.c1.c2.c3.c4)",
+ "max(struct_with_int.c1.c2.c3.c4)",
+ "min(struct_with_int.c1.c2.c3.c4)");
+ }
+
+ @Test
+ public void testAggregationPushdownStructTimestamp() {
+ sql(
+ "CREATE TABLE %s (id BIGINT, struct_with_ts STRUCT<c1:TIMESTAMP>)
USING iceberg",
+ tableName);
+ sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))",
tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\",
timestamp('2023-01-30T22:22:22Z')))",
+ tableName);
+ sql(
+ "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\",
timestamp('2023-01-30T22:23:23Z')))",
+ tableName);
+
+ String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+ String aggField = "struct_with_ts.c1";
+
+ assertAggregates(
+ sql(query, aggField, aggField, aggField, tableName),
+ 2L,
+ new Timestamp(1675117403000L),
+ new Timestamp(1675117342000L));
+
+ assertExplainContains(
+ sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+ "count(struct_with_ts.c1)",
+ "max(struct_with_ts.c1)",
+ "min(struct_with_ts.c1)");
+ }
+
+ @Test
+ public void testAggregationPushdownOnBucketedColumn() {
+ sql(
+ "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:INT>) USING
iceberg PARTITIONED BY (bucket(8, id))",
+ tableName);
+
+ sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))",
tableName);
+ sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))",
tableName);
+ sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName);
+
+ String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+ String aggField = "id";
+ assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L,
2L, 1L);
+ assertExplainContains(
+ sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+ "count(id)",
+ "max(id)",
+ "min(id)");
+ }
+
+ private void assertAggregates(
+ List<Object[]> actual, Object expectedCount, Object expectedMax, Object
expectedMin) {
+ Object actualCount = actual.get(0)[0];
+ Object actualMax = actual.get(0)[1];
+ Object actualMin = actual.get(0)[2];
+
+ Assertions.assertThat(actualCount)
+ .as("Expected and actual count should equal")
+ .isEqualTo(expectedCount);
+ Assertions.assertThat(actualMax)
+ .as("Expected and actual max should equal")
+ .isEqualTo(expectedMax);
+ Assertions.assertThat(actualMin)
+ .as("Expected and actual min should equal")
+ .isEqualTo(expectedMin);
+ }
+
+ private void assertExplainContains(List<Object[]> explain, String...
expectedFragments) {
+ String explainString =
explain.get(0)[0].toString().toLowerCase(Locale.ROOT);
+ Arrays.stream(expectedFragments)
+ .forEach(
+ fragment ->
+ Assertions.assertThat(explainString)
+ .as("Expected to find plan fragment in explain plan")
+ .contains(fragment));
+ }
+
@Test
public void testAggregatePushDownInDeleteCopyOnWrite() {
sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
diff --git
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index 603b1a1cfb..7fdd5163f1 100644
---
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -595,9 +595,9 @@ public class TestAggregatePushDown extends CatalogTestBase {
Arrays.stream(expectedFragments)
.forEach(
fragment ->
- Assertions.assertThat(explainString.contains(fragment))
- .isTrue()
- .as("Expected to find plan fragment in explain plan"));
+ Assertions.assertThat(explainString)
+ .as("Expected to find plan fragment in explain plan")
+ .contains(fragment));
}
@TestTemplate