comphead commented on code in PR #2538:
URL: https://github.com/apache/datafusion-comet/pull/2538#discussion_r2421756210


##########
spark/src/main/scala/org/apache/comet/CometExecIterator.scala:
##########
@@ -318,3 +262,74 @@ class CometExecIterator(
     nativeLib.logMemoryUsage(s"task_memory_spark_$threadId", sparkTaskMemory)
   }
 }
+
+object CometExecIterator extends Logging {
+
+  def getMemoryConfig(conf: SparkConf): MemoryConfig = {
+    val numCores = numDriverOrExecutorCores(conf).toFloat
+    val coresPerTask = conf.get("spark.task.cpus", "1").toFloat
+    // there are different paths for on-heap vs off-heap mode
+    val offHeapMode = CometSparkSessionExtensions.isOffHeapEnabled(conf)
+    if (offHeapMode) {
+      // in off-heap mode, Comet uses unified memory management to share 
off-heap memory with Spark
+      val offHeapSize = 
ByteUnit.MiB.toBytes(conf.getSizeAsMb("spark.memory.offHeap.size"))
+      val memoryFraction = CometConf.COMET_EXEC_MEMORY_POOL_FRACTION.get()
+      val memoryLimit = (offHeapSize * memoryFraction).toLong
+      val memoryLimitPerTask = (memoryLimit.toFloat * coresPerTask / 
numCores).toLong
+      val memoryPoolType = getMemoryPoolType(defaultValue = "fair_unified")
+      logInfo(
+        s"memoryPoolType=$memoryPoolType, " +
+          s"offHeapSize=${toMB(offHeapSize)}, " +
+          s"memoryFraction=$memoryFraction, " +
+          s"memoryLimit=${toMB(memoryLimit)}, " +
+          s"memoryLimitPerTask=${toMB(memoryLimitPerTask)}")
+      MemoryConfig(offHeapMode, memoryPoolType = memoryPoolType, memoryLimit, 
memoryLimitPerTask)
+    } else {
+      // we'll use the built-in memory pool from DF, and initializes with 
`memory_limit`
+      // and `memory_fraction` below.
+      val memoryLimit = 
CometSparkSessionExtensions.getCometMemoryOverhead(conf)
+      // example 16GB maxMemory * 16 cores with 4 cores per task results
+      // in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB
+      val memoryLimitPerTask = (memoryLimit.toFloat * coresPerTask / 
numCores).toLong
+      val memoryPoolType = getMemoryPoolType(defaultValue = 
"greedy_task_shared")
+      logInfo(
+        s"memoryPoolType=$memoryPoolType, " +
+          s"memoryLimit=${toMB(memoryLimit)}, " +
+          s"memoryLimitPerTask=${toMB(memoryLimitPerTask)}")
+      MemoryConfig(offHeapMode, memoryPoolType = memoryPoolType, memoryLimit, 
memoryLimitPerTask)
+    }
+  }
+
+  private def getMemoryPoolType(defaultValue: String): String = {
+    COMET_EXEC_MEMORY_POOL_TYPE.get() match {
+      case "default" => defaultValue
+      case other => other
+    }
+  }
+
+  private def numDriverOrExecutorCores(conf: SparkConf): Int = {
+    def convertToInt(threads: String): Int = {
+      if (threads == "*") Runtime.getRuntime.availableProcessors() else 
threads.toInt
+    }
+
+    val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r

Review Comment:
   would be nice to comment what expression is looking for like `local[*]` 
pseudocode?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to