This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 7e206e2 refactor: Remove a few duplicated occurrences (#53)
7e206e2 is described below
commit 7e206e2019f31052743cb5d0809dc6983e5897b5
Author: Chao Sun <[email protected]>
AuthorDate: Tue Feb 20 09:01:55 2024 -0800
refactor: Remove a few duplicated occurrences (#53)
---
.../scala/org/apache/comet/CometExecIterator.scala | 49 +++++++++++++++-------
.../org/apache/spark/sql/comet/operators.scala | 44 ++++---------------
2 files changed, 41 insertions(+), 52 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index 029be29..0140582 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -19,12 +19,11 @@
package org.apache.comet
-import java.util.HashMap
-
import org.apache.spark._
import org.apache.spark.sql.comet.CometMetricNode
import org.apache.spark.sql.vectorized._
+import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED,
COMET_EXEC_MEMORY_FRACTION}
import org.apache.comet.vector.NativeUtil
/**
@@ -45,36 +44,31 @@ class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
protobufQueryPlan: Array[Byte],
- configs: HashMap[String, String],
nativeMetrics: CometMetricNode)
extends Iterator[ColumnarBatch] {
private val nativeLib = new Native()
- private val plan = nativeLib.createPlan(id, configs, protobufQueryPlan,
nativeMetrics)
+ private val plan = {
+ val configs = createNativeConf
+ nativeLib.createPlan(id, configs, protobufQueryPlan, nativeMetrics)
+ }
private val nativeUtil = new NativeUtil
private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false
private def peekNext(): ExecutionState = {
- val result = nativeLib.peekNext(plan)
- val flag = result(0)
-
- if (flag == 0) Pending
- else if (flag == 1) {
- val numRows = result(1)
- val addresses = result.slice(2, result.length)
- Batch(numRows = numRows.toInt, addresses = addresses)
- } else {
- throw new IllegalStateException(s"Invalid native flag: $flag")
- }
+ convertNativeResult(nativeLib.peekNext(plan))
}
private def executeNative(
input: Array[Array[Long]],
finishes: Array[Boolean],
numRows: Int): ExecutionState = {
- val result = nativeLib.executePlan(plan, input, finishes, numRows)
+ convertNativeResult(nativeLib.executePlan(plan, input, finishes, numRows))
+ }
+
+ private def convertNativeResult(result: Array[Long]): ExecutionState = {
val flag = result(0)
if (flag == -1) EOF
else if (flag == 0) Pending
@@ -87,6 +81,29 @@ class CometExecIterator(
}
}
+ /**
+ * Creates a new configuration map to be passed to the native side.
+ */
+ private def createNativeConf: java.util.HashMap[String, String] = {
+ val result = new java.util.HashMap[String, String]()
+ val conf = SparkEnv.get.conf
+
+ val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
+ result.put("memory_limit", String.valueOf(maxMemory))
+ result.put("memory_fraction",
String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
+ result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
+ result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
+
+ // Strip mandatory prefix spark. which is not required for DataFusion
session params
+ conf.getAll.foreach {
+ case (k, v) if k.startsWith("spark.datafusion") =>
+ result.put(k.replaceFirst("spark\\.", ""), v)
+ case _ =>
+ }
+
+ result
+ }
+
/** Execution result from Comet native */
trait ExecutionState
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 7ac1084..4d8011e 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -23,7 +23,7 @@ import java.io.ByteArrayOutputStream
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, NamedExpression, SortOrder}
@@ -31,12 +31,12 @@ import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ColumnarToRowExec,
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan,
UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
import com.google.common.base.Objects
-import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException,
CometSparkSessionExtensions}
-import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED,
COMET_EXEC_MEMORY_FRACTION}
+import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
import org.apache.comet.serde.OperatorOuterClass.Operator
/**
@@ -83,17 +83,7 @@ object CometExec {
nativePlan.writeTo(outputStream)
outputStream.close()
val bytes = outputStream.toByteArray
-
- val configs = new java.util.HashMap[String, String]()
-
- val maxMemory =
- CometSparkSessionExtensions.getCometMemoryOverhead(SparkEnv.get.conf)
- configs.put("memory_limit", String.valueOf(maxMemory))
- configs.put("memory_fraction",
String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
- configs.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
- configs.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
-
- new CometExecIterator(newIterId, inputs, bytes, configs, nativeMetrics)
+ new CometExecIterator(newIterId, inputs, bytes, nativeMetrics)
}
}
@@ -163,33 +153,15 @@ abstract class CometNativeExec extends CometExec {
case Some(serializedPlan) =>
// Switch to use Decimal128 regardless of precision, since Arrow
native execution
// doesn't support Decimal32 and Decimal64 yet.
- conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
-
- // Populate native configurations
- val configs = new java.util.HashMap[String, String]()
- val maxMemory =
CometSparkSessionExtensions.getCometMemoryOverhead(sparkContext.getConf)
- configs.put("memory_limit", String.valueOf(maxMemory))
- configs.put("memory_fraction",
String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
- configs.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
- configs.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
-
- // Strip mandatory prefix spark. which is not required for datafusion
session params
- session.conf.getAll.foreach {
- case (k, v) if k.startsWith("spark.datafusion") =>
- configs.put(k.replaceFirst("spark\\.", ""), v)
- case _ =>
- }
+ SQLConf.get.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
+
val serializedPlanCopy = serializedPlan
// TODO: support native metrics for all operators.
val nativeMetrics = CometMetricNode.fromCometPlan(this)
def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]):
CometExecIterator = {
- val it = new CometExecIterator(
- CometExec.newIterId,
- inputs,
- serializedPlanCopy,
- configs,
- nativeMetrics)
+ val it =
+ new CometExecIterator(CometExec.newIterId, inputs,
serializedPlanCopy, nativeMetrics)
setSubqueries(it.id, originalPlan)