This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e206e2  refactor: Remove a few duplicated occurrences (#53)
7e206e2 is described below

commit 7e206e2019f31052743cb5d0809dc6983e5897b5
Author: Chao Sun <[email protected]>
AuthorDate: Tue Feb 20 09:01:55 2024 -0800

    refactor: Remove a few duplicated occurrences (#53)
---
 .../scala/org/apache/comet/CometExecIterator.scala | 49 +++++++++++++++-------
 .../org/apache/spark/sql/comet/operators.scala     | 44 ++++---------------
 2 files changed, 41 insertions(+), 52 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala 
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index 029be29..0140582 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -19,12 +19,11 @@
 
 package org.apache.comet
 
-import java.util.HashMap
-
 import org.apache.spark._
 import org.apache.spark.sql.comet.CometMetricNode
 import org.apache.spark.sql.vectorized._
 
+import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, 
COMET_EXEC_MEMORY_FRACTION}
 import org.apache.comet.vector.NativeUtil
 
 /**
@@ -45,36 +44,31 @@ class CometExecIterator(
     val id: Long,
     inputs: Seq[Iterator[ColumnarBatch]],
     protobufQueryPlan: Array[Byte],
-    configs: HashMap[String, String],
     nativeMetrics: CometMetricNode)
     extends Iterator[ColumnarBatch] {
 
   private val nativeLib = new Native()
-  private val plan = nativeLib.createPlan(id, configs, protobufQueryPlan, 
nativeMetrics)
+  private val plan = {
+    val configs = createNativeConf
+    nativeLib.createPlan(id, configs, protobufQueryPlan, nativeMetrics)
+  }
   private val nativeUtil = new NativeUtil
   private var nextBatch: Option[ColumnarBatch] = None
   private var currentBatch: ColumnarBatch = null
   private var closed: Boolean = false
 
   private def peekNext(): ExecutionState = {
-    val result = nativeLib.peekNext(plan)
-    val flag = result(0)
-
-    if (flag == 0) Pending
-    else if (flag == 1) {
-      val numRows = result(1)
-      val addresses = result.slice(2, result.length)
-      Batch(numRows = numRows.toInt, addresses = addresses)
-    } else {
-      throw new IllegalStateException(s"Invalid native flag: $flag")
-    }
+    convertNativeResult(nativeLib.peekNext(plan))
   }
 
   private def executeNative(
       input: Array[Array[Long]],
       finishes: Array[Boolean],
       numRows: Int): ExecutionState = {
-    val result = nativeLib.executePlan(plan, input, finishes, numRows)
+    convertNativeResult(nativeLib.executePlan(plan, input, finishes, numRows))
+  }
+
+  private def convertNativeResult(result: Array[Long]): ExecutionState = {
     val flag = result(0)
     if (flag == -1) EOF
     else if (flag == 0) Pending
@@ -87,6 +81,29 @@ class CometExecIterator(
     }
   }
 
+  /**
+   * Creates a new configuration map to be passed to the native side.
+   */
+  private def createNativeConf: java.util.HashMap[String, String] = {
+    val result = new java.util.HashMap[String, String]()
+    val conf = SparkEnv.get.conf
+
+    val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
+    result.put("memory_limit", String.valueOf(maxMemory))
+    result.put("memory_fraction", 
String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
+    result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
+    result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
+
+    // Strip mandatory prefix spark. which is not required for DataFusion 
session params
+    conf.getAll.foreach {
+      case (k, v) if k.startsWith("spark.datafusion") =>
+        result.put(k.replaceFirst("spark\\.", ""), v)
+      case _ =>
+    }
+
+    result
+  }
+
   /** Execution result from Comet native */
   trait ExecutionState
 
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 7ac1084..4d8011e 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -23,7 +23,7 @@ import java.io.ByteArrayOutputStream
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, NamedExpression, SortOrder}
@@ -31,12 +31,12 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{ColumnarToRowExec, 
ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, 
UnaryExecNode}
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import com.google.common.base.Objects
 
-import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, 
CometSparkSessionExtensions}
-import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_DEBUG_ENABLED, 
COMET_EXEC_MEMORY_FRACTION}
+import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
 import org.apache.comet.serde.OperatorOuterClass.Operator
 
 /**
@@ -83,17 +83,7 @@ object CometExec {
     nativePlan.writeTo(outputStream)
     outputStream.close()
     val bytes = outputStream.toByteArray
-
-    val configs = new java.util.HashMap[String, String]()
-
-    val maxMemory =
-      CometSparkSessionExtensions.getCometMemoryOverhead(SparkEnv.get.conf)
-    configs.put("memory_limit", String.valueOf(maxMemory))
-    configs.put("memory_fraction", 
String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
-    configs.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
-    configs.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
-
-    new CometExecIterator(newIterId, inputs, bytes, configs, nativeMetrics)
+    new CometExecIterator(newIterId, inputs, bytes, nativeMetrics)
   }
 }
 
@@ -163,33 +153,15 @@ abstract class CometNativeExec extends CometExec {
       case Some(serializedPlan) =>
         // Switch to use Decimal128 regardless of precision, since Arrow 
native execution
         // doesn't support Decimal32 and Decimal64 yet.
-        conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
-
-        // Populate native configurations
-        val configs = new java.util.HashMap[String, String]()
-        val maxMemory = 
CometSparkSessionExtensions.getCometMemoryOverhead(sparkContext.getConf)
-        configs.put("memory_limit", String.valueOf(maxMemory))
-        configs.put("memory_fraction", 
String.valueOf(COMET_EXEC_MEMORY_FRACTION.get()))
-        configs.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
-        configs.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
-
-        // Strip mandatory prefix spark. which is not required for datafusion 
session params
-        session.conf.getAll.foreach {
-          case (k, v) if k.startsWith("spark.datafusion") =>
-            configs.put(k.replaceFirst("spark\\.", ""), v)
-          case _ =>
-        }
+        SQLConf.get.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
+
         val serializedPlanCopy = serializedPlan
         // TODO: support native metrics for all operators.
         val nativeMetrics = CometMetricNode.fromCometPlan(this)
 
         def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): 
CometExecIterator = {
-          val it = new CometExecIterator(
-            CometExec.newIterId,
-            inputs,
-            serializedPlanCopy,
-            configs,
-            nativeMetrics)
+          val it =
+            new CometExecIterator(CometExec.newIterId, inputs, 
serializedPlanCopy, nativeMetrics)
 
           setSubqueries(it.id, originalPlan)
 

Reply via email to