This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 92d3793273 [GLUTEN-6856][CH]Support arrays_overlap and fix array_join
diff (#6857)
92d3793273 is described below
commit 92d3793273b6da051be6fb6bffa9329ecb439b30
Author: kevinyhzou <[email protected]>
AuthorDate: Wed Oct 9 10:25:46 2024 +0800
[GLUTEN-6856][CH]Support arrays_overlap and fix array_join diff (#6857)
* Fix orc timezone read
* on test
* fix ci
* fix ci
* ci build error
* support arrays_overlap
* remove useless code
* fix ci
* fix ci
* fix array_join diff
* remove useless code
* remove useless code
* solve conflict
* ci fix
* remove useless code
* use default impl for array join
* remove useless code
---
.../org/apache/gluten/utils/CHExpressionUtil.scala | 9 --
.../Functions/SparkFunctionArrayJoin.cpp | 135 ++++++--------------
.../Functions/SparkFunctionArraysOverlap.cpp | 139 +++++++++++++++++++++
.../CommonScalarFunctionParser.cpp | 1 +
4 files changed, 181 insertions(+), 103 deletions(-)
diff --git
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
index e58db43b26..bb23612b13 100644
---
a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
+++
b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala
@@ -159,13 +159,6 @@ case class EncodeDecodeValidator() extends
FunctionValidator {
}
}
-case class ArrayJoinValidator() extends FunctionValidator {
- override def doValidate(expr: Expression): Boolean = expr match {
- case t: ArrayJoin => !t.children.head.isInstanceOf[Literal]
- case _ => true
- }
-}
-
case class FormatStringValidator() extends FunctionValidator {
override def doValidate(expr: Expression): Boolean = {
val formatString = expr.asInstanceOf[FormatString]
@@ -181,13 +174,11 @@ object CHExpressionUtil {
)
final val CH_BLACKLIST_SCALAR_FUNCTION: Map[String, FunctionValidator] = Map(
- ARRAY_JOIN -> ArrayJoinValidator(),
SPLIT_PART -> DefaultValidator(),
TO_UNIX_TIMESTAMP -> UnixTimeStampValidator(),
UNIX_TIMESTAMP -> UnixTimeStampValidator(),
SEQUENCE -> SequenceValidator(),
GET_JSON_OBJECT -> GetJsonObjectValidator(),
- ARRAYS_OVERLAP -> DefaultValidator(),
SPLIT -> StringSplitValidator(),
SUBSTRING_INDEX -> SubstringIndexValidator(),
LPAD -> StringLPadValidator(),
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
index ed99c09042..4c2847d9f9 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayJoin.cpp
@@ -46,7 +46,7 @@ public:
size_t getNumberOfArguments() const override { return 0; }
String getName() const override { return name; }
bool isVariadic() const override { return true; }
- bool useDefaultImplementationForNulls() const override { return false; }
+ bool useDefaultImplementationForConstants() const override { return true; }
DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const
override
{
@@ -54,61 +54,20 @@ public:
return makeNullable(data_type);
}
- ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr &, size_t) const override
- {
+ ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr &, size_t input_rows_count) const override
+ {
if (arguments.size() != 2 && arguments.size() != 3)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} must have 2 or 3 arguments", getName());
-
- const auto * arg_null_col =
checkAndGetColumn<ColumnNullable>(arguments[0].column.get());
- const ColumnArray * array_col;
- if (!arg_null_col)
- array_col =
checkAndGetColumn<ColumnArray>(arguments[0].column.get());
- else
- array_col =
checkAndGetColumn<ColumnArray>(arg_null_col->getNestedColumnPtr().get());
- if (!array_col)
- throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}
1st argument must be array type", getName());
-
auto res_col = ColumnString::create();
- auto null_col = ColumnUInt8::create(array_col->size(), 0);
+ auto null_col = ColumnUInt8::create(input_rows_count, 0);
PaddedPODArray<UInt8> & null_result = null_col->getData();
- std::pair<bool, StringRef> delim_p, null_replacement_p;
- bool return_result = false;
- auto checkAndGetConstString = [&](const ColumnPtr & col) ->
std::pair<bool, StringRef>
- {
- StringRef res;
- const auto * str_null_col =
checkAndGetColumnConstData<ColumnNullable>(col.get());
- if (str_null_col)
- {
- if (str_null_col->isNullAt(0))
- {
- for (size_t i = 0; i < array_col->size(); ++i)
- {
- res_col->insertDefault();
- null_result[i] = 1;
- }
- return_result = true;
- return std::pair<bool, StringRef>(false, res);
- }
- }
- else
- {
- const auto * string_col =
checkAndGetColumnConstData<ColumnString>(col.get());
- if (!string_col)
- return std::pair<bool, StringRef>(false, res);
- else
- return std::pair<bool, StringRef>(true,
string_col->getDataAt(0));
- }
- };
- delim_p = checkAndGetConstString(arguments[1].column);
- if (return_result)
+ if (input_rows_count == 0)
return ColumnNullable::create(std::move(res_col),
std::move(null_col));
- if (arguments.size() == 3)
- {
- null_replacement_p = checkAndGetConstString(arguments[2].column);
- if (return_result)
- return ColumnNullable::create(std::move(res_col),
std::move(null_col));
- }
+ const ColumnArray * array_col = array_col =
checkAndGetColumn<ColumnArray>(arguments[0].column.get());;
+ if (!array_col)
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}
1st argument must be array type", getName());
+
const ColumnNullable * array_nested_col =
checkAndGetColumn<ColumnNullable>(&array_col->getData());
const ColumnString * string_col;
if (array_nested_col)
@@ -118,57 +77,42 @@ public:
const ColumnArray::Offsets & array_offsets = array_col->getOffsets();
const ColumnString::Offsets & string_offsets =
string_col->getOffsets();
const ColumnString::Chars & string_data = string_col->getChars();
- const ColumnNullable * delim_col =
checkAndGetColumn<ColumnNullable>(arguments[1].column.get());
- const ColumnNullable * null_replacement_col = arguments.size() == 3 ?
checkAndGetColumn<ColumnNullable>(arguments[2].column.get()) : nullptr;
+
+ auto extractColumnString = [&](const ColumnPtr & col) -> const
ColumnString *
+ {
+ const ColumnString * res = nullptr;
+ if (col->isConst())
+ {
+ const ColumnConst * const_col =
checkAndGetColumn<ColumnConst>(col.get());
+ if (const_col)
+ res =
checkAndGetColumn<ColumnString>(const_col->getDataColumnPtr().get());
+ }
+ else
+ res = checkAndGetColumn<ColumnString>(col.get());
+ return res;
+ };
+ bool const_delim_col = arguments[1].column->isConst();
+ bool const_null_replacement_col = false;
+ const ColumnString * delim_col =
extractColumnString(arguments[1].column);
+ const ColumnString * null_replacement_col = nullptr;
+ if (arguments.size() == 3)
+ {
+ const_null_replacement_col = arguments[2].column->isConst();
+ null_replacement_col = extractColumnString(arguments[2].column);
+ }
size_t current_offset = 0, array_pos = 0;
for (size_t i = 0; i < array_col->size(); ++i)
{
String res;
- auto setResultNull = [&]() -> void
+ const StringRef delim = const_delim_col ? delim_col->getDataAt(0)
: delim_col->getDataAt(i);
+ StringRef null_replacement = StringRef(nullptr, 0);
+ if (null_replacement_col)
{
- res_col->insertDefault();
- null_result[i] = 1;
- current_offset = array_offsets[i];
- };
- auto getDelimiterOrNullReplacement = [&](const std::pair<bool,
StringRef> & s, const ColumnNullable * col) -> StringRef
- {
- if (s.first)
- return s.second;
- else
- {
- if (col->isNullAt(i))
- return StringRef(nullptr, 0);
- else
- {
- const ColumnString * col_string =
checkAndGetColumn<ColumnString>(col->getNestedColumnPtr().get());
- return col_string->getDataAt(i);
- }
- }
- };
- if (arg_null_col->isNullAt(i))
- {
- setResultNull();
- continue;
- }
- const StringRef delim = getDelimiterOrNullReplacement(delim_p,
delim_col);
- if (!delim.data)
- {
- setResultNull();
- continue;
+ null_replacement = const_null_replacement_col ?
null_replacement_col->getDataAt(0) : null_replacement_col->getDataAt(i);
}
- StringRef null_replacement;
- if (arguments.size() == 3)
- {
- null_replacement =
getDelimiterOrNullReplacement(null_replacement_p, null_replacement_col);
- if (!null_replacement.data)
- {
- setResultNull();
- continue;
- }
- }
-
size_t array_size = array_offsets[i] - current_offset;
size_t data_pos = array_pos == 0 ? 0 : string_offsets[array_pos -
1];
+ size_t last_not_null_pos = 0;
for (size_t j = 0; j < array_size; ++j)
{
if (array_nested_col && array_nested_col->isNullAt(j +
array_pos))
@@ -179,11 +123,14 @@ public:
if (j != array_size - 1)
res += delim.toString();
}
+ else if (j == array_size - 1)
+ res = res.substr(0, last_not_null_pos);
}
else
{
const StringRef s(&string_data[data_pos], string_offsets[j
+ array_pos] - data_pos - 1);
res += s.toString();
+ last_not_null_pos = res.size();
if (j != array_size - 1)
res += delim.toString();
}
@@ -194,7 +141,7 @@ public:
current_offset = array_offsets[i];
}
return ColumnNullable::create(std::move(res_col), std::move(null_col));
- }
+ }
};
REGISTER_FUNCTION(SparkArrayJoin)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp
new file mode 100644
index 0000000000..e43b528231
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArraysOverlap.cpp
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <Columns/ColumnString.h>
+#include <Columns/ColumnNullable.h>
+#include <Functions/IFunction.h>
+#include <Functions/FunctionFactory.h>
+#include <DataTypes/DataTypeString.h>
+#include <DataTypes/DataTypesNumber.h>
+
+using namespace DB;
+
+namespace DB
+{
+namespace ErrorCodes
+{
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+ extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+class SparkFunctionArraysOverlap : public IFunction
+{
+public:
+ static constexpr auto name = "sparkArraysOverlap";
+ static FunctionPtr create(ContextPtr) { return
std::make_shared<SparkFunctionArraysOverlap>(); }
+ SparkFunctionArraysOverlap() = default;
+ ~SparkFunctionArraysOverlap() override = default;
+ bool isSuitableForShortCircuitArgumentsExecution(const
DataTypesWithConstInfo &) const override { return true; }
+ size_t getNumberOfArguments() const override { return 2; }
+ String getName() const override { return name; }
+ bool useDefaultImplementationForConstants() const override { return true; }
+
+ DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const
override
+ {
+ auto data_type = std::make_shared<DataTypeUInt8>();
+ return makeNullable(data_type);
+ }
+
+ ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr &, size_t input_rows_count) const override
+ {
+ if (arguments.size() != 2)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} must have 2 arguments", getName());
+
+ auto res = ColumnUInt8::create(input_rows_count, 0);
+ auto null_map = ColumnUInt8::create(input_rows_count, 0);
+ PaddedPODArray<UInt8> & res_data = res->getData();
+ PaddedPODArray<UInt8> & null_map_data = null_map->getData();
+ if (input_rows_count == 0)
+ return ColumnNullable::create(std::move(res), std::move(null_map));
+
+ const ColumnArray * array_col_1 =
checkAndGetColumn<ColumnArray>(arguments[0].column.get());
+ const ColumnArray * array_col_2 =
checkAndGetColumn<ColumnArray>(arguments[1].column.get());
+ if (!array_col_1 || !array_col_2)
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function {}
1st/2nd argument must be array type", getName());
+
+ const ColumnArray::Offsets & array_offsets_1 =
array_col_1->getOffsets();
+ const ColumnArray::Offsets & array_offsets_2 =
array_col_2->getOffsets();
+
+ size_t current_offset_1 = 0, current_offset_2 = 0;
+ size_t array_pos_1 = 0, array_pos_2 = 0;
+ for (size_t i = 0; i < array_col_1->size(); ++i)
+ {
+ size_t array_size_1 = array_offsets_1[i] - current_offset_1;
+ size_t array_size_2 = array_offsets_2[i] - current_offset_2;
+ auto executeCompare = [&](const IColumn & col1, const IColumn &
col2, const ColumnUInt8 * null_map1, const ColumnUInt8 * null_map2) -> void
+ {
+ for (size_t j = 0; j < array_size_1 && !res_data[i]; ++j)
+ {
+ for (size_t k = 0; k < array_size_2; ++k)
+ {
+ if ((null_map1 && null_map1->getElement(j +
array_pos_1)) || (null_map2 && null_map2->getElement(k + array_pos_2)))
+ {
+ null_map_data[i] = 1;
+ }
+ else if (col1.compareAt(j + array_pos_1, k +
array_pos_2, col2, -1) == 0)
+ {
+ res_data[i] = 1;
+ null_map_data[i] = 0;
+ break;
+ }
+ }
+ }
+ };
+ if (array_col_1->getData().isNullable() ||
array_col_2->getData().isNullable())
+ {
+ if (array_col_1->getData().isNullable() &&
array_col_2->getData().isNullable())
+ {
+ const ColumnNullable * array_null_col_1 =
assert_cast<const ColumnNullable *>(&array_col_1->getData());
+ const ColumnNullable * array_null_col_2 =
assert_cast<const ColumnNullable *>(&array_col_2->getData());
+ executeCompare(array_null_col_1->getNestedColumn(),
array_null_col_2->getNestedColumn(),
+ &array_null_col_1->getNullMapColumn(),
&array_null_col_2->getNullMapColumn());
+ }
+ else if (array_col_1->getData().isNullable())
+ {
+ const ColumnNullable * array_null_col_1 =
assert_cast<const ColumnNullable *>(&array_col_1->getData());
+ executeCompare(array_null_col_1->getNestedColumn(),
array_col_2->getData(), &array_null_col_1->getNullMapColumn(), nullptr);
+ }
+ else if (array_col_2->getData().isNullable())
+ {
+ const ColumnNullable * array_null_col_2 =
assert_cast<const ColumnNullable *>(&array_col_2->getData());
+ executeCompare(array_col_1->getData(),
array_null_col_2->getNestedColumn(), nullptr,
&array_null_col_2->getNullMapColumn());
+ }
+ }
+ else if (array_col_1->getData().getDataType() ==
array_col_2->getData().getDataType())
+ {
+ executeCompare(array_col_1->getData(), array_col_2->getData(),
nullptr, nullptr);
+ }
+
+ current_offset_1 = array_offsets_1[i];
+ current_offset_2 = array_offsets_2[i];
+ array_pos_1 += array_size_1;
+ array_pos_2 += array_size_2;
+ }
+ return ColumnNullable::create(std::move(res), std::move(null_map));
+ }
+};
+
+REGISTER_FUNCTION(SparkArraysOverlap)
+{
+ factory.registerFunction<SparkFunctionArraysOverlap>();
+}
+
+}
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
index 88e5d7ea8a..695166a894 100644
---
a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
+++
b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp
@@ -166,6 +166,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle,
arrayShuffle);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin);
+REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysOverlap, arrays_overlap,
sparkArraysOverlap);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysZip, arrays_zip,
arrayZipUnaligned);
// map functions
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]