This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 61bc50626 [CH] Fix some test cases too slow (#6659)
61bc50626 is described below

commit 61bc50626194da294c998e366b80c7c367dd3baf
Author: LiuNeng <[email protected]>
AuthorDate: Thu Aug 1 10:30:31 2024 +0800

    [CH] Fix some test cases too slow (#6659)
    
    fix ut slow , optimize lock in queryContextManager
    
    Co-authored-by: liuneng1994 <[email protected]>
---
 ...tenClickHouseMergeTreePathBasedWriteSuite.scala | 28 ++++----
 .../GlutenClickHouseMergeTreeWriteSuite.scala      | 75 +++++++++++-----------
 cpp-ch/local-engine/Common/ConcurrentMap.h         | 18 ++++--
 cpp-ch/local-engine/Common/QueryContext.cpp        | 17 ++---
 4 files changed, 69 insertions(+), 69 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreePathBasedWriteSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreePathBasedWriteSuite.scala
index ed6953b81..34ffecb45 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreePathBasedWriteSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreePathBasedWriteSuite.scala
@@ -749,8 +749,7 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
     }
   }
 
-  // FIXME: very slow after 
https://github.com/apache/incubator-gluten/pull/6558
-  ignore("test mergetree path based write with bucket table") {
+  test("test mergetree path based write with bucket table") {
     val dataPath = s"$basePath/lineitem_mergetree_bucket"
     clearDataPath(dataPath)
 
@@ -760,8 +759,8 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
 
     sourceDF.write
       .format("clickhouse")
-      .partitionBy("l_shipdate")
-      .option("clickhouse.orderByKey", "l_orderkey,l_returnflag")
+      .partitionBy("l_returnflag")
+      .option("clickhouse.orderByKey", "l_orderkey")
       .option("clickhouse.primaryKey", "l_orderkey")
       .option("clickhouse.numBuckets", "4")
       .option("clickhouse.bucketColumnNames", "l_partkey")
@@ -808,13 +807,13 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
         val buckets = 
ClickHouseTableV2.getTable(fileIndex.deltaLog).bucketOption
         assert(buckets.isDefined)
         assertResult(4)(buckets.get.numBuckets)
-        assertResult("l_orderkey,l_returnflag")(
+        assertResult("l_orderkey")(
           buckets.get.sortColumnNames
             .mkString(","))
         assertResult("l_partkey")(
           buckets.get.bucketColumnNames
             .mkString(","))
-        assertResult("l_orderkey,l_returnflag")(
+        assertResult("l_orderkey")(
           ClickHouseTableV2
             .getTable(fileIndex.deltaLog)
             .orderByKeyOption
@@ -827,20 +826,21 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
             .get
             .mkString(","))
         
assertResult(1)(ClickHouseTableV2.getTable(fileIndex.deltaLog).partitionColumns.size)
-        assertResult("l_shipdate")(
+        assertResult("l_returnflag")(
           ClickHouseTableV2
             .getTable(fileIndex.deltaLog)
             .partitionColumns
             .head)
         val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => 
f.asInstanceOf[AddMergeTreeParts])
 
-        assertResult(10089)(addFiles.size)
+        assertResult(12)(addFiles.size)
         assertResult(600572)(addFiles.map(_.rows).sum)
-        
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1992-06-01")))
-        
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1993-01-01")))
-        
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1995-01-21")))
-        assertResult(1)(addFiles.count(
-          f => f.partitionValues("l_shipdate").equals("1995-01-21") && 
f.bucketNum.equals("00000")))
+        
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("A")))
+        
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("N")))
+        
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("R")))
+        assertResult(1)(
+          addFiles.count(
+            f => f.partitionValues("l_returnflag").equals("A") && 
f.bucketNum.equals("00000")))
     }
     // check part pruning effect of filter on bucket column
     val df = spark.sql(s"""
@@ -855,7 +855,7 @@ class GlutenClickHouseMergeTreePathBasedWriteSuite
       .flatMap(partition => 
partition.asInstanceOf[GlutenMergeTreePartition].partList)
       .map(_.name)
       .distinct
-    assertResult(4)(touchedParts.size)
+    assertResult(12)(touchedParts.size)
 
     // test upsert on partitioned & bucketed table
     upsertSourceTableAndCheck(dataPath)
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteSuite.scala
index 84218f26a..3b7606daa 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteSuite.scala
@@ -801,39 +801,37 @@ class GlutenClickHouseMergeTreeWriteSuite
     }
   }
 
-  // FIXME: very slow after 
https://github.com/apache/incubator-gluten/pull/6558
-  ignore("test mergetree write with bucket table") {
+  test("test mergetree write with bucket table") {
     spark.sql(s"""
                  |DROP TABLE IF EXISTS lineitem_mergetree_bucket;
                  |""".stripMargin)
 
-    spark.sql(
-      s"""
-         |CREATE TABLE IF NOT EXISTS lineitem_mergetree_bucket
-         |(
-         | l_orderkey      bigint,
-         | l_partkey       bigint,
-         | l_suppkey       bigint,
-         | l_linenumber    bigint,
-         | l_quantity      double,
-         | l_extendedprice double,
-         | l_discount      double,
-         | l_tax           double,
-         | l_returnflag    string,
-         | l_linestatus    string,
-         | l_shipdate      date,
-         | l_commitdate    date,
-         | l_receiptdate   date,
-         | l_shipinstruct  string,
-         | l_shipmode      string,
-         | l_comment       string
-         |)
-         |USING clickhouse
-         |PARTITIONED BY (l_shipdate)
-         |CLUSTERED BY (l_partkey)
-         |${if (sparkVersion.equals("3.2")) "" else "SORTED BY (l_orderkey, 
l_returnflag)"} INTO 4 BUCKETS
-         |LOCATION '$basePath/lineitem_mergetree_bucket'
-         |""".stripMargin)
+    spark.sql(s"""
+                 |CREATE TABLE IF NOT EXISTS lineitem_mergetree_bucket
+                 |(
+                 | l_orderkey      bigint,
+                 | l_partkey       bigint,
+                 | l_suppkey       bigint,
+                 | l_linenumber    bigint,
+                 | l_quantity      double,
+                 | l_extendedprice double,
+                 | l_discount      double,
+                 | l_tax           double,
+                 | l_returnflag    string,
+                 | l_linestatus    string,
+                 | l_shipdate      date,
+                 | l_commitdate    date,
+                 | l_receiptdate   date,
+                 | l_shipinstruct  string,
+                 | l_shipmode      string,
+                 | l_comment       string
+                 |)
+                 |USING clickhouse
+                 |PARTITIONED BY (l_returnflag)
+                 |CLUSTERED BY (l_partkey)
+                 |${if (sparkVersion.equals("3.2")) "" else "SORTED BY 
(l_orderkey)"} INTO 4 BUCKETS
+                 |LOCATION '$basePath/lineitem_mergetree_bucket'
+                 |""".stripMargin)
 
     spark.sql(s"""
                  | insert into table lineitem_mergetree_bucket
@@ -881,7 +879,7 @@ class GlutenClickHouseMergeTreeWriteSuite
         if (sparkVersion.equals("3.2")) {
           
assert(ClickHouseTableV2.getTable(fileIndex.deltaLog).orderByKeyOption.isEmpty)
         } else {
-          assertResult("l_orderkey,l_returnflag")(
+          assertResult("l_orderkey")(
             ClickHouseTableV2
               .getTable(fileIndex.deltaLog)
               .orderByKeyOption
@@ -890,20 +888,21 @@ class GlutenClickHouseMergeTreeWriteSuite
         }
         
assert(ClickHouseTableV2.getTable(fileIndex.deltaLog).primaryKeyOption.isEmpty)
         
assertResult(1)(ClickHouseTableV2.getTable(fileIndex.deltaLog).partitionColumns.size)
-        assertResult("l_shipdate")(
+        assertResult("l_returnflag")(
           ClickHouseTableV2
             .getTable(fileIndex.deltaLog)
             .partitionColumns
             .head)
         val addFiles = fileIndex.matchingFiles(Nil, Nil).map(f => 
f.asInstanceOf[AddMergeTreeParts])
 
-        assertResult(10089)(addFiles.size)
+        assertResult(12)(addFiles.size)
         assertResult(600572)(addFiles.map(_.rows).sum)
-        
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1992-06-01")))
-        
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1993-01-01")))
-        
assertResult(4)(addFiles.count(_.partitionValues("l_shipdate").equals("1995-01-21")))
-        assertResult(1)(addFiles.count(
-          f => f.partitionValues("l_shipdate").equals("1995-01-21") && 
f.bucketNum.equals("00000")))
+        
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("A")))
+        
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("N")))
+        
assertResult(4)(addFiles.count(_.partitionValues("l_returnflag").equals("R")))
+        assertResult(1)(
+          addFiles.count(
+            f => f.partitionValues("l_returnflag").equals("A") && 
f.bucketNum.equals("00000")))
     }
     // check part pruning effect of filter on bucket column
     val df = spark.sql(s"""
@@ -918,7 +917,7 @@ class GlutenClickHouseMergeTreeWriteSuite
       .flatMap(partition => 
partition.asInstanceOf[GlutenMergeTreePartition].partList)
       .map(_.name)
       .distinct
-    assertResult(4)(touchedParts.size)
+    assertResult(12)(touchedParts.size)
 
     // test upsert on partitioned & bucketed table
     upsertSourceTableAndCheck("lineitem_mergetree_bucket")
diff --git a/cpp-ch/local-engine/Common/ConcurrentMap.h 
b/cpp-ch/local-engine/Common/ConcurrentMap.h
index c56926ff5..1719d9b25 100644
--- a/cpp-ch/local-engine/Common/ConcurrentMap.h
+++ b/cpp-ch/local-engine/Common/ConcurrentMap.h
@@ -27,13 +27,13 @@ class ConcurrentMap
 public:
     void insert(const K & key, const V & value)
     {
-        std::lock_guard lock{mutex};
+        std::unique_lock lock{mutex};
         map.insert({key, value});
     }
 
     V get(const K & key)
     {
-        std::lock_guard lock{mutex};
+        std::shared_lock lock{mutex};
         auto it = map.find(key);
         if (it == map.end())
         {
@@ -44,24 +44,30 @@ public:
 
     void erase(const K & key)
     {
-        std::lock_guard lock{mutex};
+        std::unique_lock lock{mutex};
         map.erase(key);
     }
 
     void clear()
     {
-        std::lock_guard lock{mutex};
+        std::unique_lock lock{mutex};
         map.clear();
     }
 
+    bool contains(const K & key)
+    {
+        std::shared_lock lock{mutex};
+        return map.contains(key);
+    }
+
     size_t size() const
     {
-        std::lock_guard lock{mutex};
+        std::shared_lock lock{mutex};
         return map.size();
     }
 
 private:
     std::unordered_map<K, V> map;
-    mutable std::mutex mutex;
+    mutable std::shared_mutex mutex;
 };
 }
diff --git a/cpp-ch/local-engine/Common/QueryContext.cpp 
b/cpp-ch/local-engine/Common/QueryContext.cpp
index 68934adad..2d5780a6e 100644
--- a/cpp-ch/local-engine/Common/QueryContext.cpp
+++ b/cpp-ch/local-engine/Common/QueryContext.cpp
@@ -24,6 +24,7 @@
 #include <Common/ThreadStatus.h>
 #include <Common/CHUtil.h>
 #include <Common/GlutenConfig.h>
+#include <Common/ConcurrentMap.h>
 #include <base/unit.h>
 #include <sstream>
 #include <iomanip>
@@ -48,8 +49,7 @@ struct QueryContext
     ContextMutablePtr query_context;
 };
 
-std::unordered_map<int64_t, std::shared_ptr<QueryContext>> query_map;
-std::mutex query_map_mutex;
+ConcurrentMap<int64_t, std::shared_ptr<QueryContext>> query_map;
 
 int64_t QueryContextManager::initializeQuery()
 {
@@ -72,9 +72,8 @@ int64_t QueryContextManager::initializeQuery()
 
     query_context->thread_group->memory_tracker.setSoftLimit(memory_limit);
     query_context->thread_group->memory_tracker.setHardLimit(memory_limit + 
config.extra_memory_hard_limit);
-    std::lock_guard<std::mutex> lock_guard(query_map_mutex);
     int64_t id = reinterpret_cast<int64_t>(query_context->thread_group.get());
-    query_map.emplace(id, query_context);
+    query_map.insert(id, query_context);
     return id;
 }
 
@@ -84,9 +83,8 @@ DB::ContextMutablePtr 
QueryContextManager::currentQueryContext()
     {
         throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Thread group not 
found.");
     }
-    std::lock_guard lock_guard(query_map_mutex);
     int64_t id = reinterpret_cast<int64_t>(CurrentThread::getGroup().get());
-    return query_map[id]->query_context;
+    return query_map.get(id)->query_context;
 }
 
 void 
QueryContextManager::logCurrentPerformanceCounters(ProfileEvents::Counters & 
counters)
@@ -116,10 +114,9 @@ void 
QueryContextManager::logCurrentPerformanceCounters(ProfileEvents::Counters
 
 size_t QueryContextManager::currentPeakMemory(int64_t id)
 {
-    std::lock_guard lock_guard(query_map_mutex);
     if (!query_map.contains(id))
         throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "context released {}", 
id);
-    return query_map[id]->thread_group->memory_tracker.getPeak();
+    return query_map.get(id)->thread_group->memory_tracker.getPeak();
 }
 
 void QueryContextManager::finalizeQuery(int64_t id)
@@ -130,8 +127,7 @@ void QueryContextManager::finalizeQuery(int64_t id)
     }
     std::shared_ptr<QueryContext> context;
     {
-        std::lock_guard lock_guard(query_map_mutex);
-        context = query_map[id];
+        context = query_map.get(id);
     }
     auto query_context = context->thread_status->getQueryContext();
     if (!query_context)
@@ -152,7 +148,6 @@ void QueryContextManager::finalizeQuery(int64_t id)
     context->thread_status.reset();
     query_context.reset();
     {
-        std::lock_guard lock_guard(query_map_mutex);
         query_map.erase(id);
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to