This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 1b619a038 [GLUTEN-6819][CH] Refactor source from jave iter && make
casting happens before materializing (#6830)
1b619a038 is described below
commit 1b619a0388417c03d4dff8df4c89175619ada5b0
Author: 李扬 <[email protected]>
AuthorDate: Wed Aug 14 21:53:08 2024 +0800
[GLUTEN-6819][CH] Refactor source from jave iter && make casting happens
before materializing (#6830)
* refactor source from jave iter
* remove used includes
---
cpp-ch/local-engine/Common/QueryContext.cpp | 2 -
cpp-ch/local-engine/Parser/TypeParser.cpp | 1 -
.../local-engine/Storages/SourceFromJavaIter.cpp | 197 ++++++++-------------
cpp-ch/local-engine/Storages/SourceFromJavaIter.h | 13 +-
4 files changed, 76 insertions(+), 137 deletions(-)
diff --git a/cpp-ch/local-engine/Common/QueryContext.cpp
b/cpp-ch/local-engine/Common/QueryContext.cpp
index 2d5780a6e..e5f5dd5dc 100644
--- a/cpp-ch/local-engine/Common/QueryContext.cpp
+++ b/cpp-ch/local-engine/Common/QueryContext.cpp
@@ -16,8 +16,6 @@
*/
#include "QueryContext.h"
-#include <format>
-
#include <Interpreters/Context.h>
#include <Parser/SerializedPlanParser.h>
#include <Common/CurrentThread.h>
diff --git a/cpp-ch/local-engine/Parser/TypeParser.cpp
b/cpp-ch/local-engine/Parser/TypeParser.cpp
index 39d52131e..269f35747 100644
--- a/cpp-ch/local-engine/Parser/TypeParser.cpp
+++ b/cpp-ch/local-engine/Parser/TypeParser.cpp
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#include <optional>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Core/ColumnsWithTypeAndName.h>
#include <DataTypes/DataTypeAggregateFunction.h>
diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
index 1c5902c8c..2191112cf 100644
--- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
+++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.cpp
@@ -15,25 +15,20 @@
* limitations under the License.
*/
#include "SourceFromJavaIter.h"
-#include <Columns/ColumnConst.h>
-#include <Columns/ColumnNullable.h>
-#include <Columns/ColumnMap.h>
-#include <Columns/ColumnTuple.h>
-#include <Columns/IColumn.h>
-#include <Core/ColumnsWithTypeAndName.h>
-#include <DataTypes/DataTypesNumber.h>
+#include <Interpreters/castColumn.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <jni/jni_common.h>
-#include <Common/assert_cast.h>
#include <Common/CHUtil.h>
-#include <Common/DebugUtils.h>
#include <Common/Exception.h>
#include <Common/JNIUtils.h>
-#include <DataTypes/DataTypeArray.h>
-#include <DataTypes/DataTypeTuple.h>
-#include <DataTypes/DataTypeMap.h>
-#include <DataTypes/DataTypeNullable.h>
-#include <DataTypes/IDataType.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+ extern const int LOGICAL_ERROR;
+}
+}
namespace local_engine
{
@@ -41,12 +36,38 @@ jclass
SourceFromJavaIter::serialized_record_batch_iterator_class = nullptr;
jmethodID SourceFromJavaIter::serialized_record_batch_iterator_hasNext =
nullptr;
jmethodID SourceFromJavaIter::serialized_record_batch_iterator_next = nullptr;
-
-static DB::Block getRealHeader(const DB::Block & header)
+static DB::Block getRealHeader(const DB::Block & header, const DB::Block *
first_block)
{
- if (header.columns())
+ if (!header)
+ return BlockUtil::buildRowCountHeader();
+
+ if (!first_block)
return header;
- return BlockUtil::buildRowCountHeader();
+
+ if (header.columns() != first_block->columns())
+ throw DB::Exception(
+ DB::ErrorCodes::LOGICAL_ERROR,
+ "Header first block have different number of columns, header:{}
first_block:{}",
+ header.dumpStructure(),
+ first_block->dumpStructure());
+
+ DB::Block result;
+ const size_t column_size = header.columns();
+ for (size_t i = 0; i < column_size; ++i)
+ {
+ const auto & header_column = header.getByPosition(i);
+ const auto & input_column = first_block->getByPosition(i);
+ chassert(header_column.name == input_column.name);
+
+ DB::WhichDataType input_which(input_column.type);
+ /// Some AggregateFunctions may have parameters, so we need to use the
exact type from the first block.
+ /// e.g. spark approx_percentile -> CH quantilesGK(accuracy, level1,
level2, ...), the intermediate result type
+ /// parsed from substrait plan is always AggregateFunction(10000,
1)(quantilesGK, arg_type), which maybe different
+ /// from the actual intermediate result type from input block. So we
need to use the exact type from the input block.
+ auto type = input_which.isAggregateFunction() ? input_column.type :
header_column.type;
+ result.insert(DB::ColumnWithTypeAndName(type, header_column.name));
+ }
+ return result;
}
@@ -62,8 +83,8 @@ DB::Block * SourceFromJavaIter::peekBlock(JNIEnv * env,
jobject java_iter)
SourceFromJavaIter::SourceFromJavaIter(
- DB::ContextPtr context_, DB::Block header, jobject java_iter_, bool
materialize_input_, DB::Block * first_block_)
- : DB::ISource(getRealHeader(header))
+ DB::ContextPtr context_, const DB::Block & header, jobject java_iter_,
bool materialize_input_, const DB::Block * first_block_)
+ : DB::ISource(getRealHeader(header, first_block_))
, context(context_)
, original_header(header)
, java_iter(java_iter_)
@@ -80,43 +101,50 @@ DB::Chunk SourceFromJavaIter::generate()
GET_JNIENV(env)
SCOPE_EXIT({CLEAN_JNIENV});
- DB::Chunk result;
- DB::Block * data = nullptr;
+ DB::Block * input_block = nullptr;
if (first_block) [[unlikely]]
{
- data = first_block;
+ input_block = const_cast<DB::Block *>(first_block);
first_block = nullptr;
}
else if (jboolean has_next = safeCallBooleanMethod(env, java_iter,
serialized_record_batch_iterator_hasNext))
{
jbyteArray block = static_cast<jbyteArray>(safeCallObjectMethod(env,
java_iter, serialized_record_batch_iterator_next));
- data = reinterpret_cast<DB::Block *>(byteArrayToLong(env, block));
+ input_block = reinterpret_cast<DB::Block *>(byteArrayToLong(env,
block));
}
else
return {};
- /// Post-processing
- if (materialize_input)
- materializeBlockInplace(*data);
-
- if (data->rows() > 0)
+ DB::Chunk result;
+ if (original_header)
{
- size_t rows = data->rows();
- if (original_header.columns())
+ const auto & header = getPort().getHeader();
+ chassert(header.columns() == input_block->columns());
+ /// Cast all input columns in data to expected data types in header
+ for (size_t i = 0; i < header.columns(); ++i)
{
- result.setColumns(data->mutateColumns(), rows);
- convertNullable(result);
- auto info = std::make_shared<DB::AggregatedChunkInfo>();
- info->is_overflows = data->info.is_overflows;
- info->bucket_num = data->info.bucket_num;
- result.getChunkInfos().add(std::move(info));
- }
- else
- {
- result = BlockUtil::buildRowCountChunk(rows);
- auto info = std::make_shared<DB::AggregatedChunkInfo>();
- result.getChunkInfos().add(std::move(info));
+ auto & input_column = input_block->getByPosition(i);
+ const auto & expected_type = header.getByPosition(i).type;
+ auto column = DB::castColumn(input_column, expected_type);
+ input_column.column = column;
+ input_column.type = expected_type;
}
+
+ /// Do materializing after casting is faster than materializing before
casting
+ if (materialize_input)
+ materializeBlockInplace(*input_block);
+
+ auto info = std::make_shared<DB::AggregatedChunkInfo>();
+ info->is_overflows = input_block->info.is_overflows;
+ info->bucket_num = input_block->info.bucket_num;
+ result.getChunkInfos().add(std::move(info));
+ result.setColumns(input_block->getColumns(), input_block->rows());
+ }
+ else
+ {
+ DB::Chunk result = BlockUtil::buildRowCountChunk(input_block->rows());
+ auto info = std::make_shared<DB::AggregatedChunkInfo>();
+ result.getChunkInfos().add(std::move(info));
}
return result;
}
@@ -140,87 +168,4 @@ Int64 SourceFromJavaIter::byteArrayToLong(JNIEnv * env,
jbyteArray arr)
return result;
}
-void SourceFromJavaIter::convertNullable(DB::Chunk & chunk)
-{
- auto output = this->getOutputs().front().getHeader();
- auto rows = chunk.getNumRows();
- auto columns = chunk.detachColumns();
- for (size_t i = 0; i < columns.size(); ++i)
- {
- const auto & column = columns.at(i);
- const auto & type = output.getByPosition(i).type;
- columns[i] = convertNestedNullable(column, type);
- }
- chunk.setColumns(columns, rows);
-}
-
-
-DB::ColumnPtr SourceFromJavaIter::convertNestedNullable(const DB::ColumnPtr &
column, const DB::DataTypePtr & target_type)
-{
- DB::WhichDataType column_type(column->getDataType());
- if (column_type.isAggregateFunction())
- return column;
-
- if (DB::isColumnConst(*column))
- {
- const auto & data_column = assert_cast<const DB::ColumnConst
&>(*column).getDataColumnPtr();
- const auto & result_column = convertNestedNullable(data_column,
target_type);
- return DB::ColumnConst::create(result_column, column->size());
- }
-
- // if target type is non-nullable, the column type must be also
non-nullable, recursively converting it's nested type
- // if target type is nullable, the column type may be nullable or
non-nullable, converting it and then recursively converting it's nested type
- DB::ColumnPtr new_column = column;
- if (!column_type.isNullable() && target_type->isNullable())
- new_column = DB::makeNullable(column);
-
- DB::ColumnPtr nested_column = new_column;
- DB::DataTypePtr nested_target_type = removeNullable(target_type);
- if (new_column->isNullable())
- {
- const auto & nullable_col = typeid_cast<const DB::ColumnNullable
*>(new_column->getPtr().get());
- nested_column = nullable_col->getNestedColumnPtr();
- const auto & result_column = convertNestedNullable(nested_column,
nested_target_type);
- return DB::ColumnNullable::create(result_column,
nullable_col->getNullMapColumnPtr());
- }
-
- DB::WhichDataType nested_column_type(nested_column->getDataType());
- if (nested_column_type.isMap())
- {
- // header: Map(String, Nullable(String))
- // chunk: Map(String, String)
- const auto & array_column = assert_cast<const DB::ColumnMap
&>(*nested_column).getNestedColumn();
- const auto & map_type = assert_cast<const DB::DataTypeMap
&>(*nested_target_type);
- auto tuple_columns = assert_cast<const DB::ColumnTuple
*>(array_column.getDataPtr().get())->getColumns();
- // only convert for value column as key is always non-nullable
- const auto & value_column = convertNestedNullable(tuple_columns[1],
map_type.getValueType());
- auto result_column =
DB::ColumnArray::create(DB::ColumnTuple::create(DB::Columns{tuple_columns[0],
value_column}), array_column.getOffsetsPtr());
- return DB::ColumnMap::create(std::move(result_column));
- }
-
- if (nested_column_type.isArray())
- {
- // header: Array(Nullable(String))
- // chunk: Array(String)
- const auto & list_column = assert_cast<const DB::ColumnArray
&>(*nested_column);
- auto nested_type = assert_cast<const DB::DataTypeArray
&>(*nested_target_type).getNestedType();
- const auto & result_column =
convertNestedNullable(list_column.getDataPtr(), nested_type);
- return DB::ColumnArray::create(result_column,
list_column.getOffsetsPtr());
- }
-
- if (nested_column_type.isTuple())
- {
- // header: Tuple(Nullable(String), Nullable(String))
- // chunk: Tuple(String, Nullable(String))
- const auto & tuple_column = assert_cast<const DB::ColumnTuple
&>(*nested_column);
- auto nested_types = assert_cast<const DB::DataTypeTuple
&>(*nested_target_type).getElements();
- DB::Columns columns;
- for (size_t i = 0; i != tuple_column.tupleSize(); ++i)
-
columns.push_back(convertNestedNullable(tuple_column.getColumnPtr(i),
nested_types[i]));
- return DB::ColumnTuple::create(std::move(columns));
- }
-
- return new_column;
-}
-
}
diff --git a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h
b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h
index 6ee02e748..80ac42b7a 100644
--- a/cpp-ch/local-engine/Storages/SourceFromJavaIter.h
+++ b/cpp-ch/local-engine/Storages/SourceFromJavaIter.h
@@ -16,10 +16,9 @@
*/
#pragma once
#include <jni.h>
-#include <Processors/ISource.h>
-#include <Interpreters/Context.h>
#include <Columns/IColumn.h>
-
+#include <Interpreters/Context.h>
+#include <Processors/ISource.h>
namespace local_engine
{
class SourceFromJavaIter : public DB::ISource
@@ -30,18 +29,16 @@ public:
static jmethodID serialized_record_batch_iterator_next;
static Int64 byteArrayToLong(JNIEnv * env, jbyteArray arr);
-
static DB::Block * peekBlock(JNIEnv * env, jobject java_iter);
- SourceFromJavaIter(DB::ContextPtr context_, DB::Block header, jobject
java_iter_, bool materialize_input_, DB::Block * peek_block_);
+ SourceFromJavaIter(
+ DB::ContextPtr context_, const DB::Block & header, jobject java_iter_,
bool materialize_input_, const DB::Block * peek_block_);
~SourceFromJavaIter() override;
String getName() const override { return "SourceFromJavaIter"; }
private:
DB::Chunk generate() override;
- void convertNullable(DB::Chunk & chunk);
- DB::ColumnPtr convertNestedNullable(const DB::ColumnPtr & column, const
DB::DataTypePtr & target_type);
DB::ContextPtr context;
DB::Block original_header;
@@ -49,7 +46,7 @@ private:
bool materialize_input;
/// The first block read from java iteration to decide exact types of
columns, especially for AggregateFunctions with parameters.
- DB::Block * first_block = nullptr;
+ const DB::Block * first_block = nullptr;
};
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]