This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new f63c6a6aa feat: implement framework to support multiple pyspark
benchmarks (#3080)
f63c6a6aa is described below
commit f63c6a6aab1b61e7fa9436e3ed0325675b4694ba
Author: Andy Grove <[email protected]>
AuthorDate: Tue Jan 20 14:58:55 2026 -0700
feat: implement framework to support multiple pyspark benchmarks (#3080)
---
.gitignore | 1 +
benchmarks/pyspark/README.md | 111 +++++++++++++++++++++----
benchmarks/pyspark/benchmarks/__init__.py | 79 ++++++++++++++++++
benchmarks/pyspark/benchmarks/base.py | 127 +++++++++++++++++++++++++++++
benchmarks/pyspark/benchmarks/shuffle.py | 130 ++++++++++++++++++++++++++++++
benchmarks/pyspark/run_benchmark.py | 109 +++++++++++++------------
6 files changed, 491 insertions(+), 66 deletions(-)
diff --git a/.gitignore b/.gitignore
index 82f8de95d..9978e37bd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -20,3 +20,4 @@ dev/release/comet-rm/workdir
spark/benchmarks
.DS_Store
comet-event-trace.json
+__pycache__
diff --git a/benchmarks/pyspark/README.md b/benchmarks/pyspark/README.md
index 130870081..3fc55123f 100644
--- a/benchmarks/pyspark/README.md
+++ b/benchmarks/pyspark/README.md
@@ -17,9 +17,16 @@ specific language governing permissions and limitations
under the License.
-->
-# Shuffle Size Comparison Benchmark
+# PySpark Benchmarks
-Compares shuffle file sizes between Spark, Comet JVM, and Comet Native shuffle
implementations.
+A suite of PySpark benchmarks for comparing performance between Spark, Comet
JVM, and Comet Native implementations.
+
+## Available Benchmarks
+
+Run `python run_benchmark.py --list-benchmarks` to see all available
benchmarks:
+
+- **shuffle-hash** - Shuffle all columns using hash partitioning on group_key
+- **shuffle-roundrobin** - Shuffle all columns using round-robin partitioning
## Prerequisites
@@ -56,42 +63,116 @@ spark-submit \
| `--rows`, `-r` | 10000000 | Number of rows |
| `--partitions`, `-p` | 200 | Number of output partitions |
-## Step 2: Run Benchmark
+## Step 2: Run Benchmarks
-Run benchmarks and check Spark UI for shuffle sizes:
+### List Available Benchmarks
```bash
-SPARK_MASTER=spark://master:7077 \
-EXECUTOR_MEMORY=16g \
-./run_all_benchmarks.sh /tmp/shuffle-benchmark-data
+python run_benchmark.py --list-benchmarks
```
-Or run individual modes:
+### Run Individual Benchmarks
+
+You can run specific benchmarks by name:
```bash
-# Spark baseline
+# Hash partitioning shuffle - Spark baseline
spark-submit --master spark://master:7077 \
- run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark
+ run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark --benchmark
shuffle-hash
-# Comet JVM shuffle
+# Round-robin shuffle - Spark baseline
+spark-submit --master spark://master:7077 \
+ run_benchmark.py --data /tmp/shuffle-benchmark-data --mode spark --benchmark
shuffle-roundrobin
+
+# Hash partitioning - Comet JVM shuffle
spark-submit --master spark://master:7077 \
--jars /path/to/comet.jar \
--conf spark.comet.enabled=true \
--conf spark.comet.exec.shuffle.enabled=true \
--conf spark.comet.shuffle.mode=jvm \
--conf
spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager
\
- run_benchmark.py --data /tmp/shuffle-benchmark-data --mode jvm
+ run_benchmark.py --data /tmp/shuffle-benchmark-data --mode jvm --benchmark
shuffle-hash
-# Comet Native shuffle
+# Round-robin - Comet Native shuffle
spark-submit --master spark://master:7077 \
--jars /path/to/comet.jar \
--conf spark.comet.enabled=true \
--conf spark.comet.exec.shuffle.enabled=true \
- --conf spark.comet.shuffle.mode=native \
+ --conf spark.comet.exec.shuffle.mode=native \
--conf
spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager
\
- run_benchmark.py --data /tmp/shuffle-benchmark-data --mode native
+ run_benchmark.py --data /tmp/shuffle-benchmark-data --mode native
--benchmark shuffle-roundrobin
+```
+
+### Run All Benchmarks
+
+Use the provided script to run all benchmarks across all modes:
+
+```bash
+SPARK_MASTER=spark://master:7077 \
+EXECUTOR_MEMORY=16g \
+./run_all_benchmarks.sh /tmp/shuffle-benchmark-data
```
## Checking Results
Open the Spark UI (default: http://localhost:4040) during each benchmark run
to compare shuffle write sizes in the Stages tab.
+
+## Adding New Benchmarks
+
+The benchmark framework makes it easy to add new benchmarks:
+
+1. **Create a benchmark class** in `benchmarks/` directory (or add to existing
file):
+
+```python
+from benchmarks.base import Benchmark
+
+class MyBenchmark(Benchmark):
+ @classmethod
+ def name(cls) -> str:
+ return "my-benchmark"
+
+ @classmethod
+ def description(cls) -> str:
+ return "Description of what this benchmark does"
+
+ def run(self) -> Dict[str, Any]:
+ # Read data
+ df = self.spark.read.parquet(self.data_path)
+
+ # Run your benchmark operation
+ def benchmark_operation():
+ result = df.filter(...).groupBy(...).agg(...)
+ result.write.mode("overwrite").parquet("/tmp/output")
+
+ # Time it
+ duration_ms = self._time_operation(benchmark_operation)
+
+ return {
+ 'duration_ms': duration_ms,
+ # Add any other metrics you want to track
+ }
+```
+
+2. **Register the benchmark** in `benchmarks/__init__.py`:
+
+```python
+from .my_module import MyBenchmark
+
+_BENCHMARK_REGISTRY = {
+ # ... existing benchmarks
+ MyBenchmark.name(): MyBenchmark,
+}
+```
+
+3. **Run your new benchmark**:
+
+```bash
+python run_benchmark.py --data /path/to/data --mode spark --benchmark
my-benchmark
+```
+
+The base `Benchmark` class provides:
+
+- Automatic timing via `_time_operation()`
+- Standard output formatting via `execute_timed()`
+- Access to SparkSession, data path, and mode
+- Spark configuration printing
diff --git a/benchmarks/pyspark/benchmarks/__init__.py
b/benchmarks/pyspark/benchmarks/__init__.py
new file mode 100644
index 000000000..7d913a7d6
--- /dev/null
+++ b/benchmarks/pyspark/benchmarks/__init__.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Benchmark registry for PySpark benchmarks.
+
+This module provides a central registry for discovering and running benchmarks.
+"""
+
+from typing import Dict, Type, List
+
+from .base import Benchmark
+from .shuffle import ShuffleHashBenchmark, ShuffleRoundRobinBenchmark
+
+
+# Registry of all available benchmarks
+_BENCHMARK_REGISTRY: Dict[str, Type[Benchmark]] = {
+ ShuffleHashBenchmark.name(): ShuffleHashBenchmark,
+ ShuffleRoundRobinBenchmark.name(): ShuffleRoundRobinBenchmark,
+}
+
+
+def get_benchmark(name: str) -> Type[Benchmark]:
+ """
+ Get a benchmark class by name.
+
+ Args:
+ name: Benchmark name
+
+ Returns:
+ Benchmark class
+
+ Raises:
+ KeyError: If benchmark name is not found
+ """
+ if name not in _BENCHMARK_REGISTRY:
+ available = ", ".join(sorted(_BENCHMARK_REGISTRY.keys()))
+ raise KeyError(
+ f"Unknown benchmark: {name}. Available benchmarks: {available}"
+ )
+ return _BENCHMARK_REGISTRY[name]
+
+
+def list_benchmarks() -> List[tuple[str, str]]:
+ """
+ List all available benchmarks.
+
+ Returns:
+ List of (name, description) tuples
+ """
+ benchmarks = []
+ for name in sorted(_BENCHMARK_REGISTRY.keys()):
+ benchmark_cls = _BENCHMARK_REGISTRY[name]
+ benchmarks.append((name, benchmark_cls.description()))
+ return benchmarks
+
+
+__all__ = [
+ 'Benchmark',
+ 'get_benchmark',
+ 'list_benchmarks',
+ 'ShuffleHashBenchmark',
+ 'ShuffleRoundRobinBenchmark',
+]
diff --git a/benchmarks/pyspark/benchmarks/base.py
b/benchmarks/pyspark/benchmarks/base.py
new file mode 100644
index 000000000..7e8e8db5a
--- /dev/null
+++ b/benchmarks/pyspark/benchmarks/base.py
@@ -0,0 +1,127 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Base benchmark class providing common functionality for all benchmarks.
+"""
+
+import time
+from abc import ABC, abstractmethod
+from typing import Dict, Any
+
+from pyspark.sql import SparkSession
+
+
+class Benchmark(ABC):
+ """Base class for all PySpark benchmarks."""
+
+ def __init__(self, spark: SparkSession, data_path: str, mode: str):
+ """
+ Initialize benchmark.
+
+ Args:
+ spark: SparkSession instance
+ data_path: Path to input data
+ mode: Execution mode (spark, jvm, native)
+ """
+ self.spark = spark
+ self.data_path = data_path
+ self.mode = mode
+
+ @classmethod
+ @abstractmethod
+ def name(cls) -> str:
+ """Return the benchmark name (used for CLI)."""
+ pass
+
+ @classmethod
+ @abstractmethod
+ def description(cls) -> str:
+ """Return a short description of the benchmark."""
+ pass
+
+ @abstractmethod
+ def run(self) -> Dict[str, Any]:
+ """
+ Run the benchmark and return results.
+
+ Returns:
+ Dictionary containing benchmark results (must include
'duration_ms')
+ """
+ pass
+
+ def execute_timed(self) -> Dict[str, Any]:
+ """
+ Execute the benchmark with timing and standard output.
+
+ Returns:
+ Dictionary containing benchmark results
+ """
+ print(f"\n{'=' * 80}")
+ print(f"Benchmark: {self.name()}")
+ print(f"Mode: {self.mode.upper()}")
+ print(f"{'=' * 80}")
+ print(f"Data path: {self.data_path}")
+
+ # Print relevant Spark configuration
+ self._print_spark_config()
+
+ # Clear cache before running
+ self.spark.catalog.clearCache()
+
+ # Run the benchmark
+ print(f"\nRunning benchmark...")
+ results = self.run()
+
+ # Print results
+ print(f"\nDuration: {results['duration_ms']:,} ms")
+ if 'row_count' in results:
+ print(f"Rows processed: {results['row_count']:,}")
+
+ # Print any additional metrics
+ for key, value in results.items():
+ if key not in ['duration_ms', 'row_count']:
+ print(f"{key}: {value}")
+
+ print(f"{'=' * 80}\n")
+
+ return results
+
+ def _print_spark_config(self):
+ """Print relevant Spark configuration."""
+ conf = self.spark.sparkContext.getConf()
+ print(f"Shuffle manager: {conf.get('spark.shuffle.manager',
'default')}")
+ print(f"Comet enabled: {conf.get('spark.comet.enabled', 'false')}")
+ print(f"Comet shuffle enabled:
{conf.get('spark.comet.exec.shuffle.enabled', 'false')}")
+ print(f"Comet shuffle mode: {conf.get('spark.comet.shuffle.mode', 'not
set')}")
+ print(f"Spark UI: {self.spark.sparkContext.uiWebUrl}")
+
+ def _time_operation(self, operation_fn):
+ """
+ Time an operation and return duration in milliseconds.
+
+ Args:
+ operation_fn: Function to time (takes no arguments)
+
+ Returns:
+ Duration in milliseconds
+ """
+ start_time = time.time()
+ operation_fn()
+ duration_ms = int((time.time() - start_time) * 1000)
+ return duration_ms
diff --git a/benchmarks/pyspark/benchmarks/shuffle.py
b/benchmarks/pyspark/benchmarks/shuffle.py
new file mode 100644
index 000000000..0facd2340
--- /dev/null
+++ b/benchmarks/pyspark/benchmarks/shuffle.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Shuffle benchmarks for comparing shuffle file sizes and performance.
+
+These benchmarks test different partitioning strategies (hash, round-robin)
+across Spark, Comet JVM, and Comet Native shuffle implementations.
+"""
+
+from typing import Dict, Any
+from pyspark.sql import DataFrame
+
+from .base import Benchmark
+
+
+class ShuffleBenchmark(Benchmark):
+ """Base class for shuffle benchmarks with common repartitioning logic."""
+
+ def __init__(self, spark, data_path: str, mode: str, num_partitions: int =
200):
+ """
+ Initialize shuffle benchmark.
+
+ Args:
+ spark: SparkSession instance
+ data_path: Path to input parquet data
+ mode: Execution mode (spark, jvm, native)
+ num_partitions: Number of partitions to shuffle to
+ """
+ super().__init__(spark, data_path, mode)
+ self.num_partitions = num_partitions
+
+ def _read_and_count(self) -> tuple[DataFrame, int]:
+ """Read input data and count rows."""
+ df = self.spark.read.parquet(self.data_path)
+ row_count = df.count()
+ return df, row_count
+
+ def _repartition(self, df: DataFrame) -> DataFrame:
+ """
+ Repartition dataframe using the strategy defined by subclass.
+
+ Args:
+ df: Input dataframe
+
+ Returns:
+ Repartitioned dataframe
+ """
+ raise NotImplementedError("Subclasses must implement _repartition")
+
+ def _write_output(self, df: DataFrame, output_path: str):
+ """Write repartitioned data to parquet."""
+ df.write.mode("overwrite").parquet(output_path)
+
+ def run(self) -> Dict[str, Any]:
+ """
+ Run the shuffle benchmark.
+
+ Returns:
+ Dictionary with duration_ms and row_count
+ """
+ # Read input data
+ df, row_count = self._read_and_count()
+ print(f"Number of rows: {row_count:,}")
+
+ # Define the benchmark operation
+ def benchmark_operation():
+ # Repartition using the specific strategy
+ repartitioned = self._repartition(df)
+
+ # Write to parquet to force materialization
+ output_path =
f"/tmp/shuffle-benchmark-output-{self.mode}-{self.name()}"
+ self._write_output(repartitioned, output_path)
+ print(f"Wrote repartitioned data to: {output_path}")
+
+ # Time the operation
+ duration_ms = self._time_operation(benchmark_operation)
+
+ return {
+ 'duration_ms': duration_ms,
+ 'row_count': row_count,
+ 'num_partitions': self.num_partitions,
+ }
+
+
+class ShuffleHashBenchmark(ShuffleBenchmark):
+ """Shuffle benchmark using hash partitioning on a key column."""
+
+ @classmethod
+ def name(cls) -> str:
+ return "shuffle-hash"
+
+ @classmethod
+ def description(cls) -> str:
+ return "Shuffle all columns using hash partitioning on group_key"
+
+ def _repartition(self, df: DataFrame) -> DataFrame:
+ """Repartition using hash partitioning on group_key."""
+ return df.repartition(self.num_partitions, "group_key")
+
+
+class ShuffleRoundRobinBenchmark(ShuffleBenchmark):
+ """Shuffle benchmark using round-robin partitioning."""
+
+ @classmethod
+ def name(cls) -> str:
+ return "shuffle-roundrobin"
+
+ @classmethod
+ def description(cls) -> str:
+ return "Shuffle all columns using round-robin partitioning"
+
+ def _repartition(self, df: DataFrame) -> DataFrame:
+ """Repartition using round-robin (no partition columns specified)."""
+ return df.repartition(self.num_partitions)
diff --git a/benchmarks/pyspark/run_benchmark.py
b/benchmarks/pyspark/run_benchmark.py
index 3f40b7c93..6713f0ff2 100755
--- a/benchmarks/pyspark/run_benchmark.py
+++ b/benchmarks/pyspark/run_benchmark.py
@@ -17,88 +17,95 @@
# under the License.
"""
-Run shuffle size comparison benchmark.
+Run PySpark benchmarks.
-Run this script once per mode (spark, jvm, native) with appropriate
spark-submit configs.
-Check the Spark UI to compare shuffle sizes between modes.
+Run benchmarks by name with appropriate spark-submit configs for different
modes
+(spark, jvm, native). Check the Spark UI to compare results between modes.
"""
import argparse
-import time
-import json
+import sys
from pyspark.sql import SparkSession
-
-def run_benchmark(spark: SparkSession, data_path: str, mode: str) -> int:
- """Run the benchmark query and return duration in ms."""
-
- spark.catalog.clearCache()
-
- df = spark.read.parquet(data_path)
- row_count = df.count()
- print(f"Number of rows: {row_count:,}")
-
- start_time = time.time()
-
- # Repartition by a different key to force full shuffle of all columns
- # This shuffles all 50 columns including nested structs, arrays, maps
- repartitioned = df.repartition(200, "group_key")
-
- # Write to parquet to force materialization
- output_path = f"/tmp/shuffle-benchmark-output-{mode}"
- repartitioned.write.mode("overwrite").parquet(output_path)
- print(f"Wrote repartitioned data to: {output_path}")
-
- duration_ms = int((time.time() - start_time) * 1000)
- return duration_ms
+from benchmarks import get_benchmark, list_benchmarks
def main():
parser = argparse.ArgumentParser(
- description="Run shuffle benchmark for a single mode"
+ description="Run PySpark benchmarks",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Run hash partitioning shuffle benchmark in Spark mode
+ python run_benchmark.py --data /path/to/data --mode spark --benchmark
shuffle-hash
+
+ # Run round-robin shuffle benchmark in Comet native mode
+ python run_benchmark.py --data /path/to/data --mode native --benchmark
shuffle-roundrobin
+
+ # List all available benchmarks
+ python run_benchmark.py --list-benchmarks
+ """
)
parser.add_argument(
"--data", "-d",
- required=True,
help="Path to input parquet data"
)
parser.add_argument(
"--mode", "-m",
- required=True,
choices=["spark", "jvm", "native"],
help="Shuffle mode being tested"
)
+ parser.add_argument(
+ "--benchmark", "-b",
+ default="shuffle-hash",
+ help="Benchmark to run (default: shuffle-hash)"
+ )
+ parser.add_argument(
+ "--list-benchmarks",
+ action="store_true",
+ help="List all available benchmarks and exit"
+ )
args = parser.parse_args()
- spark = SparkSession.builder \
- .appName(f"ShuffleBenchmark-{args.mode.upper()}") \
- .getOrCreate()
+ # Handle --list-benchmarks
+ if args.list_benchmarks:
+ print("Available benchmarks:\n")
+ for name, description in list_benchmarks():
+ print(f" {name:25s} - {description}")
+ return 0
- print("\n" + "=" * 80)
- print(f"Shuffle Benchmark: {args.mode.upper()}")
- print("=" * 80)
- print(f"Data path: {args.data}")
+ # Validate required arguments
+ if not args.data:
+ parser.error("--data is required when running a benchmark")
+ if not args.mode:
+ parser.error("--mode is required when running a benchmark")
- # Print shuffle configuration
- conf = spark.sparkContext.getConf()
- print(f"Shuffle manager: {conf.get('spark.shuffle.manager', 'default')}")
- print(f"Comet enabled: {conf.get('spark.comet.enabled', 'false')}")
- print(f"Comet shuffle enabled:
{conf.get('spark.comet.exec.shuffle.enabled', 'false')}")
- print(f"Comet shuffle mode: {conf.get('spark.comet.shuffle.mode', 'not
set')}")
- print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
+ # Get the benchmark class
+ try:
+ benchmark_cls = get_benchmark(args.benchmark)
+ except KeyError as e:
+ print(f"Error: {e}", file=sys.stderr)
+ print("\nUse --list-benchmarks to see available benchmarks",
file=sys.stderr)
+ return 1
+
+ # Create Spark session
+ spark = SparkSession.builder \
+ .appName(f"{benchmark_cls.name()}-{args.mode.upper()}") \
+ .getOrCreate()
try:
- duration_ms = run_benchmark(spark, args.data, args.mode)
- print(f"\nDuration: {duration_ms:,} ms")
- print("\nCheck Spark UI for shuffle sizes")
+ # Create and run the benchmark
+ benchmark = benchmark_cls(spark, args.data, args.mode)
+ results = benchmark.execute_timed()
+
+ print("\nCheck Spark UI for shuffle sizes and detailed metrics")
+ return 0
finally:
spark.stop()
- print("=" * 80 + "\n")
-
if __name__ == "__main__":
- main()
+ sys.exit(main())
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]