This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f6e7e6371 API, Spark 3.5: Action to compute table stats (#10288)
2f6e7e6371 is described below

commit 2f6e7e6371902bcb72f21deeaea8889d4768004e
Author: Karuppayya <[email protected]>
AuthorDate: Wed Aug 21 20:23:03 2024 -0700

    API, Spark 3.5: Action to compute table stats (#10288)
---
 .../apache/iceberg/actions/ActionsProvider.java    |   6 +
 .../apache/iceberg/actions/ComputeTableStats.java  |  47 +++
 .../org/apache/iceberg/GenericBlobMetadata.java    |   8 +
 .../iceberg/actions/BaseComputeTableStats.java     |  39 ++
 spark/v3.5/build.gradle                            |   2 +
 .../actions/ComputeTableStatsSparkAction.java      | 179 +++++++++
 .../iceberg/spark/actions/NDVSketchUtil.java       |  93 +++++
 .../apache/iceberg/spark/actions/SparkActions.java |   6 +
 .../apache/spark/sql/stats/ThetaSketchAgg.scala    | 121 ++++++
 .../spark/actions/TestComputeTableStatsAction.java | 417 +++++++++++++++++++++
 10 files changed, 918 insertions(+)

diff --git a/api/src/main/java/org/apache/iceberg/actions/ActionsProvider.java 
b/api/src/main/java/org/apache/iceberg/actions/ActionsProvider.java
index 2d6ff2679a..85773febae 100644
--- a/api/src/main/java/org/apache/iceberg/actions/ActionsProvider.java
+++ b/api/src/main/java/org/apache/iceberg/actions/ActionsProvider.java
@@ -70,4 +70,10 @@ public interface ActionsProvider {
     throw new UnsupportedOperationException(
         this.getClass().getName() + " does not implement 
rewritePositionDeletes");
   }
+
+  /** Instantiates an action to compute table stats. */
+  default ComputeTableStats computeTableStats(Table table) {
+    throw new UnsupportedOperationException(
+        this.getClass().getName() + " does not implement computeTableStats");
+  }
 }
diff --git 
a/api/src/main/java/org/apache/iceberg/actions/ComputeTableStats.java 
b/api/src/main/java/org/apache/iceberg/actions/ComputeTableStats.java
new file mode 100644
index 0000000000..04449d5916
--- /dev/null
+++ b/api/src/main/java/org/apache/iceberg/actions/ComputeTableStats.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.actions;
+
+import org.apache.iceberg.StatisticsFile;
+
+/** An action that collects statistics of an Iceberg table and writes to 
Puffin files. */
+public interface ComputeTableStats extends Action<ComputeTableStats, 
ComputeTableStats.Result> {
+  /**
+   * Choose the set of columns to collect stats, by default all columns are 
chosen.
+   *
+   * @param columns a set of column names to be analyzed
+   * @return this for method chaining
+   */
+  ComputeTableStats columns(String... columns);
+
+  /**
+   * Choose the table snapshot to compute stats, by default the current 
snapshot is used.
+   *
+   * @param snapshotId long ID of the snapshot for which stats need to be 
computed
+   * @return this for method chaining
+   */
+  ComputeTableStats snapshot(long snapshotId);
+
+  /** The result of table statistics collection. */
+  interface Result {
+
+    /** Returns statistics file or none if no statistics were collected. */
+    StatisticsFile statisticsFile();
+  }
+}
diff --git a/core/src/main/java/org/apache/iceberg/GenericBlobMetadata.java 
b/core/src/main/java/org/apache/iceberg/GenericBlobMetadata.java
index 46bedfa017..d3ac399556 100644
--- a/core/src/main/java/org/apache/iceberg/GenericBlobMetadata.java
+++ b/core/src/main/java/org/apache/iceberg/GenericBlobMetadata.java
@@ -18,6 +18,7 @@
  */
 package org.apache.iceberg;
 
+import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -37,6 +38,13 @@ public class GenericBlobMetadata implements BlobMetadata {
         puffinMetadata.properties());
   }
 
+  public static List<BlobMetadata> from(
+      Collection<org.apache.iceberg.puffin.BlobMetadata> puffinMetadataList) {
+    return puffinMetadataList.stream()
+        .map(GenericBlobMetadata::from)
+        .collect(ImmutableList.toImmutableList());
+  }
+
   private final String type;
   private final long sourceSnapshotId;
   private final long sourceSnapshotSequenceNumber;
diff --git 
a/core/src/main/java/org/apache/iceberg/actions/BaseComputeTableStats.java 
b/core/src/main/java/org/apache/iceberg/actions/BaseComputeTableStats.java
new file mode 100644
index 0000000000..71941af1d7
--- /dev/null
+++ b/core/src/main/java/org/apache/iceberg/actions/BaseComputeTableStats.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.actions;
+
+import javax.annotation.Nullable;
+import org.apache.iceberg.StatisticsFile;
+import org.immutables.value.Value;
+
[email protected]
+@SuppressWarnings("ImmutablesStyle")
[email protected](
+    typeImmutableEnclosing = "ImmutableComputeTableStats",
+    visibilityString = "PUBLIC",
+    builderVisibilityString = "PUBLIC")
+interface BaseComputeTableStats extends ComputeTableStats {
+
+  @Value.Immutable
+  interface Result extends ComputeTableStats.Result {
+    @Override
+    @Nullable
+    StatisticsFile statisticsFile();
+  }
+}
diff --git a/spark/v3.5/build.gradle b/spark/v3.5/build.gradle
index 2ba5d493c6..c8d8bbf396 100644
--- a/spark/v3.5/build.gradle
+++ b/spark/v3.5/build.gradle
@@ -59,6 +59,7 @@ 
project(":iceberg-spark:iceberg-spark-${sparkMajorVersion}_${scalaVersion}") {
     implementation project(':iceberg-parquet')
     implementation project(':iceberg-arrow')
     
implementation("org.scala-lang.modules:scala-collection-compat_${scalaVersion}:${libs.versions.scala.collection.compat.get()}")
+    
implementation("org.apache.datasketches:datasketches-java:${libs.versions.datasketches.get()}")
     if (scalaVersion == '2.12') {
       // scala-collection-compat_2.12 pulls scala 2.12.17 and we need 2.12.18 
for JDK 21 support
       implementation 'org.scala-lang:scala-library:2.12.18'
@@ -292,6 +293,7 @@ 
project(":iceberg-spark:iceberg-spark-runtime-${sparkMajorVersion}_${scalaVersio
     relocate 'com.carrotsearch', 'org.apache.iceberg.shaded.com.carrotsearch'
     relocate 'org.threeten.extra', 
'org.apache.iceberg.shaded.org.threeten.extra'
     relocate 'org.roaringbitmap', 'org.apache.iceberg.shaded.org.roaringbitmap'
+    relocate 'org.apache.datasketches', 
'org.apache.iceberg.shaded.org.apache.datasketches'
 
     archiveClassifier.set(null)
   }
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/ComputeTableStatsSparkAction.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/ComputeTableStatsSparkAction.java
new file mode 100644
index 0000000000..a508021c10
--- /dev/null
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/ComputeTableStatsSparkAction.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.actions;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import org.apache.iceberg.GenericBlobMetadata;
+import org.apache.iceberg.GenericStatisticsFile;
+import org.apache.iceberg.HasTableOperations;
+import org.apache.iceberg.IcebergBuild;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.StatisticsFile;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.TableOperations;
+import org.apache.iceberg.actions.ComputeTableStats;
+import org.apache.iceberg.actions.ImmutableComputeTableStats;
+import org.apache.iceberg.exceptions.RuntimeIOException;
+import org.apache.iceberg.io.OutputFile;
+import org.apache.iceberg.puffin.Blob;
+import org.apache.iceberg.puffin.Puffin;
+import org.apache.iceberg.puffin.PuffinWriter;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet;
+import org.apache.iceberg.spark.JobGroupInfo;
+import org.apache.iceberg.types.Types;
+import org.apache.spark.sql.SparkSession;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Computes the statistics of the given columns and stores it as Puffin 
files. */
+public class ComputeTableStatsSparkAction extends 
BaseSparkAction<ComputeTableStatsSparkAction>
+    implements ComputeTableStats {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(ComputeTableStatsSparkAction.class);
+  private static final Result EMPTY_RESULT = 
ImmutableComputeTableStats.Result.builder().build();
+
+  private final Table table;
+  private List<String> columns;
+  private Snapshot snapshot;
+
+  ComputeTableStatsSparkAction(SparkSession spark, Table table) {
+    super(spark);
+    this.table = table;
+    this.snapshot = table.currentSnapshot();
+  }
+
+  @Override
+  protected ComputeTableStatsSparkAction self() {
+    return this;
+  }
+
+  @Override
+  public ComputeTableStats columns(String... newColumns) {
+    Preconditions.checkArgument(
+        newColumns != null && newColumns.length > 0, "Columns cannot be 
null/empty");
+    this.columns = ImmutableList.copyOf(ImmutableSet.copyOf(newColumns));
+    return this;
+  }
+
+  @Override
+  public ComputeTableStats snapshot(long newSnapshotId) {
+    Snapshot newSnapshot = table.snapshot(newSnapshotId);
+    Preconditions.checkArgument(newSnapshot != null, "Snapshot not found: %s", 
newSnapshotId);
+    this.snapshot = newSnapshot;
+    return this;
+  }
+
+  @Override
+  public Result execute() {
+    if (snapshot == null) {
+      LOG.info("No snapshot to compute stats for table {}", table.name());
+      return EMPTY_RESULT;
+    }
+    validateColumns();
+    JobGroupInfo info = newJobGroupInfo("COMPUTE-TABLE-STATS", jobDesc());
+    return withJobGroupInfo(info, this::doExecute);
+  }
+
+  private Result doExecute() {
+    LOG.info(
+        "Computing stats for columns {} in {} (snapshot {})",
+        columns(),
+        table.name(),
+        snapshotId());
+    List<Blob> blobs = generateNDVBlobs();
+    StatisticsFile statisticsFile = writeStatsFile(blobs);
+    table.updateStatistics().setStatistics(snapshotId(), 
statisticsFile).commit();
+    return 
ImmutableComputeTableStats.Result.builder().statisticsFile(statisticsFile).build();
+  }
+
+  private StatisticsFile writeStatsFile(List<Blob> blobs) {
+    LOG.info("Writing stats for table {} for snapshot {}", table.name(), 
snapshotId());
+    OutputFile outputFile = table.io().newOutputFile(outputPath());
+    try (PuffinWriter writer = 
Puffin.write(outputFile).createdBy(appIdentifier()).build()) {
+      blobs.forEach(writer::add);
+      writer.finish();
+      return new GenericStatisticsFile(
+          snapshotId(),
+          outputFile.location(),
+          writer.fileSize(),
+          writer.footerSize(),
+          GenericBlobMetadata.from(writer.writtenBlobsMetadata()));
+    } catch (IOException e) {
+      throw new RuntimeIOException(e);
+    }
+  }
+
+  private List<Blob> generateNDVBlobs() {
+    return NDVSketchUtil.generateBlobs(spark(), table, snapshot, columns());
+  }
+
+  private List<String> columns() {
+    if (columns == null) {
+      Schema schema = table.schemas().get(snapshot.schemaId());
+      this.columns =
+          schema.columns().stream()
+              .filter(nestedField -> nestedField.type().isPrimitiveType())
+              .map(Types.NestedField::name)
+              .collect(Collectors.toList());
+    }
+    return columns;
+  }
+
+  private void validateColumns() {
+    Schema schema = table.schemas().get(snapshot.schemaId());
+    Preconditions.checkArgument(!columns().isEmpty(), "No columns found to 
compute stats");
+    for (String columnName : columns()) {
+      Types.NestedField field = schema.findField(columnName);
+      Preconditions.checkArgument(field != null, "Can't find column %s in %s", 
columnName, schema);
+      Preconditions.checkArgument(
+          field.type().isPrimitiveType(),
+          "Can't compute stats on non-primitive type column: %s (%s)",
+          columnName,
+          field.type());
+    }
+  }
+
+  private String appIdentifier() {
+    String icebergVersion = IcebergBuild.fullVersion();
+    String sparkVersion = spark().version();
+    return String.format("Iceberg %s Spark %s", icebergVersion, sparkVersion);
+  }
+
+  private long snapshotId() {
+    return snapshot.snapshotId();
+  }
+
+  private String jobDesc() {
+    return String.format(
+        "Computing table stats for %s (snapshot_id=%s, columns=%s)",
+        table.name(), snapshotId(), columns());
+  }
+
+  private String outputPath() {
+    TableOperations operations = ((HasTableOperations) table).operations();
+    String fileName = String.format("%s-%s.stats", snapshotId(), 
UUID.randomUUID());
+    return operations.metadataFileLocation(fileName);
+  }
+}
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/NDVSketchUtil.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/NDVSketchUtil.java
new file mode 100644
index 0000000000..22055a161e
--- /dev/null
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/NDVSketchUtil.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.actions;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+import org.apache.datasketches.memory.Memory;
+import org.apache.datasketches.theta.CompactSketch;
+import org.apache.datasketches.theta.Sketch;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.puffin.Blob;
+import org.apache.iceberg.puffin.PuffinCompressionCodec;
+import org.apache.iceberg.puffin.StandardBlobTypes;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.SparkReadOptions;
+import org.apache.iceberg.types.Types;
+import org.apache.spark.sql.Column;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.stats.ThetaSketchAgg;
+
+public class NDVSketchUtil {
+
+  private NDVSketchUtil() {}
+
+  public static final String APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY = "ndv";
+
+  static List<Blob> generateBlobs(
+      SparkSession spark, Table table, Snapshot snapshot, List<String> 
columns) {
+    Row sketches = computeNDVSketches(spark, table, snapshot, columns);
+    Schema schema = table.schemas().get(snapshot.schemaId());
+    List<Blob> blobs = Lists.newArrayList();
+    for (int i = 0; i < columns.size(); i++) {
+      Types.NestedField field = schema.findField(columns.get(i));
+      Sketch sketch = CompactSketch.wrap(Memory.wrap((byte[]) 
sketches.get(i)));
+      blobs.add(toBlob(field, sketch, snapshot));
+    }
+    return blobs;
+  }
+
+  private static Blob toBlob(Types.NestedField field, Sketch sketch, Snapshot 
snapshot) {
+    return new Blob(
+        StandardBlobTypes.APACHE_DATASKETCHES_THETA_V1,
+        ImmutableList.of(field.fieldId()),
+        snapshot.snapshotId(),
+        snapshot.sequenceNumber(),
+        ByteBuffer.wrap(sketch.toByteArray()),
+        PuffinCompressionCodec.ZSTD,
+        ImmutableMap.of(
+            APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY,
+            String.valueOf((long) sketch.getEstimate())));
+  }
+
+  private static Row computeNDVSketches(
+      SparkSession spark, Table table, Snapshot snapshot, List<String> 
colNames) {
+    return spark
+        .read()
+        .format("iceberg")
+        .option(SparkReadOptions.SNAPSHOT_ID, snapshot.snapshotId())
+        .load(table.name())
+        .select(toAggColumns(colNames))
+        .first();
+  }
+
+  private static Column[] toAggColumns(List<String> colNames) {
+    return 
colNames.stream().map(NDVSketchUtil::toAggColumn).toArray(Column[]::new);
+  }
+
+  private static Column toAggColumn(String colName) {
+    ThetaSketchAgg agg = new ThetaSketchAgg(colName);
+    return new Column(agg.toAggregateExpression());
+  }
+}
diff --git 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java
 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java
index fb67ded96e..f845386d30 100644
--- 
a/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java
+++ 
b/spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/actions/SparkActions.java
@@ -20,6 +20,7 @@ package org.apache.iceberg.spark.actions;
 
 import org.apache.iceberg.Table;
 import org.apache.iceberg.actions.ActionsProvider;
+import org.apache.iceberg.actions.ComputeTableStats;
 import org.apache.iceberg.spark.Spark3Util;
 import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier;
 import org.apache.spark.sql.SparkSession;
@@ -96,4 +97,9 @@ public class SparkActions implements ActionsProvider {
   public RewritePositionDeleteFilesSparkAction rewritePositionDeletes(Table 
table) {
     return new RewritePositionDeleteFilesSparkAction(spark, table);
   }
+
+  @Override
+  public ComputeTableStats computeTableStats(Table table) {
+    return new ComputeTableStatsSparkAction(spark, table);
+  }
 }
diff --git 
a/spark/v3.5/spark/src/main/scala/org/apache/spark/sql/stats/ThetaSketchAgg.scala
 
b/spark/v3.5/spark/src/main/scala/org/apache/spark/sql/stats/ThetaSketchAgg.scala
new file mode 100644
index 0000000000..cca16960f4
--- /dev/null
+++ 
b/spark/v3.5/spark/src/main/scala/org/apache/spark/sql/stats/ThetaSketchAgg.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.stats
+
+import java.nio.ByteBuffer
+import org.apache.datasketches.common.Family
+import org.apache.datasketches.memory.Memory
+import org.apache.datasketches.theta.CompactSketch
+import org.apache.datasketches.theta.SetOperationBuilder
+import org.apache.datasketches.theta.Sketch
+import org.apache.datasketches.theta.UpdateSketch
+import org.apache.iceberg.spark.SparkSchemaUtil
+import org.apache.iceberg.types.Conversions
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
+import org.apache.spark.sql.catalyst.trees.UnaryLike
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types.BinaryType
+import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.Decimal
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * ThetaSketchAgg generates Alpha family sketch with default seed.
+ * The values fed to the sketch are converted to bytes using Iceberg's single 
value serialization.
+ * The result returned is an array of bytes of Compact Theta sketch of 
Datasketches library,
+ * which should be deserialized to Compact sketch before using.
+ *
+ * See [[https://iceberg.apache.org/puffin-spec/]] for more information.
+ *
+ */
+case class ThetaSketchAgg(
+     child: Expression,
+     mutableAggBufferOffset: Int = 0,
+     inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Sketch] 
with UnaryLike[Expression] {
+
+  private lazy val icebergType = SparkSchemaUtil.convert(child.dataType)
+
+  def this(colName: String) = {
+    this(col(colName).expr, 0, 0)
+  }
+
+  override def dataType: DataType = BinaryType
+
+  override def nullable: Boolean = false
+
+  override def createAggregationBuffer(): Sketch = {
+    UpdateSketch.builder.setFamily(Family.ALPHA).build()
+  }
+
+  override def update(buffer: Sketch, input: InternalRow): Sketch = {
+    val value = child.eval(input)
+    if (value != null) {
+      val icebergValue = toIcebergValue(value)
+      val byteBuffer = Conversions.toByteBuffer(icebergType, icebergValue)
+      buffer.asInstanceOf[UpdateSketch].update(byteBuffer)
+    }
+    buffer
+  }
+
+  private def toIcebergValue(value: Any): Any = {
+    value match {
+      case s: UTF8String => s.toString
+      case d: Decimal => d.toJavaBigDecimal
+      case b: Array[Byte] => ByteBuffer.wrap(b)
+      case _ => value
+    }
+  }
+
+  override def merge(buffer: Sketch, input: Sketch): Sketch = {
+    new SetOperationBuilder().buildUnion.union(buffer, input)
+  }
+
+  override def eval(buffer: Sketch): Any = {
+    toBytes(buffer)
+  }
+
+  override def serialize(buffer: Sketch): Array[Byte] = {
+    toBytes(buffer)
+  }
+
+  override def deserialize(storageFormat: Array[Byte]): Sketch = {
+    CompactSketch.wrap(Memory.wrap(storageFormat))
+  }
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate = {
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+  }
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate = {
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+  }
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression = {
+    copy(child = newChild)
+  }
+
+  private def toBytes(sketch: Sketch): Array[Byte] = {
+    val compactSketch = sketch.compact()
+    compactSketch.toByteArray
+  }
+}
diff --git 
a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/actions/TestComputeTableStatsAction.java
 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/actions/TestComputeTableStatsAction.java
new file mode 100644
index 0000000000..588bb29f47
--- /dev/null
+++ 
b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/actions/TestComputeTableStatsAction.java
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.actions;
+
+import static 
org.apache.iceberg.spark.actions.NDVSketchUtil.APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY;
+import static org.apache.iceberg.types.Types.NestedField.optional;
+import static org.apache.iceberg.types.Types.NestedField.required;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.io.IOException;
+import java.util.List;
+import org.apache.iceberg.BlobMetadata;
+import org.apache.iceberg.DataFile;
+import org.apache.iceberg.Files;
+import org.apache.iceberg.PartitionSpec;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.StatisticsFile;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.actions.ComputeTableStats;
+import org.apache.iceberg.data.FileHelpers;
+import org.apache.iceberg.data.GenericRecord;
+import org.apache.iceberg.data.Record;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.CatalogTestBase;
+import org.apache.iceberg.spark.Spark3Util;
+import org.apache.iceberg.spark.SparkSchemaUtil;
+import org.apache.iceberg.spark.SparkWriteOptions;
+import org.apache.iceberg.spark.data.RandomData;
+import org.apache.iceberg.spark.source.SimpleRecord;
+import org.apache.iceberg.types.Types;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.catalyst.parser.ParseException;
+import org.apache.spark.sql.types.StructType;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.TestTemplate;
+
+public class TestComputeTableStatsAction extends CatalogTestBase {
+
+  private static final Types.StructType LEAF_STRUCT_TYPE =
+      Types.StructType.of(
+          optional(1, "leafLongCol", Types.LongType.get()),
+          optional(2, "leafDoubleCol", Types.DoubleType.get()));
+
+  private static final Types.StructType NESTED_STRUCT_TYPE =
+      Types.StructType.of(required(3, "leafStructCol", LEAF_STRUCT_TYPE));
+
+  private static final Schema NESTED_SCHEMA =
+      new Schema(required(4, "nestedStructCol", NESTED_STRUCT_TYPE));
+
+  private static final Schema SCHEMA_WITH_NESTED_COLUMN =
+      new Schema(
+          required(4, "nestedStructCol", NESTED_STRUCT_TYPE),
+          required(5, "stringCol", Types.StringType.get()));
+
+  @TestTemplate
+  public void testComputeTableStatsAction() throws NoSuchTableException, 
ParseException {
+    sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName);
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+
+    // To create multiple splits on the mapper
+    table
+        .updateProperties()
+        .set("read.split.target-size", "100")
+        .set("write.parquet.row-group-size-bytes", "100")
+        .commit();
+    List<SimpleRecord> records =
+        Lists.newArrayList(
+            new SimpleRecord(1, "a"),
+            new SimpleRecord(1, "a"),
+            new SimpleRecord(2, "b"),
+            new SimpleRecord(3, "c"),
+            new SimpleRecord(4, "d"));
+    spark.createDataset(records, 
Encoders.bean(SimpleRecord.class)).writeTo(tableName).append();
+    SparkActions actions = SparkActions.get();
+    ComputeTableStats.Result results =
+        actions.computeTableStats(table).columns("id", "data").execute();
+    assertNotNull(results);
+
+    List<StatisticsFile> statisticsFiles = table.statisticsFiles();
+    Assertions.assertEquals(statisticsFiles.size(), 1);
+
+    StatisticsFile statisticsFile = statisticsFiles.get(0);
+    assertNotEquals(statisticsFile.fileSizeInBytes(), 0);
+    Assertions.assertEquals(statisticsFile.blobMetadata().size(), 2);
+
+    BlobMetadata blobMetadata = statisticsFile.blobMetadata().get(0);
+    Assertions.assertEquals(
+        
blobMetadata.properties().get(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY),
+        String.valueOf(4));
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsActionWithoutExplicitColumns()
+      throws NoSuchTableException, ParseException {
+    sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName);
+
+    List<SimpleRecord> records =
+        Lists.newArrayList(
+            new SimpleRecord(1, "a"),
+            new SimpleRecord(2, "b"),
+            new SimpleRecord(3, "c"),
+            new SimpleRecord(4, "d"));
+    spark
+        .createDataset(records, Encoders.bean(SimpleRecord.class))
+        .coalesce(1)
+        .writeTo(tableName)
+        .append();
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+    SparkActions actions = SparkActions.get();
+    ComputeTableStats.Result results = 
actions.computeTableStats(table).execute();
+    assertNotNull(results);
+
+    Assertions.assertEquals(1, table.statisticsFiles().size());
+    StatisticsFile statisticsFile = table.statisticsFiles().get(0);
+    Assertions.assertEquals(2, statisticsFile.blobMetadata().size());
+    assertNotEquals(0, statisticsFile.fileSizeInBytes());
+    Assertions.assertEquals(
+        4,
+        Long.parseLong(
+            statisticsFile
+                .blobMetadata()
+                .get(0)
+                .properties()
+                .get(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY)));
+    Assertions.assertEquals(
+        4,
+        Long.parseLong(
+            statisticsFile
+                .blobMetadata()
+                .get(1)
+                .properties()
+                .get(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY)));
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsForInvalidColumns() throws 
NoSuchTableException, ParseException {
+    sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName);
+    // Append data to create snapshot
+    sql("INSERT into %s values(1, 'abcd')", tableName);
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+    SparkActions actions = SparkActions.get();
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> actions.computeTableStats(table).columns("id1").execute());
+    String message = exception.getMessage();
+    assertTrue(message.contains("Can't find column id1 in table"));
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsWithNoSnapshots() throws 
NoSuchTableException, ParseException {
+    sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName);
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+    SparkActions actions = SparkActions.get();
+    ComputeTableStats.Result result = 
actions.computeTableStats(table).columns("id").execute();
+    Assertions.assertNull(result.statisticsFile());
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsWithNullValues() throws 
NoSuchTableException, ParseException {
+    sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName);
+    List<SimpleRecord> records =
+        Lists.newArrayList(
+            new SimpleRecord(1, null),
+            new SimpleRecord(1, "a"),
+            new SimpleRecord(2, "b"),
+            new SimpleRecord(3, "c"),
+            new SimpleRecord(4, "d"));
+    spark
+        .createDataset(records, Encoders.bean(SimpleRecord.class))
+        .coalesce(1)
+        .writeTo(tableName)
+        .append();
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+    SparkActions actions = SparkActions.get();
+    ComputeTableStats.Result results = 
actions.computeTableStats(table).columns("data").execute();
+    assertNotNull(results);
+
+    List<StatisticsFile> statisticsFiles = table.statisticsFiles();
+    Assertions.assertEquals(statisticsFiles.size(), 1);
+
+    StatisticsFile statisticsFile = statisticsFiles.get(0);
+    assertNotEquals(statisticsFile.fileSizeInBytes(), 0);
+    Assertions.assertEquals(statisticsFile.blobMetadata().size(), 1);
+
+    BlobMetadata blobMetadata = statisticsFile.blobMetadata().get(0);
+    Assertions.assertEquals(
+        
blobMetadata.properties().get(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY),
+        String.valueOf(4));
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsWithSnapshotHavingDifferentSchemas()
+      throws NoSuchTableException, ParseException {
+    SparkActions actions = SparkActions.get();
+    sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName);
+    // Append data to create snapshot
+    sql("INSERT into %s values(1, 'abcd')", tableName);
+    long snapshotId1 = Spark3Util.loadIcebergTable(spark, 
tableName).currentSnapshot().snapshotId();
+    // Snapshot id not specified
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+
+    assertDoesNotThrow(() -> 
actions.computeTableStats(table).columns("data").execute());
+
+    sql("ALTER TABLE %s DROP COLUMN %s", tableName, "data");
+    // Append data to create snapshot
+    sql("INSERT into %s values(1)", tableName);
+    table.refresh();
+    long snapshotId2 = Spark3Util.loadIcebergTable(spark, 
tableName).currentSnapshot().snapshotId();
+
+    // Snapshot id specified
+    assertDoesNotThrow(
+        () -> 
actions.computeTableStats(table).snapshot(snapshotId1).columns("data").execute());
+
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> 
actions.computeTableStats(table).snapshot(snapshotId2).columns("data").execute());
+    String message = exception.getMessage();
+    assertTrue(message.contains("Can't find column data in table"));
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsWhenSnapshotIdNotSpecified()
+      throws NoSuchTableException, ParseException {
+    sql("CREATE TABLE %s (id int, data string) USING iceberg", tableName);
+    // Append data to create snapshot
+    sql("INSERT into %s values(1, 'abcd')", tableName);
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+    SparkActions actions = SparkActions.get();
+    ComputeTableStats.Result results = 
actions.computeTableStats(table).columns("data").execute();
+
+    assertNotNull(results);
+
+    List<StatisticsFile> statisticsFiles = table.statisticsFiles();
+    Assertions.assertEquals(statisticsFiles.size(), 1);
+
+    StatisticsFile statisticsFile = statisticsFiles.get(0);
+    assertNotEquals(statisticsFile.fileSizeInBytes(), 0);
+    Assertions.assertEquals(statisticsFile.blobMetadata().size(), 1);
+
+    BlobMetadata blobMetadata = statisticsFile.blobMetadata().get(0);
+    Assertions.assertEquals(
+        
blobMetadata.properties().get(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY),
+        String.valueOf(1));
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsWithNestedSchema()
+      throws NoSuchTableException, ParseException, IOException {
+    List<Record> records = Lists.newArrayList(createNestedRecord());
+    Table table =
+        validationCatalog.createTable(
+            tableIdent,
+            SCHEMA_WITH_NESTED_COLUMN,
+            PartitionSpec.unpartitioned(),
+            ImmutableMap.of());
+    DataFile dataFile = FileHelpers.writeDataFile(table, 
Files.localOutput(temp.toFile()), records);
+    table.newAppend().appendFile(dataFile).commit();
+
+    Table tbl = Spark3Util.loadIcebergTable(spark, tableName);
+    SparkActions actions = SparkActions.get();
+    actions.computeTableStats(tbl).execute();
+
+    tbl.refresh();
+    List<StatisticsFile> statisticsFiles = tbl.statisticsFiles();
+    Assertions.assertEquals(statisticsFiles.size(), 1);
+    StatisticsFile statisticsFile = statisticsFiles.get(0);
+    assertNotEquals(statisticsFile.fileSizeInBytes(), 0);
+    Assertions.assertEquals(statisticsFile.blobMetadata().size(), 1);
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsWithNoComputableColumns() throws 
IOException {
+    List<Record> records = Lists.newArrayList(createNestedRecord());
+    Table table =
+        validationCatalog.createTable(
+            tableIdent, NESTED_SCHEMA, PartitionSpec.unpartitioned(), 
ImmutableMap.of());
+    DataFile dataFile = FileHelpers.writeDataFile(table, 
Files.localOutput(temp.toFile()), records);
+    table.newAppend().appendFile(dataFile).commit();
+
+    table.refresh();
+    SparkActions actions = SparkActions.get();
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class, () -> 
actions.computeTableStats(table).execute());
+    Assertions.assertEquals(exception.getMessage(), "No columns found to 
compute stats");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnByteColumn() throws NoSuchTableException, 
ParseException {
+    testComputeTableStats("byte_col", "TINYINT");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnShortColumn() throws 
NoSuchTableException, ParseException {
+    testComputeTableStats("short_col", "SMALLINT");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnIntColumn() throws NoSuchTableException, 
ParseException {
+    testComputeTableStats("int_col", "INT");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnLongColumn() throws NoSuchTableException, 
ParseException {
+    testComputeTableStats("long_col", "BIGINT");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnTimestampColumn() throws 
NoSuchTableException, ParseException {
+    testComputeTableStats("timestamp_col", "TIMESTAMP");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnTimestampNtzColumn()
+      throws NoSuchTableException, ParseException {
+    testComputeTableStats("timestamp_col", "TIMESTAMP_NTZ");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnDateColumn() throws NoSuchTableException, 
ParseException {
+    testComputeTableStats("date_col", "DATE");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnDecimalColumn() throws 
NoSuchTableException, ParseException {
+    testComputeTableStats("decimal_col", "DECIMAL(20, 2)");
+  }
+
+  @TestTemplate
+  public void testComputeTableStatsOnBinaryColumn() throws 
NoSuchTableException, ParseException {
+    testComputeTableStats("binary_col", "BINARY");
+  }
+
+  public void testComputeTableStats(String columnName, String type)
+      throws NoSuchTableException, ParseException {
+    sql("CREATE TABLE %s (id int, %s %s) USING iceberg", tableName, 
columnName, type);
+    Table table = Spark3Util.loadIcebergTable(spark, tableName);
+
+    Dataset<Row> dataDF = randomDataDF(table.schema());
+    append(tableName, dataDF);
+
+    SparkActions actions = SparkActions.get();
+    table.refresh();
+    ComputeTableStats.Result results =
+        actions.computeTableStats(table).columns(columnName).execute();
+    assertNotNull(results);
+
+    List<StatisticsFile> statisticsFiles = table.statisticsFiles();
+    Assertions.assertEquals(statisticsFiles.size(), 1);
+
+    StatisticsFile statisticsFile = statisticsFiles.get(0);
+    assertNotEquals(statisticsFile.fileSizeInBytes(), 0);
+    Assertions.assertEquals(statisticsFile.blobMetadata().size(), 1);
+
+    BlobMetadata blobMetadata = statisticsFile.blobMetadata().get(0);
+    Assertions.assertNotNull(
+        
blobMetadata.properties().get(APACHE_DATASKETCHES_THETA_V1_NDV_PROPERTY));
+  }
+
+  private GenericRecord createNestedRecord() {
+    GenericRecord record = GenericRecord.create(SCHEMA_WITH_NESTED_COLUMN);
+    GenericRecord nested = GenericRecord.create(NESTED_STRUCT_TYPE);
+    GenericRecord leaf = GenericRecord.create(LEAF_STRUCT_TYPE);
+    leaf.set(0, 0L);
+    leaf.set(1, 0.0);
+    nested.set(0, leaf);
+    record.set(0, nested);
+    record.set(1, "data");
+    return record;
+  }
+
+  private Dataset<Row> randomDataDF(Schema schema) {
+    Iterable<InternalRow> rows = RandomData.generateSpark(schema, 10, 0);
+    JavaRDD<InternalRow> rowRDD = 
sparkContext.parallelize(Lists.newArrayList(rows));
+    StructType rowSparkType = SparkSchemaUtil.convert(schema);
+    return spark.internalCreateDataFrame(JavaRDD.toRDD(rowRDD), rowSparkType, 
false);
+  }
+
+  private void append(String table, Dataset<Row> df) throws 
NoSuchTableException {
+    // fanout writes are enabled as write-time clustering is not supported 
without Spark extensions
+    df.coalesce(1).writeTo(table).option(SparkWriteOptions.FANOUT_ENABLED, 
"true").append();
+  }
+
+  @AfterEach
+  public void removeTable() {
+    sql("DROP TABLE IF EXISTS %s", tableName);
+  }
+}


Reply via email to