This is an automated email from the ASF dual-hosted git repository.
taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 05eac08b57 [GLUTEN-3289][CH]Fix cast float to string (#8092)
05eac08b57 is described below
commit 05eac08b57f0ef9286d8a9171be7b6cd2d1a0e5d
Author: kevinyhzou <[email protected]>
AuthorDate: Fri Mar 14 17:20:43 2025 +0800
[GLUTEN-3289][CH]Fix cast float to string (#8092)
* Fix cast float to string
* rebase
---
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 13 +++
.../Functions/SparkFunctionCastFloatToString.cpp | 120 +++++++++++++++++++++
cpp-ch/local-engine/Parser/ExpressionParser.cpp | 2 +
.../Parser/scalar_function_parser/getTimestamp.h | 2 +-
4 files changed, 136 insertions(+), 1 deletion(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 02b1609032..c30577c339 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2168,6 +2168,19 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
spark.sql("drop table test_tbl_3149")
}
+ test("GLUTEN-3289: Fix convert float to string") {
+ val tbl_create_sql = "create table test_tbl_3289(id bigint, a float) using
parquet"
+ val tbl_insert_sql = "insert into test_tbl_3289 values(1, 2.0), (2, 2.1),
(3, 2.2)"
+ val select_sql_1 = "select cast(a as string), cast(a * 1.0f as string)
from test_tbl_3289"
+ val select_sql_2 =
+ "select cast(cast(a as double) as string), cast(cast(a * 1.0f as double)
as string) from test_tbl_3289"
+ spark.sql(tbl_create_sql)
+ spark.sql(tbl_insert_sql)
+ compareResultsAgainstVanillaSpark(select_sql_1, true, { _ => })
+ compareResultsAgainstVanillaSpark(select_sql_2, true, { _ => })
+ spark.sql("drop table test_tbl_3289")
+ }
+
test("test in-filter contains null value (bigint)") {
val sql = "select s_nationkey from supplier where s_nationkey in (null, 1,
2)"
compareResultsAgainstVanillaSpark(sql, true, { _ => })
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToString.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToString.cpp
new file mode 100644
index 0000000000..2c564997a2
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToString.cpp
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <base/TypeList.h>
+#include <Columns/ColumnVector.h>
+#include <DataTypes/DataTypesNumber.h>
+#include <Functions/IFunction.h>
+#include <Functions/FunctionFactory.h>
+#include <Functions/castTypeToEither.h>
+#include <Functions/FunctionsConversion.h>
+
+using namespace DB;
+
+namespace DB
+{
+namespace ErrorCodes
+{
+ extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+}
+}
+
+namespace local_engine
+{
+
+class SparkFunctionCastFloatToString : public IFunction
+{
+public:
+ static constexpr auto name = "sparkCastFloatToString";
+ static DB::FunctionPtr create(DB::ContextPtr) { return
std::make_shared<SparkFunctionCastFloatToString>(); }
+
+ SparkFunctionCastFloatToString() = default;
+ ~SparkFunctionCastFloatToString() override = default;
+
+ size_t getNumberOfArguments() const override { return 1; }
+ String getName() const override { return name; }
+ bool useDefaultImplementationForConstants() const override { return true; }
+ bool isSuitableForShortCircuitArgumentsExecution(const
DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+
+ DataTypePtr getReturnTypeImpl(const DataTypes &) const override
+ {
+ return std::make_shared<const DataTypeString>();
+ }
+
+ template <typename F>
+ requires is_floating_point<F>
+ inline void writeFloatEnd(F x, WriteBuffer & buf) const
+ {
+ if constexpr (std::is_same_v<F, Float64>)
+ {
+ if (DecomposedFloat64(x).isIntegerInRepresentableRange())
+ {
+ writeChar('.', buf);
+ writeChar('0', buf);
+ }
+ }
+ else if constexpr (std::is_same_v<F, Float32>)
+ {
+ if (DecomposedFloat32(x).isIntegerInRepresentableRange())
+ {
+ writeChar('.', buf);
+ writeChar('0', buf);
+ }
+ }
+ }
+
+ ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr & result_type, size_t) const override
+ {
+ if (arguments.size() != 1)
+ throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {}'s arguments number must be 1", name);
+
+ if (!isFloat(arguments[0].type))
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function
{}'s 1st argument must be float type", name);
+
+ auto res_col = ColumnString::create();
+ ColumnString::Chars & res_data = res_col->getChars();
+ ColumnString::Offsets & res_offsets = res_col->getOffsets();
+ size_t size = arguments[0].column->size();
+ res_data.resize_exact(size * 3);
+ res_offsets.resize_exact(size);
+ using Types = TypeList<DataTypeFloat32, DataTypeFloat64>;
+ castTypeToEither(Types{}, arguments[0].type.get(), [&](const auto &
arg_type)
+ {
+ using F = typename std::decay_t<decltype(arg_type)>::FieldType;
+ const ColumnVector<F> * src_col =
checkAndGetColumn<ColumnVector<F>>(arguments[0].column.get());
+ WriteBufferFromVector<ColumnString::Chars> write_buffer(res_data);
+
+ for (size_t i = 0 ; i < src_col->size(); ++i)
+ {
+ writeFloatText(src_col->getElement(i), write_buffer);
+ writeFloatEnd<F>(src_col->getElement(i), write_buffer);
+ writeChar(0, write_buffer);
+ res_offsets[i] = write_buffer.count();
+ }
+ return true;
+ });
+ return std::move(res_col);
+ }
+};
+
+REGISTER_FUNCTION(SparkFunctionCastFloatToString)
+{
+ factory.registerFunction<SparkFunctionCastFloatToString>();
+}
+
+}
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index fb78066e8d..4261241f25 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -337,6 +337,8 @@ ExpressionParser::NodeRawConstPtr
ExpressionParser::parseExpression(ActionsDAG &
String function_name = "sparkCastFloatTo" +
denull_output_type->getName();
result_node = toFunctionNode(actions_dag, function_name, args);
}
+ else if (isFloat(denull_input_type) &&
isString(denull_output_type))
+ result_node = toFunctionNode(actions_dag,
"sparkCastFloatToString", args);
else if ((isDecimal(denull_input_type) ||
isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
{
int precision = substrait_type.decimal().precision();
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h
b/cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h
index 3596f65d94..5a90ba0cd7 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h
@@ -73,7 +73,7 @@ public:
fmt = fmt_string_literal ? literal_fmt_expr.string() : "";
}
if (!fmt_string_literal)
- throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The
second of function {} must be const String.", name);
+ throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The
second argument of function {} must be const String.", name);
UInt32 s_count = std::count(fmt.begin(), fmt.end(), 'S');
String time_parser_policy =
getContext()->getSettingsRef().has(TIMER_PARSER_POLICY) ?
toString(getContext()->getSettingsRef().get(TIMER_PARSER_POLICY)) : "";
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]