This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b9dd9b0f [spark] Support report partitioning to eliminate shuffle 
exchange (#3912)
3b9dd9b0f is described below

commit 3b9dd9b0fc00566ccdb8037a33702e7d4b19e3fb
Author: Xiduo You <[email protected]>
AuthorDate: Thu Aug 8 22:18:16 2024 +0800

    [spark] Support report partitioning to eliminate shuffle exchange (#3912)
---
 .../java/org/apache/paimon/schema/TableSchema.java |  22 +++-
 .../paimon/table/AbstractFileStoreTable.java       |   5 -
 .../java/org/apache/paimon/table/BucketSpec.java   |  65 ++++++++++
 .../paimon/table/DelegatedFileStoreTable.java      |   5 -
 .../org/apache/paimon/table/FileStoreTable.java    |  10 +-
 .../paimon/spark/catalog/SparkBaseCatalog.java     |   0
 .../apache/paimon/spark/PaimonInputPartition.scala |  16 +--
 .../paimon/spark/catalog/SparkBaseCatalog.java     |   0
 .../apache/paimon/spark/PaimonInputPartition.scala |  16 +--
 .../scala/org/apache/paimon/spark/PaimonScan.scala |   0
 .../paimon/spark/catalog/SparkBaseCatalog.java     |  43 ++++++-
 .../spark/catalog/functions/PaimonFunctions.java   |  88 +++++++++++++
 .../org/apache/paimon/spark/PaimonBaseScan.scala   |   4 +-
 .../apache/paimon/spark/PaimonInputPartition.scala |  22 +++-
 .../scala/org/apache/paimon/spark/PaimonScan.scala |  48 +++++++-
 .../org/apache/paimon/spark/PaimonStatistics.scala |   2 +-
 .../paimon/spark/sql/BucketedTableQueryTest.scala  | 136 +++++++++++++++++++++
 .../apache/paimon/spark/sql/PaimonMetricTest.scala |   2 +-
 18 files changed, 428 insertions(+), 56 deletions(-)

diff --git 
a/paimon-core/src/main/java/org/apache/paimon/schema/TableSchema.java 
b/paimon-core/src/main/java/org/apache/paimon/schema/TableSchema.java
index bcad8e92b..18bf3c893 100644
--- a/paimon-core/src/main/java/org/apache/paimon/schema/TableSchema.java
+++ b/paimon-core/src/main/java/org/apache/paimon/schema/TableSchema.java
@@ -18,6 +18,7 @@
 
 package org.apache.paimon.schema;
 
+import org.apache.paimon.CoreOptions;
 import org.apache.paimon.fs.FileIO;
 import org.apache.paimon.fs.Path;
 import org.apache.paimon.types.DataField;
@@ -66,6 +67,10 @@ public class TableSchema implements Serializable {
 
     private final List<String> primaryKeys;
 
+    private final List<String> bucketKeys;
+
+    private final int numBucket;
+
     private final Map<String, String> options;
 
     private final @Nullable String comment;
@@ -115,8 +120,13 @@ public class TableSchema implements Serializable {
         // try to trim to validate primary keys
         trimmedPrimaryKeys();
 
-        // try to validate bucket keys
-        originalBucketKeys();
+        // try to validate and initalize the bucket keys
+        List<String> tmpBucketKeys = originalBucketKeys();
+        if (tmpBucketKeys.isEmpty()) {
+            tmpBucketKeys = trimmedPrimaryKeys();
+        }
+        bucketKeys = tmpBucketKeys;
+        numBucket = CoreOptions.fromMap(options).bucket();
     }
 
     public int version() {
@@ -171,11 +181,11 @@ public class TableSchema implements Serializable {
         return options;
     }
 
+    public int numBuckets() {
+        return numBucket;
+    }
+
     public List<String> bucketKeys() {
-        List<String> bucketKeys = originalBucketKeys();
-        if (bucketKeys.isEmpty()) {
-            bucketKeys = trimmedPrimaryKeys();
-        }
         return bucketKeys;
     }
 
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/AbstractFileStoreTable.java 
b/paimon-core/src/main/java/org/apache/paimon/table/AbstractFileStoreTable.java
index d30bd11ef..6e3c79d4d 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/table/AbstractFileStoreTable.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/AbstractFileStoreTable.java
@@ -150,11 +150,6 @@ abstract class AbstractFileStoreTable implements 
FileStoreTable {
         return Optional.empty();
     }
 
-    @Override
-    public BucketMode bucketMode() {
-        return store().bucketMode();
-    }
-
     @Override
     public Optional<WriteSelector> newWriteSelector() {
         switch (bucketMode()) {
diff --git a/paimon-core/src/main/java/org/apache/paimon/table/BucketSpec.java 
b/paimon-core/src/main/java/org/apache/paimon/table/BucketSpec.java
new file mode 100644
index 000000000..99ca53e04
--- /dev/null
+++ b/paimon-core/src/main/java/org/apache/paimon/table/BucketSpec.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.table;
+
+import java.util.List;
+
+/**
+ * Bucket spec holds all bucket information, we can do plan optimization 
during table scan.
+ *
+ * <p>If the `bucketMode` is {@link BucketMode#HASH_DYNAMIC}, then `numBucket` 
is -1;
+ *
+ * @since 0.9
+ */
+public class BucketSpec {
+
+    private final BucketMode bucketMode;
+    private final List<String> bucketKeys;
+    private final int numBuckets;
+
+    public BucketSpec(BucketMode bucketMode, List<String> bucketKeys, int 
numBuckets) {
+        this.bucketMode = bucketMode;
+        this.bucketKeys = bucketKeys;
+        this.numBuckets = numBuckets;
+    }
+
+    public BucketMode getBucketMode() {
+        return bucketMode;
+    }
+
+    public List<String> getBucketKeys() {
+        return bucketKeys;
+    }
+
+    public int getNumBuckets() {
+        return numBuckets;
+    }
+
+    @Override
+    public String toString() {
+        return "BucketSpec{"
+                + "bucketMode="
+                + bucketMode
+                + ", bucketKeys="
+                + bucketKeys
+                + ", numBuckets="
+                + numBuckets
+                + '}';
+    }
+}
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/DelegatedFileStoreTable.java
 
b/paimon-core/src/main/java/org/apache/paimon/table/DelegatedFileStoreTable.java
index 58d1caaac..243ffb754 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/table/DelegatedFileStoreTable.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/table/DelegatedFileStoreTable.java
@@ -113,11 +113,6 @@ public abstract class DelegatedFileStoreTable implements 
FileStoreTable {
         return wrapped.store();
     }
 
-    @Override
-    public BucketMode bucketMode() {
-        return wrapped.bucketMode();
-    }
-
     @Override
     public CatalogEnvironment catalogEnvironment() {
         return wrapped.catalogEnvironment();
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/table/FileStoreTable.java 
b/paimon-core/src/main/java/org/apache/paimon/table/FileStoreTable.java
index 61fe816ac..ed1ba1da5 100644
--- a/paimon-core/src/main/java/org/apache/paimon/table/FileStoreTable.java
+++ b/paimon-core/src/main/java/org/apache/paimon/table/FileStoreTable.java
@@ -59,6 +59,14 @@ public interface FileStoreTable extends DataTable {
         return schema().primaryKeys();
     }
 
+    default BucketSpec bucketSpec() {
+        return new BucketSpec(bucketMode(), schema().bucketKeys(), 
schema().numBuckets());
+    }
+
+    default BucketMode bucketMode() {
+        return store().bucketMode();
+    }
+
     @Override
     default Map<String, String> options() {
         return schema().options();
@@ -73,8 +81,6 @@ public interface FileStoreTable extends DataTable {
 
     FileStore<?> store();
 
-    BucketMode bucketMode();
-
     CatalogEnvironment catalogEnvironment();
 
     @Override
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
 
b/paimon-spark/paimon-spark-3.1/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
similarity index 100%
copy from 
paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
copy to 
paimon-spark/paimon-spark-3.1/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
 
b/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
similarity index 72%
copy from 
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
copy to 
paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
index 894405fd7..49bc71e93 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
+++ 
b/paimon-spark/paimon-spark-3.1/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
@@ -20,16 +20,6 @@ package org.apache.paimon.spark
 
 import org.apache.paimon.table.source.Split
 
-import org.apache.spark.sql.connector.read.InputPartition
-
-case class PaimonInputPartition(splits: Seq[Split]) extends InputPartition {
-  def rowCount(): Long = {
-    splits.map(_.rowCount()).sum
-  }
-}
-
-object PaimonInputPartition {
-  def apply(split: Split): PaimonInputPartition = {
-    PaimonInputPartition(Seq(split))
-  }
-}
+// never be used
+case class PaimonBucketedInputPartition(splits: Seq[Split], bucket: Int)
+  extends PaimonInputPartition
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
 
b/paimon-spark/paimon-spark-3.2/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
similarity index 100%
copy from 
paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
copy to 
paimon-spark/paimon-spark-3.2/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
similarity index 72%
copy from 
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
copy to 
paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
index 894405fd7..49bc71e93 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
+++ 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
@@ -20,16 +20,6 @@ package org.apache.paimon.spark
 
 import org.apache.paimon.table.source.Split
 
-import org.apache.spark.sql.connector.read.InputPartition
-
-case class PaimonInputPartition(splits: Seq[Split]) extends InputPartition {
-  def rowCount(): Long = {
-    splits.map(_.rowCount()).sum
-  }
-}
-
-object PaimonInputPartition {
-  def apply(split: Split): PaimonInputPartition = {
-    PaimonInputPartition(Seq(split))
-  }
-}
+// never be used
+case class PaimonBucketedInputPartition(splits: Seq[Split], bucket: Int)
+  extends PaimonInputPartition
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
 
b/paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
similarity index 100%
copy from 
paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
copy to 
paimon-spark/paimon-spark-3.2/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
index 55670a594..b5a513564 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/SparkBaseCatalog.java
@@ -22,16 +22,32 @@ import org.apache.paimon.catalog.Catalog;
 import org.apache.paimon.spark.SparkProcedures;
 import org.apache.paimon.spark.SparkSource;
 import org.apache.paimon.spark.analysis.NoSuchProcedureException;
+import org.apache.paimon.spark.catalog.functions.PaimonFunctions;
 import org.apache.paimon.spark.procedure.Procedure;
 import org.apache.paimon.spark.procedure.ProcedureBuilder;
 
+import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableMap;
+
+import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException;
+import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
+import org.apache.spark.sql.connector.catalog.FunctionCatalog;
 import org.apache.spark.sql.connector.catalog.Identifier;
 import org.apache.spark.sql.connector.catalog.SupportsNamespaces;
 import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+
+import java.util.Arrays;
+import java.util.Map;
+
+import scala.Option;
 
 /** Spark base catalog. */
 public abstract class SparkBaseCatalog
-        implements TableCatalog, SupportsNamespaces, ProcedureCatalog, 
WithPaimonCatalog {
+        implements TableCatalog,
+                FunctionCatalog,
+                SupportsNamespaces,
+                ProcedureCatalog,
+                WithPaimonCatalog {
 
     protected String catalogName;
 
@@ -54,4 +70,29 @@ public abstract class SparkBaseCatalog
     public boolean usePaimon(String provider) {
         return provider == null || 
SparkSource.NAME().equalsIgnoreCase(provider);
     }
+
+    // --------------------- Function Catalog Methods 
----------------------------
+    private static final Map<String, UnboundFunction> FUNCTIONS =
+            ImmutableMap.of("bucket", new PaimonFunctions.BucketFunction());
+
+    @Override
+    public UnboundFunction loadFunction(Identifier ident) throws 
NoSuchFunctionException {
+        UnboundFunction func = FUNCTIONS.get(ident.name());
+        if (func == null) {
+            throw new NoSuchFunctionException(
+                    "Function " + ident + " is not a paimon function", 
Option.empty());
+        }
+        return func;
+    }
+
+    @Override
+    public Identifier[] listFunctions(String[] namespace) throws 
NoSuchNamespaceException {
+        if (namespace.length != 0) {
+            throw new NoSuchNamespaceException(
+                    "Namespace " + Arrays.toString(namespace) + " is not 
valid", Option.empty());
+        }
+        return FUNCTIONS.keySet().stream()
+                .map(name -> Identifier.of(namespace, name))
+                .toArray(Identifier[]::new);
+    }
 }
diff --git 
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java
 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java
new file mode 100644
index 000000000..d346f7f24
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/catalog/functions/PaimonFunctions.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.catalog.functions;
+
+import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import static org.apache.paimon.utils.Preconditions.checkArgument;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+
+/**
+ * It should be only used for resolving, e.g., for {@link
+ * org.apache.spark.sql.connector.read.SupportsReportPartitioning}.
+ */
+public class PaimonFunctions {
+    /**
+     * For now, we only support report bucket partitioning for table scan. So 
the case `SELECT
+     * bucket(10, col)` would fail since we do not implement {@link
+     * org.apache.spark.sql.connector.catalog.functions.ScalarFunction}
+     */
+    public static class BucketFunction implements UnboundFunction {
+        @Override
+        public BoundFunction bind(StructType inputType) {
+            if (inputType.size() != 2) {
+                throw new UnsupportedOperationException(
+                        "Wrong number of inputs (expected numBuckets and 
value)");
+            }
+
+            StructField numBucket = inputType.fields()[0];
+            StructField bucketField = inputType.fields()[1];
+            checkArgument(
+                    numBucket.dataType() == IntegerType,
+                    "bucket number field must be integer type");
+
+            return new BoundFunction() {
+                @Override
+                public DataType[] inputTypes() {
+                    return new DataType[] {IntegerType, 
bucketField.dataType()};
+                }
+
+                @Override
+                public DataType resultType() {
+                    return IntegerType;
+                }
+
+                @Override
+                public String name() {
+                    return "bucket";
+                }
+
+                @Override
+                public String canonicalName() {
+                    // We have to override this method to make it support 
canonical equivalent
+                    return "paimon.bucket(" + 
bucketField.dataType().catalogString() + ", int)";
+                }
+            };
+        }
+
+        @Override
+        public String description() {
+            return name();
+        }
+
+        @Override
+        public String name() {
+            return "bucket";
+        }
+    }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
index ac0018297..6fdfe4f6a 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonBaseScan.scala
@@ -105,7 +105,7 @@ abstract class PaimonBaseScan(
       .toArray
   }
 
-  def getInputPartitions: Seq[PaimonInputPartition] = {
+  final def lazyInputPartitions: Seq[PaimonInputPartition] = {
     if (inputPartitions == null) {
       inputPartitions = getInputPartitions(getOriginSplits)
     }
@@ -118,7 +118,7 @@ abstract class PaimonBaseScan(
 
   override def toBatch: Batch = {
     val metadataColumns = metadataFields.map(field => 
PaimonMetadataColumn.get(field.name))
-    PaimonBatch(getInputPartitions, readBuilder, metadataColumns)
+    PaimonBatch(lazyInputPartitions, readBuilder, metadataColumns)
   }
 
   override def toMicroBatchStream(checkpointLocation: String): 
MicroBatchStream = {
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
index 894405fd7..a7c33b21d 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonInputPartition.scala
@@ -20,16 +20,32 @@ package org.apache.paimon.spark
 
 import org.apache.paimon.table.source.Split
 
-import org.apache.spark.sql.connector.read.InputPartition
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition, 
SupportsReportPartitioning}
+
+trait PaimonInputPartition extends InputPartition {
+  def splits: Seq[Split]
 
-case class PaimonInputPartition(splits: Seq[Split]) extends InputPartition {
   def rowCount(): Long = {
     splits.map(_.rowCount()).sum
   }
 }
 
+case class SimplePaimonInputPartition(splits: Seq[Split]) extends 
PaimonInputPartition
 object PaimonInputPartition {
   def apply(split: Split): PaimonInputPartition = {
-    PaimonInputPartition(Seq(split))
+    SimplePaimonInputPartition(Seq(split))
+  }
+
+  def apply(splits: Seq[Split]): PaimonInputPartition = {
+    SimplePaimonInputPartition(splits)
   }
 }
+
+/** Bucketed input partition should work with [[SupportsReportPartitioning]] 
together. */
+case class PaimonBucketedInputPartition(splits: Seq[Split], bucket: Int)
+  extends PaimonInputPartition
+  with HasPartitionKey {
+  override def partitionKey(): InternalRow = new 
GenericInternalRow(Array(bucket.asInstanceOf[Any]))
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
index f0476cf70..f34a24991 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonScan.scala
@@ -19,11 +19,13 @@
 package org.apache.paimon.spark
 
 import org.apache.paimon.predicate.Predicate
-import org.apache.paimon.table.Table
+import org.apache.paimon.table.{BucketMode, FileStoreTable, Table}
+import org.apache.paimon.table.source.{DataSplit, Split}
 
 import org.apache.spark.sql.PaimonUtils.fieldReference
-import org.apache.spark.sql.connector.expressions.NamedReference
-import org.apache.spark.sql.connector.read.SupportsRuntimeFiltering
+import org.apache.spark.sql.connector.expressions.{Expressions, NamedReference}
+import org.apache.spark.sql.connector.read.{SupportsReportPartitioning, 
SupportsRuntimeFiltering}
+import 
org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, 
Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.sources.{Filter, In}
 import org.apache.spark.sql.types.StructType
 
@@ -36,7 +38,45 @@ case class PaimonScan(
     reservedFilters: Seq[Filter],
     pushDownLimit: Option[Int])
   extends PaimonBaseScan(table, requiredSchema, filters, reservedFilters, 
pushDownLimit)
-  with SupportsRuntimeFiltering {
+  with SupportsRuntimeFiltering
+  with SupportsReportPartitioning {
+
+  override def outputPartitioning(): Partitioning = {
+    table match {
+      case fileStoreTable: FileStoreTable =>
+        val bucketSpec = fileStoreTable.bucketSpec()
+        if (bucketSpec.getBucketMode != BucketMode.HASH_FIXED) {
+          new UnknownPartitioning(0)
+        } else if (bucketSpec.getBucketKeys.size() > 1) {
+          new UnknownPartitioning(0)
+        } else {
+          // Spark does not support bucket with several input attributes,
+          // so we only support one bucket key case.
+          assert(bucketSpec.getNumBuckets > 0)
+          assert(bucketSpec.getBucketKeys.size() == 1)
+          val key = Expressions.bucket(bucketSpec.getNumBuckets, 
bucketSpec.getBucketKeys.get(0))
+          new KeyGroupedPartitioning(Array(key), lazyInputPartitions.size)
+        }
+
+      case _ =>
+        new UnknownPartitioning(0)
+    }
+  }
+
+  override def getInputPartitions(splits: Array[Split]): 
Seq[PaimonInputPartition] = {
+    if (!conf.v2BucketingEnabled || splits.exists(!_.isInstanceOf[DataSplit])) 
{
+      return super.getInputPartitions(splits)
+    }
+
+    splits
+      .map(_.asInstanceOf[DataSplit])
+      .groupBy(_.bucket())
+      .map {
+        case (bucket, groupedSplits) =>
+          PaimonBucketedInputPartition(groupedSplits, bucket)
+      }
+      .toSeq
+  }
 
   override def filterAttributes(): Array[NamedReference] = {
     val requiredFields = readBuilder.readType().getFieldNames.asScala
diff --git 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala
 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala
index b1d66c90f..da9239dd0 100644
--- 
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala
+++ 
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/PaimonStatistics.scala
@@ -34,7 +34,7 @@ import scala.collection.JavaConverters._
 
 case class PaimonStatistics[T <: PaimonBaseScan](scan: T) extends Statistics {
 
-  private lazy val rowCount: Long = 
scan.getInputPartitions.map(_.rowCount()).sum
+  private lazy val rowCount: Long = 
scan.lazyInputPartitions.map(_.rowCount()).sum
 
   private lazy val scannedTotalSize: Long = rowCount * 
scan.readSchema().defaultSize
 
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
new file mode 100644
index 000000000..f7faeabdc
--- /dev/null
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/BucketedTableQueryTest.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
+
+class BucketedTableQueryTest extends PaimonSparkTestBase with 
AdaptiveSparkPlanHelper {
+  private def checkAnswerAndShuffle(query: String, numShuffle: Int): Unit = {
+    var expectedResult: Array[Row] = null
+    // avoid config default value change in future, so specify it manually
+    withSQLConf(
+      "spark.sql.sources.v2.bucketing.enabled" -> "false",
+      "spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      expectedResult = spark.sql(query).collect()
+    }
+    withSQLConf(
+      "spark.sql.sources.v2.bucketing.enabled" -> "true",
+      "spark.sql.autoBroadcastJoinThreshold" -> "-1") {
+      val df = spark.sql(query)
+      checkAnswer(df, expectedResult.toSeq)
+      assert(collect(df.queryExecution.executedPlan) {
+        case shuffle: ShuffleExchangeLike => shuffle
+      }.size == numShuffle)
+    }
+  }
+
+  test("Query on a bucketed table - join - positive case") {
+    assume(gteqSpark3_3)
+
+    withTable("t1", "t2", "t3", "t4") {
+      spark.sql(
+        "CREATE TABLE t1 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 
'id', 'bucket'='10')")
+      spark.sql("INSERT INTO t1 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+
+      // all matched
+      spark.sql(
+        "CREATE TABLE t2 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 
'id', 'bucket'='10')")
+      spark.sql("INSERT INTO t2 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+      checkAnswerAndShuffle("SELECT * FROM t1 JOIN t2 on t1.id = t2.id", 0)
+
+      // different primary-key name but does not matter
+      spark.sql(
+        "CREATE TABLE t3 (id2 INT, c STRING) TBLPROPERTIES ('primary-key' = 
'id2', 'bucket'='10')")
+      spark.sql("INSERT INTO t3 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+      checkAnswerAndShuffle("SELECT * FROM t1 JOIN t3 on t1.id = t3.id2", 0)
+
+      // one primary-key table and one bucketed table
+      spark.sql(
+        "CREATE TABLE t4 (id INT, c STRING) TBLPROPERTIES ('bucket-key' = 
'id', 'bucket'='10')")
+      spark.sql("INSERT INTO t4 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+      checkAnswerAndShuffle("SELECT * FROM t1 JOIN t4 on t1.id = t4.id", 0)
+    }
+  }
+
+  test("Query on a bucketed table - join - negative case") {
+    assume(gteqSpark3_3)
+
+    withTable("t1", "t2", "t3", "t4", "t5", "t6") {
+      spark.sql(
+        "CREATE TABLE t1 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 
'id', 'bucket'='10')")
+      spark.sql("INSERT INTO t1 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+
+      // dynamic bucket number
+      spark.sql("CREATE TABLE t2 (id INT, c STRING) TBLPROPERTIES 
('primary-key' = 'id')")
+      spark.sql("INSERT INTO t2 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+      checkAnswerAndShuffle("SELECT * FROM t1 JOIN t2 on t1.id = t2.id", 2)
+
+      // different bucket number
+      spark.sql(
+        "CREATE TABLE t3 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 
'id', 'bucket'='2')")
+      spark.sql("INSERT INTO t3 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+      checkAnswerAndShuffle("SELECT * FROM t1 JOIN t3 on t1.id = t3.id", 2)
+
+      // different primary-key data type
+      spark.sql(
+        "CREATE TABLE t4 (id STRING, c STRING) TBLPROPERTIES ('primary-key' = 
'id', 'bucket'='10')")
+      spark.sql("INSERT INTO t4 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+      checkAnswerAndShuffle("SELECT * FROM t1 JOIN t4 on t1.id = t4.id", 2)
+
+      // different input partition number
+      spark.sql(
+        "CREATE TABLE t5 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 
'id', 'bucket'='10')")
+      spark.sql("INSERT INTO t5 VALUES (1, 'x1')")
+      checkAnswerAndShuffle("SELECT * FROM t1 JOIN t5 on t1.id = t5.id", 2)
+
+      // one more bucket keys
+      spark.sql(
+        "CREATE TABLE t6 (id1 INT, id2 INT, c STRING) TBLPROPERTIES 
('bucket-key' = 'id1,id2', 'bucket'='10')")
+      spark.sql(
+        "INSERT INTO t6 VALUES (1, 1, 'x1'), (2, 2, 'x3'), (3, 3, 'x3'), (4, 
4, 'x4'), (5, 5, 'x5')")
+      checkAnswerAndShuffle("SELECT * FROM t1 JOIN t6 on t1.id = t6.id1", 2)
+    }
+  }
+
+  test("Query on a bucketed table - other operators") {
+    assume(gteqSpark3_3)
+
+    withTable("t1") {
+      spark.sql(
+        "CREATE TABLE t1 (id INT, c STRING) TBLPROPERTIES ('primary-key' = 
'id', 'bucket'='10')")
+      spark.sql("INSERT INTO t1 VALUES (1, 'x1'), (2, 'x3'), (3, 'x3'), (4, 
'x4'), (5, 'x5')")
+
+      checkAnswerAndShuffle("SELECT id, count(*) FROM t1 GROUP BY id", 0)
+      checkAnswerAndShuffle("SELECT c, count(*) FROM t1 GROUP BY c", 1)
+      checkAnswerAndShuffle("select sum(c) OVER (PARTITION BY id ORDER BY c) 
from t1", 0)
+      checkAnswerAndShuffle("select sum(id) OVER (PARTITION BY c ORDER BY id) 
from t1", 1)
+
+      withSQLConf("spark.sql.requireAllClusterKeysForDistribution" -> "false") 
{
+        checkAnswerAndShuffle("SELECT id, c, count(*) FROM t1 GROUP BY id, c", 
0)
+      }
+      withSQLConf("spark.sql.requireAllClusterKeysForDistribution" -> "true") {
+        checkAnswerAndShuffle("SELECT id, c, count(*) FROM t1 GROUP BY id, c", 
1)
+      }
+    }
+  }
+}
diff --git 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonMetricTest.scala
 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonMetricTest.scala
index 99ba335a7..f223dabdd 100644
--- 
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonMetricTest.scala
+++ 
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/PaimonMetricTest.scala
@@ -42,7 +42,7 @@ class PaimonMetricTest extends PaimonSparkTestBase {
       def checkMetrics(s: String, skippedTableFiles: Long, resultedTableFiles: 
Long): Unit = {
         val scan = getPaimonScan(s)
         // call getInputPartitions to trigger scan
-        scan.getInputPartitions
+        scan.lazyInputPartitions
         val metrics = scan.reportDriverMetrics()
         Assertions.assertEquals(skippedTableFiles, metric(metrics, 
SKIPPED_TABLE_FILES))
         Assertions.assertEquals(resultedTableFiles, metric(metrics, 
RESULTED_TABLE_FILES))

Reply via email to