This is an automated email from the ASF dual-hosted git repository.

etudenhoefner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new d23c4902eb Spark: Backport tests for struct aggregation pushdown to 
3.3/3.4, cleanup assertion (#10333)
d23c4902eb is described below

commit d23c4902eb7ed319176bb7d74e04b6f2175e3593
Author: Amogh Jahagirdar <[email protected]>
AuthorDate: Tue May 14 01:10:43 2024 -0600

    Spark: Backport tests for struct aggregation pushdown to 3.3/3.4, cleanup 
assertion (#10333)
---
 .../iceberg/spark/sql/TestAggregatePushDown.java   | 122 +++++++++++++++++++++
 .../iceberg/spark/sql/TestAggregatePushDown.java   | 122 +++++++++++++++++++++
 .../iceberg/spark/sql/TestAggregatePushDown.java   |   6 +-
 3 files changed, 247 insertions(+), 3 deletions(-)

diff --git 
a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
 
b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index 37ae96a248..9ea1a563ef 100644
--- 
a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++ 
b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -21,6 +21,7 @@ package org.apache.iceberg.spark.sql;
 import java.math.BigDecimal;
 import java.sql.Date;
 import java.sql.Timestamp;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -35,6 +36,7 @@ import 
org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.spark.SparkCatalogTestBase;
 import org.apache.iceberg.spark.SparkTestBase;
 import org.apache.spark.sql.SparkSession;
+import org.assertj.core.api.Assertions;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.BeforeClass;
@@ -470,6 +472,126 @@ public class TestAggregatePushDown extends 
SparkCatalogTestBase {
     Assert.assertFalse("max not pushed down for complex types", 
explainContainsPushDownAggregates);
   }
 
+  @Test
+  public void testAggregationPushdownStructInteger() {
+    sql("CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:BIGINT>) USING 
iceberg", tableName);
+    sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", 
tableName);
+    sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2))", tableName);
+    sql("INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 3))", tableName);
+
+    String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+    String aggField = "struct_with_int.c1";
+    assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 
3L, 2L);
+    assertExplainContains(
+        sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+        "count(struct_with_int.c1)",
+        "max(struct_with_int.c1)",
+        "min(struct_with_int.c1)");
+  }
+
+  @Test
+  public void testAggregationPushdownNestedStruct() {
+    sql(
+        "CREATE TABLE %s (id BIGINT, struct_with_int 
STRUCT<c1:STRUCT<c2:STRUCT<c3:STRUCT<c4:BIGINT>>>>) USING iceberg",
+        tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", NULL)))))",
+        tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 2)))))",
+        tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 3)))))",
+        tableName);
+
+    String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+    String aggField = "struct_with_int.c1.c2.c3.c4";
+
+    assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 
3L, 2L);
+
+    assertExplainContains(
+        sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+        "count(struct_with_int.c1.c2.c3.c4)",
+        "max(struct_with_int.c1.c2.c3.c4)",
+        "min(struct_with_int.c1.c2.c3.c4)");
+  }
+
+  @Test
+  public void testAggregationPushdownStructTimestamp() {
+    sql(
+        "CREATE TABLE %s (id BIGINT, struct_with_ts STRUCT<c1:TIMESTAMP>) 
USING iceberg",
+        tableName);
+    sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", 
tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 
timestamp('2023-01-30T22:22:22Z')))",
+        tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 
timestamp('2023-01-30T22:23:23Z')))",
+        tableName);
+
+    String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+    String aggField = "struct_with_ts.c1";
+
+    assertAggregates(
+        sql(query, aggField, aggField, aggField, tableName),
+        2L,
+        new Timestamp(1675117403000L),
+        new Timestamp(1675117342000L));
+
+    assertExplainContains(
+        sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+        "count(struct_with_ts.c1)",
+        "max(struct_with_ts.c1)",
+        "min(struct_with_ts.c1)");
+  }
+
+  @Test
+  public void testAggregationPushdownOnBucketedColumn() {
+    sql(
+        "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:INT>) USING 
iceberg PARTITIONED BY (bucket(8, id))",
+        tableName);
+
+    sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", 
tableName);
+    sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))", 
tableName);
+    sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName);
+
+    String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+    String aggField = "id";
+    assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 
2L, 1L);
+    assertExplainContains(
+        sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+        "count(id)",
+        "max(id)",
+        "min(id)");
+  }
+
+  private void assertAggregates(
+      List<Object[]> actual, Object expectedCount, Object expectedMax, Object 
expectedMin) {
+    Object actualCount = actual.get(0)[0];
+    Object actualMax = actual.get(0)[1];
+    Object actualMin = actual.get(0)[2];
+
+    Assertions.assertThat(actualCount)
+        .as("Expected and actual count should equal")
+        .isEqualTo(expectedCount);
+    Assertions.assertThat(actualMax)
+        .as("Expected and actual max should equal")
+        .isEqualTo(expectedMax);
+    Assertions.assertThat(actualMin)
+        .as("Expected and actual min should equal")
+        .isEqualTo(expectedMin);
+  }
+
+  private void assertExplainContains(List<Object[]> explain, String... 
expectedFragments) {
+    String explainString = 
explain.get(0)[0].toString().toLowerCase(Locale.ROOT);
+    Arrays.stream(expectedFragments)
+        .forEach(
+            fragment ->
+                Assertions.assertThat(explainString)
+                    .as("Expected to find plan fragment in explain plan")
+                    .contains(fragment));
+  }
+
   @Test
   public void testAggregatePushDownInDeleteCopyOnWrite() {
     sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
diff --git 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index 37ae96a248..9ea1a563ef 100644
--- 
a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++ 
b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -21,6 +21,7 @@ package org.apache.iceberg.spark.sql;
 import java.math.BigDecimal;
 import java.sql.Date;
 import java.sql.Timestamp;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -35,6 +36,7 @@ import 
org.apache.iceberg.relocated.com.google.common.collect.Lists;
 import org.apache.iceberg.spark.SparkCatalogTestBase;
 import org.apache.iceberg.spark.SparkTestBase;
 import org.apache.spark.sql.SparkSession;
+import org.assertj.core.api.Assertions;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.BeforeClass;
@@ -470,6 +472,126 @@ public class TestAggregatePushDown extends 
SparkCatalogTestBase {
     Assert.assertFalse("max not pushed down for complex types", 
explainContainsPushDownAggregates);
   }
 
+  @Test
+  public void testAggregationPushdownStructInteger() {
+    sql("CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:BIGINT>) USING 
iceberg", tableName);
+    sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", 
tableName);
+    sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2))", tableName);
+    sql("INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 3))", tableName);
+
+    String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+    String aggField = "struct_with_int.c1";
+    assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 
3L, 2L);
+    assertExplainContains(
+        sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+        "count(struct_with_int.c1)",
+        "max(struct_with_int.c1)",
+        "min(struct_with_int.c1)");
+  }
+
+  @Test
+  public void testAggregationPushdownNestedStruct() {
+    sql(
+        "CREATE TABLE %s (id BIGINT, struct_with_int 
STRUCT<c1:STRUCT<c2:STRUCT<c3:STRUCT<c4:BIGINT>>>>) USING iceberg",
+        tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", NULL)))))",
+        tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 2)))))",
+        tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 
named_struct(\"c2\", named_struct(\"c3\", named_struct(\"c4\", 3)))))",
+        tableName);
+
+    String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+    String aggField = "struct_with_int.c1.c2.c3.c4";
+
+    assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 
3L, 2L);
+
+    assertExplainContains(
+        sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+        "count(struct_with_int.c1.c2.c3.c4)",
+        "max(struct_with_int.c1.c2.c3.c4)",
+        "min(struct_with_int.c1.c2.c3.c4)");
+  }
+
+  @Test
+  public void testAggregationPushdownStructTimestamp() {
+    sql(
+        "CREATE TABLE %s (id BIGINT, struct_with_ts STRUCT<c1:TIMESTAMP>) 
USING iceberg",
+        tableName);
+    sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", 
tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 
timestamp('2023-01-30T22:22:22Z')))",
+        tableName);
+    sql(
+        "INSERT INTO TABLE %s VALUES (3, named_struct(\"c1\", 
timestamp('2023-01-30T22:23:23Z')))",
+        tableName);
+
+    String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+    String aggField = "struct_with_ts.c1";
+
+    assertAggregates(
+        sql(query, aggField, aggField, aggField, tableName),
+        2L,
+        new Timestamp(1675117403000L),
+        new Timestamp(1675117342000L));
+
+    assertExplainContains(
+        sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+        "count(struct_with_ts.c1)",
+        "max(struct_with_ts.c1)",
+        "min(struct_with_ts.c1)");
+  }
+
+  @Test
+  public void testAggregationPushdownOnBucketedColumn() {
+    sql(
+        "CREATE TABLE %s (id BIGINT, struct_with_int STRUCT<c1:INT>) USING 
iceberg PARTITIONED BY (bucket(8, id))",
+        tableName);
+
+    sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", NULL))", 
tableName);
+    sql("INSERT INTO TABLE %s VALUES (null, named_struct(\"c1\", 2))", 
tableName);
+    sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 3))", tableName);
+
+    String query = "SELECT COUNT(%s), MAX(%s), MIN(%s) FROM %s";
+    String aggField = "id";
+    assertAggregates(sql(query, aggField, aggField, aggField, tableName), 2L, 
2L, 1L);
+    assertExplainContains(
+        sql("EXPLAIN " + query, aggField, aggField, aggField, tableName),
+        "count(id)",
+        "max(id)",
+        "min(id)");
+  }
+
+  private void assertAggregates(
+      List<Object[]> actual, Object expectedCount, Object expectedMax, Object 
expectedMin) {
+    Object actualCount = actual.get(0)[0];
+    Object actualMax = actual.get(0)[1];
+    Object actualMin = actual.get(0)[2];
+
+    Assertions.assertThat(actualCount)
+        .as("Expected and actual count should equal")
+        .isEqualTo(expectedCount);
+    Assertions.assertThat(actualMax)
+        .as("Expected and actual max should equal")
+        .isEqualTo(expectedMax);
+    Assertions.assertThat(actualMin)
+        .as("Expected and actual min should equal")
+        .isEqualTo(expectedMin);
+  }
+
+  private void assertExplainContains(List<Object[]> explain, String... 
expectedFragments) {
+    String explainString = 
explain.get(0)[0].toString().toLowerCase(Locale.ROOT);
+    Arrays.stream(expectedFragments)
+        .forEach(
+            fragment ->
+                Assertions.assertThat(explainString)
+                    .as("Expected to find plan fragment in explain plan")
+                    .contains(fragment));
+  }
+
   @Test
   public void testAggregatePushDownInDeleteCopyOnWrite() {
     sql("CREATE TABLE %s (id LONG, data INT) USING iceberg", tableName);
diff --git 
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
index 603b1a1cfb..7fdd5163f1 100644
--- 
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
+++ 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAggregatePushDown.java
@@ -595,9 +595,9 @@ public class TestAggregatePushDown extends CatalogTestBase {
     Arrays.stream(expectedFragments)
         .forEach(
             fragment ->
-                Assertions.assertThat(explainString.contains(fragment))
-                    .isTrue()
-                    .as("Expected to find plan fragment in explain plan"));
+                Assertions.assertThat(explainString)
+                    .as("Expected to find plan fragment in explain plan")
+                    .contains(fragment));
   }
 
   @TestTemplate

Reply via email to