This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new f63c6a6aa feat: implement framework to support multiple pyspark 
benchmarks (#3080)
f63c6a6aa is described below

commit f63c6a6aab1b61e7fa9436e3ed0325675b4694ba
Author: Andy Grove <[email protected]>
AuthorDate: Tue Jan 20 14:58:55 2026 -0700

    feat: implement framework to support multiple pyspark benchmarks (#3080)
---
 .gitignore                                |   1 +
 benchmarks/pyspark/README.md              | 111 +++++++++++++++++++++----
 benchmarks/pyspark/benchmarks/__init__.py |  79 ++++++++++++++++++
 benchmarks/pyspark/benchmarks/base.py     | 127 +++++++++++++++++++++++++++++
 benchmarks/pyspark/benchmarks/shuffle.py  | 130 ++++++++++++++++++++++++++++++
 benchmarks/pyspark/run_benchmark.py       | 109 +++++++++++++------------
 6 files changed, 491 insertions(+), 66 deletions(-)

diff --git a/.gitignore b/.gitignore
index 82f8de95d..9978e37bd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,3 +20,4 @@ dev/release/comet-rm/workdir
 spark/benchmarks
 .DS_Store
 comet-event-trace.json
+__pycache__
diff --git a/benchmarks/pyspark/README.md b/benchmarks/pyspark/README.md
index 130870081..3fc55123f 100644
--- a/benchmarks/pyspark/README.md
+++ b/benchmarks/pyspark/README.md
@@ -17,9 +17,16 @@ specific language governing permissions and limitations
 under the License.
 -->
 
-# Shuffle Size Comparison Benchmark
+# PySpark Benchmarks
 
-Compares shuffle file sizes between Spark, Comet JVM, and Comet Native shuffle 
implementations.
+A suite of PySpark benchmarks for comparing performance between Spark, Comet 
JVM, and Comet Native implementations.
+
+## Available Benchmarks
+
+Run `python run_benchmark.py --list-benchmarks` to see all available 
benchmarks:
+
+- **shuffle-hash** - Shuffle all columns using hash partitioning on group_key
+- **shuffle-roundrobin** - Shuffle all columns using round-robin partitioning
 
 ## Prerequisites
 
@@ -56,42 +63,116 @@ spark-submit \
 | `--rows`, `-r`       | 10000000   | Number of rows               |
 | `--partitions`, `-p` | 200        | Number of output partitions  |
 
-## Step 2: Run Benchmark
+## Step 2: Run Benchmarks
 
-Run benchmarks and check Spark UI for shuffle sizes:
+### List Available Benchmarks
 
 ```bash
-SPARK_MASTER=spark://master:7077 \
-EXECUTOR_MEMORY=16g \
-./run_all_benchmarks.sh /tmp/shuffle-benchmark-data
+python run_benchmark.py --list-benchmarks
 ```
 
-Or run individual modes:
+### Run Individual Benchmarks
+
+You can run specific benchmarks by name:
 
 ```bash
-# Spark baseline
+# Hash partitioning shuffle - Spark baseline
 spark-submit --master spark://master:7077 \
-  run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark
+  run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark --benchmark 
shuffle-hash
 
-# Comet JVM shuffle
+# Round-robin shuffle - Spark baseline
+spark-submit --master spark://master:7077 \
+  run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark --benchmark 
shuffle-roundrobin
+
+# Hash partitioning - Comet JVM shuffle
 spark-submit --master spark://master:7077 \
   --jars /path/to/comet.jar \
   --conf spark.comet.enabled=true \
   --conf spark.comet.exec.shuffle.enabled=true \
   --conf spark.comet.shuffle.mode=jvm \
   --conf 
spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager
 \
-  run_benchmark.py --data /tmp/shuffle-benchmark-data --mode jvm
+  run_benchmark.py --data /tmp/shuffle-benchmark-data --mode jvm --benchmark 
shuffle-hash
 
-# Comet Native shuffle
+# Round-robin - Comet Native shuffle
 spark-submit --master spark://master:7077 \
   --jars /path/to/comet.jar \
   --conf spark.comet.enabled=true \
   --conf spark.comet.exec.shuffle.enabled=true \
-  --conf spark.comet.shuffle.mode=native \
+  --conf spark.comet.exec.shuffle.mode=native \
   --conf 
spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager
 \
-  run_benchmark.py --data /tmp/shuffle-benchmark-data --mode native
+  run_benchmark.py --data /tmp/shuffle-benchmark-data --mode native 
--benchmark shuffle-roundrobin
+```
+
+### Run All Benchmarks
+
+Use the provided script to run all benchmarks across all modes:
+
+```bash
+SPARK_MASTER=spark://master:7077 \
+EXECUTOR_MEMORY=16g \
+./run_all_benchmarks.sh /tmp/shuffle-benchmark-data
 ```
 
 ## Checking Results
 
 Open the Spark UI (default: http://localhost:4040) during each benchmark run 
to compare shuffle write sizes in the Stages tab.
+
+## Adding New Benchmarks
+
+The benchmark framework makes it easy to add new benchmarks:
+
+1. **Create a benchmark class** in `benchmarks/` directory (or add to existing 
file):
+
+```python
+from benchmarks.base import Benchmark
+
+class MyBenchmark(Benchmark):
+    @classmethod
+    def name(cls) -> str:
+        return "my-benchmark"
+
+    @classmethod
+    def description(cls) -> str:
+        return "Description of what this benchmark does"
+
+    def run(self) -> Dict[str, Any]:
+        # Read data
+        df = self.spark.read.parquet(self.data_path)
+
+        # Run your benchmark operation
+        def benchmark_operation():
+            result = df.filter(...).groupBy(...).agg(...)
+            result.write.mode("overwrite").parquet("/tmp/output")
+
+        # Time it
+        duration_ms = self._time_operation(benchmark_operation)
+
+        return {
+            'duration_ms': duration_ms,
+            # Add any other metrics you want to track
+        }
+```
+
+2. **Register the benchmark** in `benchmarks/__init__.py`:
+
+```python
+from .my_module import MyBenchmark
+
+_BENCHMARK_REGISTRY = {
+    # ... existing benchmarks
+    MyBenchmark.name(): MyBenchmark,
+}
+```
+
+3. **Run your new benchmark**:
+
+```bash
+python run_benchmark.py --data /path/to/data --mode spark --benchmark 
my-benchmark
+```
+
+The base `Benchmark` class provides:
+
+- Automatic timing via `_time_operation()`
+- Standard output formatting via `execute_timed()`
+- Access to SparkSession, data path, and mode
+- Spark configuration printing
diff --git a/benchmarks/pyspark/benchmarks/__init__.py 
b/benchmarks/pyspark/benchmarks/__init__.py
new file mode 100644
index 000000000..7d913a7d6
--- /dev/null
+++ b/benchmarks/pyspark/benchmarks/__init__.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Benchmark registry for PySpark benchmarks.
+
+This module provides a central registry for discovering and running benchmarks.
+"""
+
+from typing import Dict, Type, List
+
+from .base import Benchmark
+from .shuffle import ShuffleHashBenchmark, ShuffleRoundRobinBenchmark
+
+
+# Registry of all available benchmarks
+_BENCHMARK_REGISTRY: Dict[str, Type[Benchmark]] = {
+    ShuffleHashBenchmark.name(): ShuffleHashBenchmark,
+    ShuffleRoundRobinBenchmark.name(): ShuffleRoundRobinBenchmark,
+}
+
+
+def get_benchmark(name: str) -> Type[Benchmark]:
+    """
+    Get a benchmark class by name.
+
+    Args:
+        name: Benchmark name
+
+    Returns:
+        Benchmark class
+
+    Raises:
+        KeyError: If benchmark name is not found
+    """
+    if name not in _BENCHMARK_REGISTRY:
+        available = ", ".join(sorted(_BENCHMARK_REGISTRY.keys()))
+        raise KeyError(
+            f"Unknown benchmark: {name}. Available benchmarks: {available}"
+        )
+    return _BENCHMARK_REGISTRY[name]
+
+
+def list_benchmarks() -> List[tuple[str, str]]:
+    """
+    List all available benchmarks.
+
+    Returns:
+        List of (name, description) tuples
+    """
+    benchmarks = []
+    for name in sorted(_BENCHMARK_REGISTRY.keys()):
+        benchmark_cls = _BENCHMARK_REGISTRY[name]
+        benchmarks.append((name, benchmark_cls.description()))
+    return benchmarks
+
+
+__all__ = [
+    'Benchmark',
+    'get_benchmark',
+    'list_benchmarks',
+    'ShuffleHashBenchmark',
+    'ShuffleRoundRobinBenchmark',
+]
diff --git a/benchmarks/pyspark/benchmarks/base.py 
b/benchmarks/pyspark/benchmarks/base.py
new file mode 100644
index 000000000..7e8e8db5a
--- /dev/null
+++ b/benchmarks/pyspark/benchmarks/base.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Base benchmark class providing common functionality for all benchmarks.
+"""
+
+import time
+from abc import ABC, abstractmethod
+from typing import Dict, Any
+
+from pyspark.sql import SparkSession
+
+
+class Benchmark(ABC):
+    """Base class for all PySpark benchmarks."""
+
+    def __init__(self, spark: SparkSession, data_path: str, mode: str):
+        """
+        Initialize benchmark.
+
+        Args:
+            spark: SparkSession instance
+            data_path: Path to input data
+            mode: Execution mode (spark, jvm, native)
+        """
+        self.spark = spark
+        self.data_path = data_path
+        self.mode = mode
+
+    @classmethod
+    @abstractmethod
+    def name(cls) -> str:
+        """Return the benchmark name (used for CLI)."""
+        pass
+
+    @classmethod
+    @abstractmethod
+    def description(cls) -> str:
+        """Return a short description of the benchmark."""
+        pass
+
+    @abstractmethod
+    def run(self) -> Dict[str, Any]:
+        """
+        Run the benchmark and return results.
+
+        Returns:
+            Dictionary containing benchmark results (must include 
'duration_ms')
+        """
+        pass
+
+    def execute_timed(self) -> Dict[str, Any]:
+        """
+        Execute the benchmark with timing and standard output.
+
+        Returns:
+            Dictionary containing benchmark results
+        """
+        print(f"\n{'=' * 80}")
+        print(f"Benchmark: {self.name()}")
+        print(f"Mode: {self.mode.upper()}")
+        print(f"{'=' * 80}")
+        print(f"Data path: {self.data_path}")
+
+        # Print relevant Spark configuration
+        self._print_spark_config()
+
+        # Clear cache before running
+        self.spark.catalog.clearCache()
+
+        # Run the benchmark
+        print(f"\nRunning benchmark...")
+        results = self.run()
+
+        # Print results
+        print(f"\nDuration: {results['duration_ms']:,} ms")
+        if 'row_count' in results:
+            print(f"Rows processed: {results['row_count']:,}")
+
+        # Print any additional metrics
+        for key, value in results.items():
+            if key not in ['duration_ms', 'row_count']:
+                print(f"{key}: {value}")
+
+        print(f"{'=' * 80}\n")
+
+        return results
+
+    def _print_spark_config(self):
+        """Print relevant Spark configuration."""
+        conf = self.spark.sparkContext.getConf()
+        print(f"Shuffle manager: {conf.get('spark.shuffle.manager', 
'default')}")
+        print(f"Comet enabled: {conf.get('spark.comet.enabled', 'false')}")
+        print(f"Comet shuffle enabled: 
{conf.get('spark.comet.exec.shuffle.enabled', 'false')}")
+        print(f"Comet shuffle mode: {conf.get('spark.comet.shuffle.mode', 'not 
set')}")
+        print(f"Spark UI: {self.spark.sparkContext.uiWebUrl}")
+
+    def _time_operation(self, operation_fn):
+        """
+        Time an operation and return duration in milliseconds.
+
+        Args:
+            operation_fn: Function to time (takes no arguments)
+
+        Returns:
+            Duration in milliseconds
+        """
+        start_time = time.time()
+        operation_fn()
+        duration_ms = int((time.time() - start_time) * 1000)
+        return duration_ms
diff --git a/benchmarks/pyspark/benchmarks/shuffle.py 
b/benchmarks/pyspark/benchmarks/shuffle.py
new file mode 100644
index 000000000..0facd2340
--- /dev/null
+++ b/benchmarks/pyspark/benchmarks/shuffle.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Shuffle benchmarks for comparing shuffle file sizes and performance.
+
+These benchmarks test different partitioning strategies (hash, round-robin)
+across Spark, Comet JVM, and Comet Native shuffle implementations.
+"""
+
+from typing import Dict, Any
+from pyspark.sql import DataFrame
+
+from .base import Benchmark
+
+
+class ShuffleBenchmark(Benchmark):
+    """Base class for shuffle benchmarks with common repartitioning logic."""
+
+    def __init__(self, spark, data_path: str, mode: str, num_partitions: int = 
200):
+        """
+        Initialize shuffle benchmark.
+
+        Args:
+            spark: SparkSession instance
+            data_path: Path to input parquet data
+            mode: Execution mode (spark, jvm, native)
+            num_partitions: Number of partitions to shuffle to
+        """
+        super().__init__(spark, data_path, mode)
+        self.num_partitions = num_partitions
+
+    def _read_and_count(self) -> tuple[DataFrame, int]:
+        """Read input data and count rows."""
+        df = self.spark.read.parquet(self.data_path)
+        row_count = df.count()
+        return df, row_count
+
+    def _repartition(self, df: DataFrame) -> DataFrame:
+        """
+        Repartition dataframe using the strategy defined by subclass.
+
+        Args:
+            df: Input dataframe
+
+        Returns:
+            Repartitioned dataframe
+        """
+        raise NotImplementedError("Subclasses must implement _repartition")
+
+    def _write_output(self, df: DataFrame, output_path: str):
+        """Write repartitioned data to parquet."""
+        df.write.mode("overwrite").parquet(output_path)
+
+    def run(self) -> Dict[str, Any]:
+        """
+        Run the shuffle benchmark.
+
+        Returns:
+            Dictionary with duration_ms and row_count
+        """
+        # Read input data
+        df, row_count = self._read_and_count()
+        print(f"Number of rows: {row_count:,}")
+
+        # Define the benchmark operation
+        def benchmark_operation():
+            # Repartition using the specific strategy
+            repartitioned = self._repartition(df)
+
+            # Write to parquet to force materialization
+            output_path = 
f"/tmp/shuffle-benchmark-output-{self.mode}-{self.name()}"
+            self._write_output(repartitioned, output_path)
+            print(f"Wrote repartitioned data to: {output_path}")
+
+        # Time the operation
+        duration_ms = self._time_operation(benchmark_operation)
+
+        return {
+            'duration_ms': duration_ms,
+            'row_count': row_count,
+            'num_partitions': self.num_partitions,
+        }
+
+
+class ShuffleHashBenchmark(ShuffleBenchmark):
+    """Shuffle benchmark using hash partitioning on a key column."""
+
+    @classmethod
+    def name(cls) -> str:
+        return "shuffle-hash"
+
+    @classmethod
+    def description(cls) -> str:
+        return "Shuffle all columns using hash partitioning on group_key"
+
+    def _repartition(self, df: DataFrame) -> DataFrame:
+        """Repartition using hash partitioning on group_key."""
+        return df.repartition(self.num_partitions, "group_key")
+
+
+class ShuffleRoundRobinBenchmark(ShuffleBenchmark):
+    """Shuffle benchmark using round-robin partitioning."""
+
+    @classmethod
+    def name(cls) -> str:
+        return "shuffle-roundrobin"
+
+    @classmethod
+    def description(cls) -> str:
+        return "Shuffle all columns using round-robin partitioning"
+
+    def _repartition(self, df: DataFrame) -> DataFrame:
+        """Repartition using round-robin (no partition columns specified)."""
+        return df.repartition(self.num_partitions)
diff --git a/benchmarks/pyspark/run_benchmark.py 
b/benchmarks/pyspark/run_benchmark.py
index 3f40b7c93..6713f0ff2 100755
--- a/benchmarks/pyspark/run_benchmark.py
+++ b/benchmarks/pyspark/run_benchmark.py
@@ -17,88 +17,95 @@
 # under the License.
 
 """
-Run shuffle size comparison benchmark.
+Run PySpark benchmarks.
 
-Run this script once per mode (spark, jvm, native) with appropriate 
spark-submit configs.
-Check the Spark UI to compare shuffle sizes between modes.
+Run benchmarks by name with appropriate spark-submit configs for different 
modes
+(spark, jvm, native). Check the Spark UI to compare results between modes.
 """
 
 import argparse
-import time
-import json
+import sys
 
 from pyspark.sql import SparkSession
 
-
-def run_benchmark(spark: SparkSession, data_path: str, mode: str) -> int:
-    """Run the benchmark query and return duration in ms."""
-
-    spark.catalog.clearCache()
-
-    df = spark.read.parquet(data_path)
-    row_count = df.count()
-    print(f"Number of rows: {row_count:,}")
-
-    start_time = time.time()
-
-    # Repartition by a different key to force full shuffle of all columns
-    # This shuffles all 50 columns including nested structs, arrays, maps
-    repartitioned = df.repartition(200, "group_key")
-
-    # Write to parquet to force materialization
-    output_path = f"/tmp/shuffle-benchmark-output-{mode}"
-    repartitioned.write.mode("overwrite").parquet(output_path)
-    print(f"Wrote repartitioned data to: {output_path}")
-
-    duration_ms = int((time.time() - start_time) * 1000)
-    return duration_ms
+from benchmarks import get_benchmark, list_benchmarks
 
 
 def main():
     parser = argparse.ArgumentParser(
-        description="Run shuffle benchmark for a single mode"
+        description="Run PySpark benchmarks",
+        formatter_class=argparse.RawDescriptionHelpFormatter,
+        epilog="""
+Examples:
+  # Run hash partitioning shuffle benchmark in Spark mode
+  python run_benchmark.py --data /path/to/data --mode spark --benchmark 
shuffle-hash
+
+  # Run round-robin shuffle benchmark in Comet native mode
+  python run_benchmark.py --data /path/to/data --mode native --benchmark 
shuffle-roundrobin
+
+  # List all available benchmarks
+  python run_benchmark.py --list-benchmarks
+        """
     )
     parser.add_argument(
         "--data", "-d",
-        required=True,
         help="Path to input parquet data"
     )
     parser.add_argument(
         "--mode", "-m",
-        required=True,
         choices=["spark", "jvm", "native"],
         help="Shuffle mode being tested"
     )
+    parser.add_argument(
+        "--benchmark", "-b",
+        default="shuffle-hash",
+        help="Benchmark to run (default: shuffle-hash)"
+    )
+    parser.add_argument(
+        "--list-benchmarks",
+        action="store_true",
+        help="List all available benchmarks and exit"
+    )
 
     args = parser.parse_args()
 
-    spark = SparkSession.builder \
-        .appName(f"ShuffleBenchmark-{args.mode.upper()}") \
-        .getOrCreate()
+    # Handle --list-benchmarks
+    if args.list_benchmarks:
+        print("Available benchmarks:\n")
+        for name, description in list_benchmarks():
+            print(f"  {name:25s} - {description}")
+        return 0
 
-    print("\n" + "=" * 80)
-    print(f"Shuffle Benchmark: {args.mode.upper()}")
-    print("=" * 80)
-    print(f"Data path: {args.data}")
+    # Validate required arguments
+    if not args.data:
+        parser.error("--data is required when running a benchmark")
+    if not args.mode:
+        parser.error("--mode is required when running a benchmark")
 
-    # Print shuffle configuration
-    conf = spark.sparkContext.getConf()
-    print(f"Shuffle manager: {conf.get('spark.shuffle.manager', 'default')}")
-    print(f"Comet enabled: {conf.get('spark.comet.enabled', 'false')}")
-    print(f"Comet shuffle enabled: 
{conf.get('spark.comet.exec.shuffle.enabled', 'false')}")
-    print(f"Comet shuffle mode: {conf.get('spark.comet.shuffle.mode', 'not 
set')}")
-    print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
+    # Get the benchmark class
+    try:
+        benchmark_cls = get_benchmark(args.benchmark)
+    except KeyError as e:
+        print(f"Error: {e}", file=sys.stderr)
+        print("\nUse --list-benchmarks to see available benchmarks", 
file=sys.stderr)
+        return 1
+
+    # Create Spark session
+    spark = SparkSession.builder \
+        .appName(f"{benchmark_cls.name()}-{args.mode.upper()}") \
+        .getOrCreate()
 
     try:
-        duration_ms = run_benchmark(spark, args.data, args.mode)
-        print(f"\nDuration: {duration_ms:,} ms")
-        print("\nCheck Spark UI for shuffle sizes")
+        # Create and run the benchmark
+        benchmark = benchmark_cls(spark, args.data, args.mode)
+        results = benchmark.execute_timed()
+
+        print("\nCheck Spark UI for shuffle sizes and detailed metrics")
+        return 0
 
     finally:
         spark.stop()
 
-    print("=" * 80 + "\n")
-
 
 if __name__ == "__main__":
-    main()
+    sys.exit(main())


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to