This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new fb36dad06 [CH] A simple job scheduler for merge tree cache sync load 
(#6842)
fb36dad06 is described below

commit fb36dad0648a6be605e160096c37f74ffb48d235
Author: LiuNeng <[email protected]>
AuthorDate: Mon Aug 19 15:51:10 2024 +0800

    [CH] A simple job scheduler for merge tree cache sync load (#6842)
    
    What changes were proposed in this pull request?
    When the cache is loaded synchronously, the time consumed may be greater 
than the timeout of the spark rpc. A new asynchronous task mechanism is 
introduced to implement cache synchronous loading through polling, and a 
unified exception handling is added.
    
    How was this patch tested?
    unit tests
    
    (If this patch involves UI changes, please attach a screenshot; otherwise, 
remove this)
---
 .../gluten/execution/CHNativeCacheManager.java     |  12 +-
 ...{CHNativeCacheManager.java => CacheResult.java} |  43 ++++-
 .../apache/spark/rpc/GlutenExecutorEndpoint.scala  |  12 +-
 .../org/apache/spark/rpc/GlutenRpcMessages.scala   |   6 +-
 .../commands/GlutenCHCacheDataCommand.scala        | 188 ++++++++++++---------
 cpp-ch/local-engine/Common/CHUtil.cpp              |   1 +
 cpp-ch/local-engine/Common/ConcurrentMap.h         |   2 +-
 cpp-ch/local-engine/Common/GlutenConfig.h          |  14 ++
 .../local-engine/Storages/Cache/CacheManager.cpp   |  78 ++++++---
 cpp-ch/local-engine/Storages/Cache/CacheManager.h  |  21 ++-
 .../local-engine/Storages/Cache/JobScheduler.cpp   | 163 ++++++++++++++++++
 cpp-ch/local-engine/Storages/Cache/JobScheduler.h  | 132 +++++++++++++++
 cpp-ch/local-engine/local_engine_jni.cpp           |  14 +-
 13 files changed, 559 insertions(+), 127 deletions(-)

diff --git 
a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java
 
b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java
index f5f75dc1d..7b765924f 100644
--- 
a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java
+++ 
b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java
@@ -19,9 +19,15 @@ package org.apache.gluten.execution;
 import java.util.Set;
 
 public class CHNativeCacheManager {
-  public static void cacheParts(String table, Set<String> columns, boolean 
async) {
-    nativeCacheParts(table, String.join(",", columns), async);
+  public static String cacheParts(String table, Set<String> columns) {
+    return nativeCacheParts(table, String.join(",", columns));
   }
 
-  private static native void nativeCacheParts(String table, String columns, 
boolean async);
+  private static native String nativeCacheParts(String table, String columns);
+
+  public static CacheResult getCacheStatus(String jobId) {
+    return nativeGetCacheStatus(jobId);
+  }
+
+  private static native CacheResult nativeGetCacheStatus(String jobId);
 }
diff --git 
a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java
 
b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java
similarity index 51%
copy from 
backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java
copy to 
backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java
index f5f75dc1d..0fa69e0d0 100644
--- 
a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java
+++ 
b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java
@@ -16,12 +16,45 @@
  */
 package org.apache.gluten.execution;
 
-import java.util.Set;
+public class CacheResult {
+  public enum Status {
+    RUNNING(0),
+    SUCCESS(1),
+    ERROR(2);
 
-public class CHNativeCacheManager {
-  public static void cacheParts(String table, Set<String> columns, boolean 
async) {
-    nativeCacheParts(table, String.join(",", columns), async);
+    private final int value;
+
+    Status(int value) {
+      this.value = value;
+    }
+
+    public int getValue() {
+      return value;
+    }
+
+    public static Status fromInt(int value) {
+      for (Status myEnum : Status.values()) {
+        if (myEnum.getValue() == value) {
+          return myEnum;
+        }
+      }
+      throw new IllegalArgumentException("No enum constant for value: " + 
value);
+    }
   }
 
-  private static native void nativeCacheParts(String table, String columns, 
boolean async);
+  private final Status status;
+  private final String message;
+
+  public CacheResult(int status, String message) {
+    this.status = Status.fromInt(status);
+    this.message = message;
+  }
+
+  public Status getStatus() {
+    return status;
+  }
+
+  public String getMessage() {
+    return message;
+  }
 }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
index 4d90ab653..8a3bde235 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala
@@ -64,8 +64,6 @@ class GlutenExecutorEndpoint(val executorId: String, val 
conf: SparkConf)
         hashIds.forEach(
           resource_id => 
CHBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id))
       }
-    case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) =>
-      CHNativeCacheManager.cacheParts(mergeTreeTable, columns, true)
 
     case e =>
       logError(s"Received unexpected message. $e")
@@ -74,12 +72,16 @@ class GlutenExecutorEndpoint(val executorId: String, val 
conf: SparkConf)
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
     case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) =>
       try {
-        CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false)
-        context.reply(CacheLoadResult(true))
+        val jobId = CHNativeCacheManager.cacheParts(mergeTreeTable, columns)
+        context.reply(CacheJobInfo(status = true, jobId))
       } catch {
         case _: Exception =>
-          context.reply(CacheLoadResult(false, s"executor: $executorId cache 
data failed."))
+          context.reply(
+            CacheJobInfo(status = false, "", s"executor: $executorId cache 
data failed."))
       }
+    case GlutenMergeTreeCacheLoadStatus(jobId) =>
+      val status = CHNativeCacheManager.getCacheStatus(jobId)
+      context.reply(status)
     case e =>
       logError(s"Received unexpected message. $e")
   }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
index d675d705f..800b15b99 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala
@@ -35,8 +35,12 @@ object GlutenRpcMessages {
   case class GlutenCleanExecutionResource(executionId: String, 
broadcastHashIds: util.Set[String])
     extends GlutenRpcMessage
 
+  // for mergetree cache
   case class GlutenMergeTreeCacheLoad(mergeTreeTable: String, columns: 
util.Set[String])
     extends GlutenRpcMessage
 
-  case class CacheLoadResult(success: Boolean, reason: String = "") extends 
GlutenRpcMessage
+  case class GlutenMergeTreeCacheLoadStatus(jobId: String)
+
+  case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "")
+    extends GlutenRpcMessage
 }
diff --git 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala
 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala
index 1e6b02406..f32d22d5e 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala
@@ -17,18 +17,20 @@
 package org.apache.spark.sql.execution.commands
 
 import org.apache.gluten.exception.GlutenException
+import org.apache.gluten.execution.CacheResult
+import org.apache.gluten.execution.CacheResult.Status
 import org.apache.gluten.expression.ConverterUtils
 import org.apache.gluten.substrait.rel.ExtensionTableBuilder
 
 import org.apache.spark.affinity.CHAffinity
 import org.apache.spark.rpc.GlutenDriverEndpoint
-import org.apache.spark.rpc.GlutenRpcMessages.{CacheLoadResult, 
GlutenMergeTreeCacheLoad}
+import org.apache.spark.rpc.GlutenRpcMessages.{CacheJobInfo, 
GlutenMergeTreeCacheLoad, GlutenMergeTreeCacheLoadStatus}
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, GreaterThanOrEqual, IsNotNull, Literal}
 import org.apache.spark.sql.delta._
 import org.apache.spark.sql.execution.command.LeafRunnableCommand
-import 
org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.toExecutorId
+import 
org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.{checkExecutorId,
 collectJobTriggerResult, toExecutorId, waitAllJobFinish, waitRpcResults}
 import 
org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts
 import org.apache.spark.sql.types.{BooleanType, StringType}
 import org.apache.spark.util.ThreadUtils
@@ -106,7 +108,8 @@ case class GlutenCHCacheDataCommand(
     }
 
     val selectedAddFiles = if (tsfilter.isDefined) {
-      val allParts = DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, 
Seq.empty, false)
+      val allParts =
+        DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, 
keepNumRecords = false)
       allParts.files.filter(_.modificationTime >= tsfilter.get.toLong).toSeq
     } else if (partitionColumn.isDefined && partitionValue.isDefined) {
       val partitionColumns = snapshot.metadata.partitionSchema.fieldNames
@@ -126,10 +129,12 @@ case class GlutenCHCacheDataCommand(
           snapshot,
           Seq(partitionColumnAttr),
           Seq(isNotNullExpr, greaterThanOrEqual),
-          false)
+          keepNumRecords = false)
         .files
     } else {
-      DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, 
false).files
+      DeltaAdapter
+        .snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = 
false)
+        .files
     }
 
     val executorIdsToAddFiles =
@@ -151,9 +156,7 @@ case class GlutenCHCacheDataCommand(
 
         if (locations.isEmpty) {
           // non soft affinity
-          executorIdsToAddFiles
-            .get(GlutenCHCacheDataCommand.ALL_EXECUTORS)
-            .get
+          executorIdsToAddFiles(GlutenCHCacheDataCommand.ALL_EXECUTORS)
             .append(mergeTreePart)
         } else {
           locations.foreach(
@@ -161,7 +164,7 @@ case class GlutenCHCacheDataCommand(
               if (!executorIdsToAddFiles.contains(executor)) {
                 executorIdsToAddFiles.put(executor, new 
ArrayBuffer[AddMergeTreeParts]())
               }
-              executorIdsToAddFiles.get(executor).get.append(mergeTreePart)
+              executorIdsToAddFiles(executor).append(mergeTreePart)
             })
         }
       })
@@ -201,87 +204,112 @@ case class GlutenCHCacheDataCommand(
           executorIdsToParts.put(executorId, 
extensionTableNode.getExtensionTableStr)
         }
       })
-
-    // send rpc call
+    val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]()
     if (executorIdsToParts.contains(GlutenCHCacheDataCommand.ALL_EXECUTORS)) {
       // send all parts to all executors
-      val tableMessage = 
executorIdsToParts.get(GlutenCHCacheDataCommand.ALL_EXECUTORS).get
-      if (asynExecute) {
-        GlutenDriverEndpoint.executorDataMap.forEach(
-          (executorId, executor) => {
-            executor.executorEndpointRef.send(
-              GlutenMergeTreeCacheLoad(tableMessage, 
selectedColumns.toSet.asJava))
-          })
-        Seq(Row(true, ""))
-      } else {
-        val futureList = ArrayBuffer[Future[CacheLoadResult]]()
-        val resultList = ArrayBuffer[CacheLoadResult]()
-        GlutenDriverEndpoint.executorDataMap.forEach(
-          (executorId, executor) => {
-            futureList.append(
-              executor.executorEndpointRef.ask[CacheLoadResult](
+      val tableMessage = 
executorIdsToParts(GlutenCHCacheDataCommand.ALL_EXECUTORS)
+      GlutenDriverEndpoint.executorDataMap.forEach(
+        (executorId, executor) => {
+          futureList.append(
+            (
+              executorId,
+              executor.executorEndpointRef.ask[CacheJobInfo](
                 GlutenMergeTreeCacheLoad(tableMessage, 
selectedColumns.toSet.asJava)
-              ))
-          })
-        futureList.foreach(
-          f => {
-            resultList.append(ThreadUtils.awaitResult(f, Duration.Inf))
-          })
-        if (resultList.exists(!_.success)) {
-          Seq(Row(false, 
resultList.filter(!_.success).map(_.reason).mkString(";")))
-        } else {
-          Seq(Row(true, ""))
-        }
-      }
+              )))
+        })
     } else {
-      if (asynExecute) {
-        executorIdsToParts.foreach(
-          value => {
-            val executorData = 
GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
-            if (executorData != null) {
-              executorData.executorEndpointRef.send(
-                GlutenMergeTreeCacheLoad(value._2, 
selectedColumns.toSet.asJava))
-            } else {
-              throw new GlutenException(
-                s"executor ${value._1} not found," +
-                  s" all executors are 
${GlutenDriverEndpoint.executorDataMap.toString}")
-            }
-          })
-        Seq(Row(true, ""))
-      } else {
-        val futureList = ArrayBuffer[Future[CacheLoadResult]]()
-        val resultList = ArrayBuffer[CacheLoadResult]()
-        executorIdsToParts.foreach(
-          value => {
-            val executorData = 
GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
-            if (executorData != null) {
-              futureList.append(
-                executorData.executorEndpointRef.ask[CacheLoadResult](
-                  GlutenMergeTreeCacheLoad(value._2, 
selectedColumns.toSet.asJava)
-                ))
-            } else {
-              throw new GlutenException(
-                s"executor ${value._1} not found," +
-                  s" all executors are 
${GlutenDriverEndpoint.executorDataMap.toString}")
-            }
-          })
-        futureList.foreach(
-          f => {
-            resultList.append(ThreadUtils.awaitResult(f, Duration.Inf))
-          })
-        if (resultList.exists(!_.success)) {
-          Seq(Row(false, 
resultList.filter(!_.success).map(_.reason).mkString(";")))
-        } else {
-          Seq(Row(true, ""))
-        }
-      }
+      executorIdsToParts.foreach(
+        value => {
+          checkExecutorId(value._1)
+          val executorData = 
GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1))
+          futureList.append(
+            (
+              value._1,
+              executorData.executorEndpointRef.ask[CacheJobInfo](
+                GlutenMergeTreeCacheLoad(value._2, 
selectedColumns.toSet.asJava)
+              )))
+        })
+    }
+    val resultList = waitRpcResults(futureList)
+    if (asynExecute) {
+      val res = collectJobTriggerResult(resultList)
+      Seq(Row(res._1, res._2.mkString(";")))
+    } else {
+      val res = waitAllJobFinish(resultList)
+      Seq(Row(res._1, res._2))
     }
   }
+
 }
 
 object GlutenCHCacheDataCommand {
-  val ALL_EXECUTORS = "allExecutors"
+  private val ALL_EXECUTORS = "allExecutors"
 
   private def toExecutorId(executorId: String): String =
     executorId.split("_").last
+
+  def waitAllJobFinish(jobs: ArrayBuffer[(String, CacheJobInfo)]): (Boolean, 
String) = {
+    val res = collectJobTriggerResult(jobs)
+    var status = res._1
+    val messages = res._2
+    jobs.foreach(
+      job => {
+        if (status) {
+          var complete = false
+          while (!complete) {
+            Thread.sleep(5000)
+            val future_result = GlutenDriverEndpoint.executorDataMap
+              .get(toExecutorId(job._1))
+              .executorEndpointRef
+              .ask[CacheResult](GlutenMergeTreeCacheLoadStatus(job._2.jobId))
+            val result = ThreadUtils.awaitResult(future_result, Duration.Inf)
+            result.getStatus match {
+              case Status.ERROR =>
+                status = false
+                messages.append(
+                  s"executor : {}, failed with message: {};",
+                  job._1,
+                  result.getMessage)
+                complete = true
+              case Status.SUCCESS =>
+                complete = true
+              case _ =>
+              // still running
+            }
+          }
+        }
+      })
+    (status, messages.mkString(";"))
+  }
+
+  private def collectJobTriggerResult(jobs: ArrayBuffer[(String, 
CacheJobInfo)]) = {
+    var status = true
+    val messages = ArrayBuffer[String]()
+    jobs.foreach(
+      job => {
+        if (!job._2.status) {
+          messages.append(job._2.reason)
+          status = false
+        }
+      })
+    (status, messages)
+  }
+
+  private def waitRpcResults = (futureList: ArrayBuffer[(String, 
Future[CacheJobInfo])]) => {
+    val resultList = ArrayBuffer[(String, CacheJobInfo)]()
+    futureList.foreach(
+      f => {
+        resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf)))
+      })
+    resultList
+  }
+
+  private def checkExecutorId(executorId: String): Unit = {
+    if 
(!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) {
+      throw new GlutenException(
+        s"executor $executorId not found," +
+          s" all executors are 
${GlutenDriverEndpoint.executorDataMap.toString}")
+    }
+  }
+
 }
diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp 
b/cpp-ch/local-engine/Common/CHUtil.cpp
index 0409b66bd..8e07eea01 100644
--- a/cpp-ch/local-engine/Common/CHUtil.cpp
+++ b/cpp-ch/local-engine/Common/CHUtil.cpp
@@ -979,6 +979,7 @@ void BackendInitializerUtil::init(const std::string_view 
plan)
     // Init the table metadata cache map
     StorageMergeTreeFactory::init_cache_map();
 
+    JobScheduler::initialize(SerializedPlanParser::global_context);
     CacheManager::initialize(SerializedPlanParser::global_context);
 
     std::call_once(
diff --git a/cpp-ch/local-engine/Common/ConcurrentMap.h 
b/cpp-ch/local-engine/Common/ConcurrentMap.h
index 1719d9b25..2db351022 100644
--- a/cpp-ch/local-engine/Common/ConcurrentMap.h
+++ b/cpp-ch/local-engine/Common/ConcurrentMap.h
@@ -16,7 +16,7 @@
  */
 #pragma once
 
-#include <mutex>
+#include <shared_mutex>
 #include <unordered_map>
 
 namespace local_engine
diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h 
b/cpp-ch/local-engine/Common/GlutenConfig.h
index 84744dab2..ac82b0fff 100644
--- a/cpp-ch/local-engine/Common/GlutenConfig.h
+++ b/cpp-ch/local-engine/Common/GlutenConfig.h
@@ -183,5 +183,19 @@ struct MergeTreeConfig
         return config;
     }
 };
+
+struct GlutenJobSchedulerConfig
+{
+    inline static const String JOB_SCHEDULER_MAX_THREADS = 
"job_scheduler_max_threads";
+
+    size_t job_scheduler_max_threads = 10;
+
+    static GlutenJobSchedulerConfig loadFromContext(DB::ContextPtr context)
+    {
+        GlutenJobSchedulerConfig config;
+        config.job_scheduler_max_threads = 
context->getConfigRef().getUInt64(JOB_SCHEDULER_MAX_THREADS, 10);
+        return config;
+    }
+};
 }
 
diff --git a/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp 
b/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp
index d2c7b0681..a97f0c72a 100644
--- a/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp
+++ b/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp
@@ -26,12 +26,13 @@
 #include <Parser/MergeTreeRelParser.h>
 #include <Processors/Executors/PipelineExecutor.h>
 #include <Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h>
-#include <Processors/Sinks/NullSink.h>
 #include <QueryPipeline/QueryPipelineBuilder.h>
 #include <Common/Logger.h>
 #include <Common/logger_useful.h>
 #include <ranges>
 
+#include <jni/jni_common.h>
+
 namespace DB
 {
 namespace ErrorCodes
@@ -49,6 +50,16 @@ extern const Metric LocalThreadScheduled;
 
 namespace local_engine
 {
+
+jclass CacheManager::cache_result_class = nullptr;
+jmethodID CacheManager::cache_result_constructor = nullptr;
+
+void CacheManager::initJNI(JNIEnv * env)
+{
+    cache_result_class = CreateGlobalClassReference(env, 
"Lorg/apache/gluten/execution/CacheResult;");
+    cache_result_constructor = GetMethodID(env, cache_result_class, "<init>", 
"(ILjava/lang/String;)V");
+}
+
 CacheManager & CacheManager::instance()
 {
     static CacheManager cache_manager;
@@ -59,13 +70,6 @@ void CacheManager::initialize(DB::ContextMutablePtr context_)
 {
     auto & manager = instance();
     manager.context = context_;
-    manager.thread_pool = std::make_unique<ThreadPool>(
-        CurrentMetrics::LocalThread,
-        CurrentMetrics::LocalThreadActive,
-        CurrentMetrics::LocalThreadScheduled,
-        manager.context->getConfigRef().getInt("cache_sync_max_threads", 10),
-        0,
-        0);
 }
 
 struct CacheJobContext
@@ -73,17 +77,16 @@ struct CacheJobContext
     MergeTreeTable table;
 };
 
-void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& 
part, const std::unordered_set<String> & columns, std::shared_ptr<std::latch> 
latch)
+Task CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& 
part, const std::unordered_set<String> & columns)
 {
     CacheJobContext job_context{table};
     job_context.table.parts.clear();
     job_context.table.parts.push_back(part);
     job_context.table.snapshot_id = "";
-    auto job = [job_detail = job_context, context = this->context, 
read_columns = columns, latch = latch]()
+    Task task = [job_detail = job_context, context = this->context, 
read_columns = columns]()
     {
         try
         {
-            SCOPE_EXIT({ if (latch) latch->count_down();});
             auto storage = MergeTreeRelParser::parseStorage(job_detail.table, 
context, true);
             auto storage_snapshot = 
std::make_shared<StorageSnapshot>(*storage, storage->getInMemoryMetadataPtr());
             NamesAndTypesList names_and_types_list;
@@ -113,8 +116,7 @@ void CacheManager::cachePart(const MergeTreeTable& table, 
const MergeTreePart& p
             PullingPipelineExecutor executor(pipeline);
             while (true)
             {
-                Chunk chunk;
-                if (!executor.pull(chunk))
+                if (Chunk chunk; !executor.pull(chunk))
                     break;
             }
             LOG_INFO(getLogger("CacheManager"), "Load cache of table {}.{} 
part {} success.", job_detail.table.database, job_detail.table.table, 
job_detail.table.parts.front().name);
@@ -122,22 +124,58 @@ void CacheManager::cachePart(const MergeTreeTable& table, 
const MergeTreePart& p
         catch (std::exception& e)
         {
             LOG_ERROR(getLogger("CacheManager"), "Load cache of table {}.{} 
part {} failed.\n {}", job_detail.table.database, job_detail.table.table, 
job_detail.table.parts.front().name, e.what());
+            std::rethrow_exception(std::current_exception());
         }
     };
     LOG_INFO(getLogger("CacheManager"), "Loading cache of table {}.{} part 
{}", job_context.table.database, job_context.table.table, 
job_context.table.parts.front().name);
-    thread_pool->scheduleOrThrowOnError(std::move(job));
+    return std::move(task);
 }
 
-void CacheManager::cacheParts(const String& table_def, const 
std::unordered_set<String>& columns, bool async)
+JobId CacheManager::cacheParts(const String& table_def, const 
std::unordered_set<String>& columns)
 {
     auto table = parseMergeTreeTableString(table_def);
-    std::shared_ptr<std::latch> latch = nullptr;
-    if (!async) latch = std::make_shared<std::latch>(table.parts.size());
+    JobId id = toString(UUIDHelpers::generateV4());
+    Job job(id);
     for (const auto & part : table.parts)
     {
-        cachePart(table, part, columns, latch);
+        job.addTask(cachePart(table, part, columns));
+    }
+    auto& scheduler = JobScheduler::instance();
+    scheduler.scheduleJob(std::move(job));
+    return id;
+}
+
+jobject CacheManager::getCacheStatus(JNIEnv * env, const String & jobId)
+{
+    auto& scheduler = JobScheduler::instance();
+    auto job_status = scheduler.getJobSatus(jobId);
+    int status = 0;
+    String message;
+    if (job_status.has_value())
+    {
+        switch (job_status.value().status)
+        {
+            case JobSatus::RUNNING:
+                status = 0;
+                break;
+            case JobSatus::FINISHED:
+                status = 1;
+                break;
+            case JobSatus::FAILED:
+                status = 2;
+                for (const auto & msg : job_status->messages)
+                {
+                    message.append(msg);
+                    message.append(";");
+                }
+                break;
+        }
+    }
+    else
+    {
+        status = 2;
+        message = fmt::format("job {} not found", jobId);
     }
-    if (latch)
-        latch->wait();
+    return env->NewObject(cache_result_class, cache_result_constructor, 
status, charTojstring(env, message.c_str()));
 }
 }
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Storages/Cache/CacheManager.h 
b/cpp-ch/local-engine/Storages/Cache/CacheManager.h
index a303b7b7f..b88a3ea03 100644
--- a/cpp-ch/local-engine/Storages/Cache/CacheManager.h
+++ b/cpp-ch/local-engine/Storages/Cache/CacheManager.h
@@ -16,29 +16,32 @@
  */
 #pragma once
 #include <Disks/IDisk.h>
-#include <latch>
-
+#include <Storages/Cache/JobScheduler.h>
+#include <jni.h>
 
 namespace local_engine
 {
 struct MergeTreePart;
 struct MergeTreeTable;
+
+
+
 /***
  * Manage the cache of the MergeTree, mainly including meta.bin, data.bin, 
metadata.gluten
  */
 class CacheManager {
 public:
+    static jclass cache_result_class;
+    static jmethodID cache_result_constructor;
+    static void initJNI(JNIEnv* env);
+
     static CacheManager & instance();
     static void initialize(DB::ContextMutablePtr context);
-    void cachePart(const MergeTreeTable& table, const MergeTreePart& part, 
const std::unordered_set<String>& columns, std::shared_ptr<std::latch> latch = 
nullptr);
-    void cacheParts(const String& table_def, const std::unordered_set<String>& 
columns, bool async = true);
+    Task cachePart(const MergeTreeTable& table, const MergeTreePart& part, 
const std::unordered_set<String>& columns);
+    JobId cacheParts(const String& table_def, const 
std::unordered_set<String>& columns);
+    static jobject getCacheStatus(JNIEnv * env, const String& jobId);
 private:
     CacheManager() = default;
-
-    std::unique_ptr<ThreadPool> thread_pool;
     DB::ContextMutablePtr context;
-    std::unordered_map<String, DB::DiskPtr> policy_to_disk;
-    std::unordered_map<DB::DiskPtr, DB::DiskPtr> disk_to_metadisk;
-    std::unordered_map<String, DB::FileCachePtr> policy_to_cache;
 };
 }
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp 
b/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp
new file mode 100644
index 000000000..6a43ad644
--- /dev/null
+++ b/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#include "JobScheduler.h"
+
+#include <Common/GlutenConfig.h>
+#include <Common/ThreadPool.h>
+#include <Interpreters/Context.h>
+#include <Common/logger_useful.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int BAD_ARGUMENTS;
+}
+}
+
+namespace CurrentMetrics
+{
+extern const Metric LocalThread;
+extern const Metric LocalThreadActive;
+extern const Metric LocalThreadScheduled;
+}
+
+namespace local_engine
+{
+std::shared_ptr<JobScheduler> global_job_scheduler = nullptr;
+
+void JobScheduler::initialize(DB::ContextPtr context)
+{
+    auto config = GlutenJobSchedulerConfig::loadFromContext(context);
+    instance().thread_pool = std::make_unique<ThreadPool>(
+        CurrentMetrics::LocalThread,
+        CurrentMetrics::LocalThreadActive,
+        CurrentMetrics::LocalThreadScheduled,
+        config.job_scheduler_max_threads,
+        0,
+        0);
+
+}
+
+JobId JobScheduler::scheduleJob(Job&& job)
+{
+    cleanFinishedJobs();
+    if (job_details.contains(job.id))
+    {
+        throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "job {} exists.", 
job.id);
+    }
+    size_t task_num = job.tasks.size();
+    auto job_id = job.id;
+    std::vector<TaskResult> task_results;
+    task_results.reserve(task_num);
+    JobContext job_context = {std::move(job), 
std::make_unique<std::atomic_uint32_t>(task_num), std::move(task_results)};
+    {
+        std::lock_guard lock(job_details_mutex);
+        job_details.emplace(job_id, std::move(job_context));
+    }
+    LOG_INFO(logger, "schedule job {}", job_id);
+
+    auto & job_detail = job_details.at(job_id);
+
+    for (auto & task : job_detail.job.tasks)
+    {
+        job_detail.task_results.emplace_back(TaskResult());
+        auto & task_result = job_detail.task_results.back();
+        thread_pool->scheduleOrThrow(
+            [&]()
+            {
+                SCOPE_EXIT({
+                    job_detail.remain_tasks->fetch_sub(1, 
std::memory_order::acquire);
+                    if (job_detail.isFinished())
+                    {
+                        addFinishedJob(job_detail.job.id);
+                    }
+                });
+                try
+                {
+                    task();
+                    task_result.status = TaskResult::Status::SUCCESS;
+                }
+                catch (std::exception & e)
+                {
+                    task_result.status = TaskResult::Status::FAILED;
+                    task_result.message = e.what();
+                }
+            });
+    }
+    return job_id;
+}
+
+std::optional<JobSatus> JobScheduler::getJobSatus(const JobId & job_id)
+{
+    if (!job_details.contains(job_id))
+    {
+        return std::nullopt;
+    }
+    std::optional<JobSatus> res;
+    auto & job_context = job_details.at(job_id);
+    if (job_context.isFinished())
+    {
+        std::vector<String> messages;
+        for (auto & task_result : job_context.task_results)
+        {
+            if (task_result.status == TaskResult::Status::FAILED)
+            {
+                messages.push_back(task_result.message);
+            }
+        }
+        if (messages.empty())
+            res = JobSatus::success();
+        else
+            res= JobSatus::failed(messages);
+    }
+    else
+        res = JobSatus::running();
+    return res;
+}
+
+void JobScheduler::cleanupJob(const JobId & job_id)
+{
+    LOG_INFO(logger, "clean job {}", job_id);
+    job_details.erase(job_id);
+}
+
+void JobScheduler::addFinishedJob(const JobId & job_id)
+{
+    std::lock_guard lock(finished_job_mutex);
+    auto job = std::make_pair(job_id, Stopwatch());
+    finished_job.emplace_back(job);
+}
+
+void JobScheduler::cleanFinishedJobs()
+{
+    std::lock_guard lock(finished_job_mutex);
+    for (auto it = finished_job.begin(); it != finished_job.end();)
+    {
+        // clean finished job after 5 minutes
+        if (it->second.elapsedSeconds() > 60 * 5)
+        {
+            cleanupJob(it->first);
+            it = finished_job.erase(it);
+        }
+        else
+            ++it;
+    }
+}
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Storages/Cache/JobScheduler.h 
b/cpp-ch/local-engine/Storages/Cache/JobScheduler.h
new file mode 100644
index 000000000..b5c2f601a
--- /dev/null
+++ b/cpp-ch/local-engine/Storages/Cache/JobScheduler.h
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+#include <base/types.h>
+#include <Common/ThreadPool_fwd.h>
+#include <Interpreters/Context_fwd.h>
+#include <Common/Stopwatch.h>
+
+namespace local_engine
+{
+
+using JobId = String;
+using Task = std::function<void()>;
+
+class Job
+{
+    friend class JobScheduler;
+public:
+    explicit Job(const JobId& id)
+        : id(id)
+    {
+    }
+
+    void addTask(Task&& task)
+    {
+        tasks.emplace_back(task);
+    }
+
+private:
+    JobId id;
+    std::vector<Task> tasks;
+};
+
+
+
+struct JobSatus
+{
+    enum Status
+    {
+        RUNNING,
+        FINISHED,
+        FAILED
+    };
+    Status status;
+    std::vector<String> messages;
+
+    static JobSatus success()
+    {
+        return JobSatus{FINISHED};
+    }
+
+    static JobSatus running()
+    {
+        return JobSatus{RUNNING};
+    }
+
+    static JobSatus failed(const std::vector<std::string> & messages)
+    {
+        return JobSatus{FAILED, messages};
+    }
+};
+
+struct TaskResult
+{
+    enum Status
+    {
+        SUCCESS,
+        FAILED,
+        RUNNING
+    };
+    Status status = RUNNING;
+    String message;
+};
+
+class JobContext
+{
+public:
+    Job job;
+    std::unique_ptr<std::atomic_uint32_t> remain_tasks = 
std::make_unique<std::atomic_uint32_t>();
+    std::vector<TaskResult> task_results;
+
+    bool isFinished()
+    {
+        return remain_tasks->load(std::memory_order::relaxed) == 0;
+    }
+};
+
+class JobScheduler
+{
+public:
+    static JobScheduler& instance()
+    {
+        static JobScheduler global_job_scheduler;
+        return global_job_scheduler;
+    }
+
+    static void initialize(DB::ContextPtr context);
+
+    JobId scheduleJob(Job&& job);
+
+    std::optional<JobSatus> getJobSatus(const JobId& job_id);
+
+    void cleanupJob(const JobId& job_id);
+
+    void addFinishedJob(const JobId& job_id);
+
+    void cleanFinishedJobs();
+private:
+    JobScheduler() = default;
+    std::unique_ptr<ThreadPool> thread_pool;
+    std::unordered_map<JobId, JobContext> job_details;
+    std::mutex job_details_mutex;
+
+    std::vector<std::pair<JobId, Stopwatch>> finished_job;
+    std::mutex finished_job_mutex;
+    LoggerPtr logger = getLogger("JobScheduler");
+};
+}
diff --git a/cpp-ch/local-engine/local_engine_jni.cpp 
b/cpp-ch/local-engine/local_engine_jni.cpp
index 828556b4a..3c3d6d4f8 100644
--- a/cpp-ch/local-engine/local_engine_jni.cpp
+++ b/cpp-ch/local-engine/local_engine_jni.cpp
@@ -163,6 +163,7 @@ JNIEXPORT jint JNI_OnLoad(JavaVM * vm, void * /*reserved*/)
         env, local_engine::SparkRowToCHColumn::spark_row_interator_class, 
"nextBatch", "()Ljava/nio/ByteBuffer;");
 
     local_engine::BroadCastJoinBuilder::init(env);
+    local_engine::CacheManager::initJNI(env);
 
     local_engine::JNIUtils::vm = vm;
     return JNI_VERSION_1_8;
@@ -1269,7 +1270,7 @@ JNIEXPORT void 
Java_org_apache_gluten_utils_TestExceptionUtils_generateNativeExc
 
 
 
-JNIEXPORT void 
Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * 
env, jobject, jstring table_, jstring columns_, jboolean async_)
+JNIEXPORT jstring 
Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * 
env, jobject, jstring table_, jstring columns_)
 {
     LOCAL_ENGINE_JNI_METHOD_START
     auto table_def = jstring2string(env, table_);
@@ -1280,10 +1281,17 @@ JNIEXPORT void 
Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCache
     {
         column_set.insert(col);
     }
-    local_engine::CacheManager::instance().cacheParts(table_def, column_set, 
async_);
-    LOCAL_ENGINE_JNI_METHOD_END(env, );
+    auto id = local_engine::CacheManager::instance().cacheParts(table_def, 
column_set);
+    return local_engine::charTojstring(env, id.c_str());
+    LOCAL_ENGINE_JNI_METHOD_END(env, nullptr);
 }
 
+JNIEXPORT jobject 
Java_org_apache_gluten_execution_CHNativeCacheManager_nativeGetCacheStatus(JNIEnv
 * env, jobject, jstring id)
+{
+    LOCAL_ENGINE_JNI_METHOD_START
+    return local_engine::CacheManager::instance().getCacheStatus(env, 
jstring2string(env, id));
+    LOCAL_ENGINE_JNI_METHOD_END(env, nullptr);
+}
 #ifdef __cplusplus
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to