This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 4867d60bcd [GLUTEN-7602][CH] Add spark cast array to string (#8392)
4867d60bcd is described below
commit 4867d60bcde6627abb6f6a0988772b069f1229be
Author: zhanglistar <[email protected]>
AuthorDate: Sun Jan 5 17:46:46 2025 +0800
[GLUTEN-7602][CH] Add spark cast array to string (#8392)
* add spark cast array to string
* fix style
* add spark cast array to string
* fix ut
---
.../GlutenClickhouseFunctionSuite.scala | 16 ++
.../Functions/SparkFunctionArrayToString.cpp | 28 ++++
.../Functions/SparkFunctionArrayToString.h | 165 +++++++++++++++++++++
cpp-ch/local-engine/Parser/ExpressionParser.cpp | 5 +
4 files changed, 214 insertions(+)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
index 8f2658ef8f..974380f624 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
@@ -429,4 +429,20 @@ class GlutenClickhouseFunctionSuite extends
GlutenClickHouseTPCHAbstractSuite {
}
}
+ test("GLUTEN-7602: cast array to string") {
+ withTable("test_7602") {
+ sql("create table if not exists test_7602 (v ARRAY<STRING>) using
parquet")
+ sql("insert into test_7602 values(array('1', '2a', 'foo'));")
+ compareResultsAgainstVanillaSpark(
+ """
+ |select cast(v as string) from test_7602
+ """.stripMargin,
+ true,
+ { _ => }
+ )
+ val q = "select cast(a as string) from (select array('123',NULL) as a)"
+ compareResultsAgainstVanillaSpark(q, true, { _ => }, false)
+ }
+ }
+
}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayToString.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionArrayToString.cpp
new file mode 100644
index 0000000000..1914c47c04
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayToString.cpp
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Functions/SparkFunctionArrayToString.h>
+
+namespace local_engine
+{
+
+REGISTER_FUNCTION(SparkFunctionArrayToString)
+{
+ factory.registerFunction<local_eingine::SparkFunctionArrayToString>();
+}
+
+}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArrayToString.h
b/cpp-ch/local-engine/Functions/SparkFunctionArrayToString.h
new file mode 100644
index 0000000000..fce339d294
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionArrayToString.h
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+#include <memory>
+#include <Columns/ColumnArray.h>
+#include <Columns/ColumnNullable.h>
+#include <Columns/ColumnStringHelpers.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypeString.h>
+#include <DataTypes/DataTypeArray.h>
+#include <Formats/FormatFactory.h>
+#include <Functions/FunctionFactory.h>
+#include <Functions/FunctionHelpers.h>
+#include <IO/WriteHelpers.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+ extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+}
+}
+
+namespace local_eingine
+{
+class SparkFunctionArrayToString : public DB::IFunction
+{
+public:
+ static constexpr auto name = "sparkCastArrayToString";
+
+ static DB::FunctionPtr create(DB::ContextPtr context) { return
std::make_shared<SparkFunctionArrayToString>(context); }
+
+ explicit SparkFunctionArrayToString(DB::ContextPtr context_) :
context(context_) {}
+
+ ~SparkFunctionArrayToString() override = default;
+
+ String getName() const override { return name; }
+
+ size_t getNumberOfArguments() const override { return 1; }
+
+ bool isSuitableForShortCircuitArgumentsExecution(const
DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+
+ bool useDefaultImplementationForNulls() const override { return false; }
+
+ bool useDefaultImplementationForLowCardinalityColumns() const override {
return false; }
+
+ DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &
arguments) const override
+ {
+ if (arguments.size() != 1)
+ throw
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {}
argument size must be 1", name);
+
+ auto arg_type = DB::removeNullable(arguments[0].type);
+ if (!DB::WhichDataType(arg_type).isArray())
+ throw DB::Exception(
+ DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of
argument[0] of function {}",
+ arguments[0].type->getName(), getName());
+
+ if (arguments[0].type->isNullable())
+ return makeNullable(std::make_shared<DB::DataTypeString>());
+ else
+ return std::make_shared<DB::DataTypeString>();
+ }
+
+ DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments,
const DB::DataTypePtr & result_type, size_t input_rows_count) const override
+ {
+ DB::ColumnUInt8::MutablePtr null_map = nullptr;
+ if (const auto * col_nullable =
checkAndGetColumn<DB::ColumnNullable>(arguments[0].column.get()))
+ {
+ null_map = DB::ColumnUInt8::create();
+ null_map->insertRangeFrom(col_nullable->getNullMapColumn(), 0,
col_nullable->size());
+ }
+
+ const auto & nested_col_with_type_and_name =
columnGetNested(arguments[0]);
+
+ if (const auto * col_const = typeid_cast<const DB::ColumnConst
*>(nested_col_with_type_and_name.column.get()))
+ {
+ DB::ColumnsWithTypeAndName new_arguments {1};
+ new_arguments[0] = {col_const->getDataColumnPtr(),
nested_col_with_type_and_name.type, nested_col_with_type_and_name.name};
+ auto col = executeImpl(new_arguments, result_type, 1);
+ return DB::ColumnConst::create(std::move(col), input_rows_count);
+ }
+
+ const DB::IColumn & col_from = *nested_col_with_type_and_name.column;
+ size_t size = col_from.size();
+ auto col_to = removeNullable(result_type)->createColumn();
+
+ DB::FormatSettings format_settings = context ?
DB::getFormatSettings(context) : DB::FormatSettings{};
+ format_settings.pretty.charset =
DB::FormatSettings::Pretty::Charset::ASCII; /// Use ASCII for pretty output.
+
+ DB::ColumnStringHelpers::WriteHelper write_helper(
+ assert_cast<DB::ColumnString &>(*col_to),
+ size);
+
+ auto & write_buffer = write_helper.getWriteBuffer();
+
+ const auto * array_type =
checkAndGetDataType<DB::DataTypeArray>(nested_col_with_type_and_name.type.get());
+ if (!array_type)
+ throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument #1 for function {} must be an array, not {}",
+ name, arguments[0].type->getName());
+
+ DB::DataTypePtr value_type = array_type->getNestedType();
+ auto value_serializer = value_type->getDefaultSerialization();
+
+ for (size_t row = 0; row < size; ++row)
+ {
+ serializeInSparkStyle(col_from,row,write_buffer,format_settings,
value_serializer);
+ write_helper.rowWritten();
+ }
+
+ write_helper.finalize();
+
+ if (result_type->isNullable() && null_map)
+ return DB::ColumnNullable::create(std::move(col_to),
std::move(null_map));
+ return col_to;
+ }
+
+private:
+ DB::ContextPtr context;
+
+ void serializeInSparkStyle(
+ const DB::IColumn & column,
+ size_t row_num,
+ DB::WriteBuffer & ostr,
+ const DB::FormatSettings & settings,
+ const DB::SerializationPtr & value_serializer) const
+ {
+ const auto & column_array = assert_cast<const DB::ColumnArray
&>(column);
+
+ const auto & nested_column= column_array.getData();
+ const DB::ColumnArray::Offsets & offsets = column_array.getOffsets();
+
+ size_t offset = offsets[row_num - 1];
+ size_t next_offset = offsets[row_num];
+
+ writeChar('[', ostr);
+ if (offset != next_offset)
+ {
+ value_serializer->serializeText(nested_column, offset, ostr,
settings);
+ for (size_t i = offset + 1; i < next_offset; ++i)
+ {
+ writeString(std::string_view(", "), ostr);
+ value_serializer->serializeText(nested_column, i, ostr,
settings);
+ }
+ }
+ writeChar(']', ostr);
+ }
+};
+
+}
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index 7d41c78a3d..400d8c28df 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -351,6 +351,11 @@ ExpressionParser::NodeRawConstPtr
ExpressionParser::parseExpression(ActionsDAG &
addConstColumn(actions_dag,
map_input_type->getValueType(), map_input_type->getValueType()->getDefault()));
result_node = toFunctionNode(actions_dag,
"sparkCastMapToString", args);
}
+ else if (isArray(denull_input_type) &&
isString(denull_output_type))
+ {
+ // ISSUE-7602: spark cast(array to string) has different
result with CH cast(array to string)
+ result_node = toFunctionNode(actions_dag,
"sparkCastArrayToString", args);
+ }
else if (isString(denull_input_type) && substrait_type.has_bool_())
{
/// cast(string to boolean)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]