This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b619a038 [GLUTEN-6819][CH] Refactor source from jave iter &&  make 
casting happens before materializing (#6830)
1b619a038 is described below

commit 1b619a0388417c03d4dff8df4c89175619ada5b0
Author: 李扬 <[email protected]>
AuthorDate: Wed Aug 14 21:53:08 2024 +0800

    [GLUTEN-6819][CH] Refactor source from jave iter &&  make casting happens 
before materializing (#6830)
    
    * refactor source from jave iter
    
    * remove used includes
---
 cpp-ch/local-engine/Common/QueryContext.cpp        |   2 -
 cpp-ch/local-engine/Parser/TypeParser.cpp          |   1 -
 .../local-engine/Storages/SourceFromJavaIter.cpp   | 197 ++++++++-------------
 cpp-ch/local-engine/Storages/SourceFromJavaIter.h  |  13 +-
 4 files changed, 76 insertions(+), 137 deletions(-)

diff --git a/cpp-ch/local-engine/Common/QueryContext.cpp 
b/cpp-ch/local-engine/Common/QueryContext.cpp
index 2d5780a6e..e5f5dd5dc 100644
--- a/cpp-ch/local-engine/Common/QueryContext.cpp
+++ b/cpp-ch/local-engine/Common/QueryContext.cpp
@@ -16,8 +16,6 @@
  */
 #include "QueryContext.h"
 
-#include <format>
-
 #include <Interpreters/Context.h>
 #include <Parser/SerializedPlanParser.h>
 #include <Common/CurrentThread.h>
diff --git a/cpp-ch/local-engine/Parser/TypeParser.cpp 
b/cpp-ch/local-engine/Parser/TypeParser.cpp
index 39d52131e..269f35747 100644
--- a/cpp-ch/local-engine/Parser/TypeParser.cpp
+++ b/cpp-ch/local-engine/Parser/TypeParser.cpp
@@ -14,7 +14,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include <optional>
 #include <AggregateFunctions/AggregateFunctionFactory.h>
 #include <Core/ColumnsWithTypeAndName.h>
 #include <DataTypes/DataTypeAggregateFunction.h>
diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp 
b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
index 1c5902c8c..2191112cf 100644
--- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
+++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
@@ -15,25 +15,20 @@
  * limitations under the License.
  */
 #include "SourceFromJavaIter.h"
-#include <Columns/ColumnConst.h>
-#include <Columns/ColumnNullable.h>
-#include <Columns/ColumnMap.h>
-#include <Columns/ColumnTuple.h>
-#include <Columns/IColumn.h>
-#include <Core/ColumnsWithTypeAndName.h>
-#include <DataTypes/DataTypesNumber.h>
+#include <Interpreters/castColumn.h>
 #include <Processors/Transforms/AggregatingTransform.h>
 #include <jni/jni_common.h>
-#include <Common/assert_cast.h>
 #include <Common/CHUtil.h>
-#include <Common/DebugUtils.h>
 #include <Common/Exception.h>
 #include <Common/JNIUtils.h>
-#include <DataTypes/DataTypeArray.h>
-#include <DataTypes/DataTypeTuple.h>
-#include <DataTypes/DataTypeMap.h>
-#include <DataTypes/DataTypeNullable.h>
-#include <DataTypes/IDataType.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+    extern const int LOGICAL_ERROR;
+}
+}
 
 namespace local_engine
 {
@@ -41,12 +36,38 @@ jclass 
SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr;
 jmethodID SourceFromJavaIter::serialized_record_batch_iterator_hasNext = 
nullptr;
 jmethodID SourceFromJavaIter::serialized_record_batch_iterator_next = nullptr;
 
-
-static DB::Block getRealHeader(const DB::Block & header)
+static DB::Block getRealHeader(const DB::Block & header, const DB::Block * 
first_block)
 {
-    if (header.columns())
+    if (!header)
+        return BlockUtil::buildRowCountHeader();
+
+    if (!first_block)
         return header;
-    return BlockUtil::buildRowCountHeader();
+
+    if (header.columns() != first_block->columns())
+        throw DB::Exception(
+            DB::ErrorCodes::LOGICAL_ERROR,
+            "Header first block have different number of columns, header:{} 
first_block:{}",
+            header.dumpStructure(),
+            first_block->dumpStructure());
+
+    DB::Block result;
+    const size_t column_size = header.columns();
+    for (size_t i = 0; i < column_size; ++i)
+    {
+        const auto & header_column = header.getByPosition(i);
+        const auto & input_column = first_block->getByPosition(i);
+        chassert(header_column.name == input_column.name);
+
+        DB::WhichDataType input_which(input_column.type);
+        /// Some AggregateFunctions may have parameters, so we need to use the 
exact type from the first block.
+        /// e.g. spark approx_percentile -> CH quantilesGK(accuracy, level1, 
level2, ...), the intermediate result type
+        /// parsed from substrait plan is always AggregateFunction(10000, 
1)(quantilesGK, arg_type), which maybe different
+        /// from the actual intermediate result type from input block. So we 
need to use the exact type from the input block.
+        auto type = input_which.isAggregateFunction() ? input_column.type : 
header_column.type;
+        result.insert(DB::ColumnWithTypeAndName(type, header_column.name));
+    }
+    return result;
 }
 
 
@@ -62,8 +83,8 @@ DB::Block * SourceFromJavaIter::peekBlock(JNIEnv * env, 
jobject java_iter)
 
 
 SourceFromJavaIter::SourceFromJavaIter(
-    DB::ContextPtr context_, DB::Block header, jobject java_iter_, bool 
materialize_input_, DB::Block * first_block_)
-    : DB::ISource(getRealHeader(header))
+    DB::ContextPtr context_, const DB::Block & header, jobject java_iter_, 
bool materialize_input_, const DB::Block * first_block_)
+    : DB::ISource(getRealHeader(header, first_block_))
     , context(context_)
     , original_header(header)
     , java_iter(java_iter_)
@@ -80,43 +101,50 @@ DB::Chunk SourceFromJavaIter::generate()
     GET_JNIENV(env)
     SCOPE_EXIT({CLEAN_JNIENV});
 
-    DB::Chunk result;
-    DB::Block * data = nullptr;
+    DB::Block * input_block = nullptr;
     if (first_block) [[unlikely]]
     {
-        data = first_block;
+        input_block = const_cast<DB::Block *>(first_block);
         first_block = nullptr;
     }
     else if (jboolean has_next = safeCallBooleanMethod(env, java_iter, 
serialized_record_batch_iterator_hasNext))
     {
         jbyteArray block = static_cast<jbyteArray>(safeCallObjectMethod(env, 
java_iter, serialized_record_batch_iterator_next));
-        data = reinterpret_cast<DB::Block *>(byteArrayToLong(env, block));
+        input_block = reinterpret_cast<DB::Block *>(byteArrayToLong(env, 
block));
     }
     else
         return {};
 
-    /// Post-processing
-    if (materialize_input)
-        materializeBlockInplace(*data);
-
-    if (data->rows() > 0)
+    DB::Chunk result;
+    if (original_header)
     {
-        size_t rows = data->rows();
-        if (original_header.columns())
+        const auto & header = getPort().getHeader();
+        chassert(header.columns() == input_block->columns());
+        /// Cast all input columns in data to expected data types in header
+        for (size_t i = 0; i < header.columns(); ++i)
         {
-            result.setColumns(data->mutateColumns(), rows);
-            convertNullable(result);
-            auto info = std::make_shared<DB::AggregatedChunkInfo>();
-            info->is_overflows = data->info.is_overflows;
-            info->bucket_num = data->info.bucket_num;
-            result.getChunkInfos().add(std::move(info));
-        }
-        else
-        {
-            result = BlockUtil::buildRowCountChunk(rows);
-            auto info = std::make_shared<DB::AggregatedChunkInfo>();
-            result.getChunkInfos().add(std::move(info));
+            auto & input_column = input_block->getByPosition(i);
+            const auto & expected_type = header.getByPosition(i).type;
+            auto column = DB::castColumn(input_column, expected_type);
+            input_column.column = column;
+            input_column.type = expected_type;
         }
+
+        /// Do materializing after casting is faster than materializing before 
casting
+        if (materialize_input)
+            materializeBlockInplace(*input_block);
+
+        auto info = std::make_shared<DB::AggregatedChunkInfo>();
+        info->is_overflows = input_block->info.is_overflows;
+        info->bucket_num = input_block->info.bucket_num;
+        result.getChunkInfos().add(std::move(info));
+        result.setColumns(input_block->getColumns(), input_block->rows());
+    }
+    else
+    {
+        DB::Chunk result = BlockUtil::buildRowCountChunk(input_block->rows());
+        auto info = std::make_shared<DB::AggregatedChunkInfo>();
+        result.getChunkInfos().add(std::move(info));
     }
     return result;
 }
@@ -140,87 +168,4 @@ Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env, 
jbyteArray arr)
     return result;
 }
 
-void SourceFromJavaIter::convertNullable(DB::Chunk & chunk)
-{
-    auto output = this->getOutputs().front().getHeader();
-    auto rows = chunk.getNumRows();
-    auto columns = chunk.detachColumns();
-    for (size_t i = 0; i < columns.size(); ++i)
-    {
-        const auto & column = columns.at(i);
-        const auto & type = output.getByPosition(i).type;
-        columns[i] = convertNestedNullable(column, type);
-    }
-    chunk.setColumns(columns, rows);
-}
-
-
-DB::ColumnPtr SourceFromJavaIter::convertNestedNullable(const DB::ColumnPtr & 
column, const DB::DataTypePtr & target_type)
-{
-    DB::WhichDataType column_type(column->getDataType());
-    if (column_type.isAggregateFunction())
-        return column;
-
-    if (DB::isColumnConst(*column))
-    {
-        const auto & data_column = assert_cast<const DB::ColumnConst 
&>(*column).getDataColumnPtr();
-        const auto & result_column = convertNestedNullable(data_column, 
target_type);
-        return DB::ColumnConst::create(result_column, column->size());
-    }
-
-    // if target type is non-nullable, the column type must be also 
non-nullable, recursively converting it's nested type
-    // if target type is nullable, the column type may be nullable or 
non-nullable, converting it and then recursively converting it's nested type
-    DB::ColumnPtr new_column = column;
-    if (!column_type.isNullable() && target_type->isNullable())
-        new_column = DB::makeNullable(column);
-
-    DB::ColumnPtr nested_column = new_column;
-    DB::DataTypePtr nested_target_type = removeNullable(target_type);
-    if (new_column->isNullable())
-    {
-        const auto & nullable_col = typeid_cast<const DB::ColumnNullable 
*>(new_column->getPtr().get());
-        nested_column = nullable_col->getNestedColumnPtr();
-        const auto & result_column = convertNestedNullable(nested_column, 
nested_target_type);
-        return DB::ColumnNullable::create(result_column, 
nullable_col->getNullMapColumnPtr());
-    }
-
-    DB::WhichDataType nested_column_type(nested_column->getDataType());
-    if (nested_column_type.isMap())
-    {
-        // header: Map(String, Nullable(String))
-        // chunk:  Map(String, String)
-        const auto & array_column = assert_cast<const DB::ColumnMap 
&>(*nested_column).getNestedColumn();
-        const auto & map_type = assert_cast<const DB::DataTypeMap 
&>(*nested_target_type);
-        auto tuple_columns = assert_cast<const DB::ColumnTuple 
*>(array_column.getDataPtr().get())->getColumns();
-        // only convert for value column as key is always non-nullable
-        const auto & value_column = convertNestedNullable(tuple_columns[1],  
map_type.getValueType());
-        auto result_column = 
DB::ColumnArray::create(DB::ColumnTuple::create(DB::Columns{tuple_columns[0], 
value_column}), array_column.getOffsetsPtr());
-        return DB::ColumnMap::create(std::move(result_column));
-    }
-
-    if (nested_column_type.isArray())
-    {
-        // header: Array(Nullable(String))
-        // chunk:  Array(String)
-        const auto & list_column = assert_cast<const DB::ColumnArray 
&>(*nested_column);
-        auto nested_type = assert_cast<const DB::DataTypeArray 
&>(*nested_target_type).getNestedType();
-        const auto & result_column = 
convertNestedNullable(list_column.getDataPtr(), nested_type);
-        return DB::ColumnArray::create(result_column, 
list_column.getOffsetsPtr());
-    }
-
-    if (nested_column_type.isTuple())
-    {
-        // header: Tuple(Nullable(String), Nullable(String))
-        // chunk:  Tuple(String, Nullable(String))
-        const auto & tuple_column = assert_cast<const DB::ColumnTuple 
&>(*nested_column);
-        auto nested_types = assert_cast<const DB::DataTypeTuple 
&>(*nested_target_type).getElements();
-        DB::Columns columns;
-        for (size_t i = 0; i != tuple_column.tupleSize(); ++i)
-            
columns.push_back(convertNestedNullable(tuple_column.getColumnPtr(i), 
nested_types[i]));
-        return DB::ColumnTuple::create(std::move(columns));
-    }
-
-    return new_column;
-}
-
 }
diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h 
b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h
index 6ee02e748..80ac42b7a 100644
--- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h
+++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h
@@ -16,10 +16,9 @@
  */
 #pragma once
 #include <jni.h>
-#include <Processors/ISource.h>
-#include <Interpreters/Context.h>
 #include <Columns/IColumn.h>
-
+#include <Interpreters/Context.h>
+#include <Processors/ISource.h>
 namespace local_engine
 {
 class SourceFromJavaIter : public DB::ISource
@@ -30,18 +29,16 @@ public:
     static jmethodID serialized_record_batch_iterator_next;
 
     static Int64 byteArrayToLong(JNIEnv * env, jbyteArray arr);
-
     static DB::Block * peekBlock(JNIEnv * env, jobject java_iter);
 
-    SourceFromJavaIter(DB::ContextPtr context_, DB::Block header, jobject 
java_iter_, bool materialize_input_, DB::Block * peek_block_);
+    SourceFromJavaIter(
+        DB::ContextPtr context_, const DB::Block & header, jobject java_iter_, 
bool materialize_input_, const DB::Block * peek_block_);
     ~SourceFromJavaIter() override;
 
     String getName() const override { return "SourceFromJavaIter"; }
 
 private:
     DB::Chunk generate() override;
-    void convertNullable(DB::Chunk & chunk);
-    DB::ColumnPtr convertNestedNullable(const DB::ColumnPtr & column, const 
DB::DataTypePtr & target_type);
 
     DB::ContextPtr context;
     DB::Block original_header;
@@ -49,7 +46,7 @@ private:
     bool materialize_input;
 
     /// The first block read from java iteration to decide exact types of 
columns, especially for AggregateFunctions with parameters.
-    DB::Block * first_block = nullptr;
+    const DB::Block * first_block = nullptr;
 };
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to