This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 05eac08b57 [GLUTEN-3289][CH]Fix cast float to string (#8092)
05eac08b57 is described below

commit 05eac08b57f0ef9286d8a9171be7b6cd2d1a0e5d
Author: kevinyhzou <[email protected]>
AuthorDate: Fri Mar 14 17:20:43 2025 +0800

    [GLUTEN-3289][CH]Fix cast float to string (#8092)
    
    * Fix cast float to string
    
    * rebase
---
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala |  13 +++
 .../Functions/SparkFunctionCastFloatToString.cpp   | 120 +++++++++++++++++++++
 cpp-ch/local-engine/Parser/ExpressionParser.cpp    |   2 +
 .../Parser/scalar_function_parser/getTimestamp.h   |   2 +-
 4 files changed, 136 insertions(+), 1 deletion(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index 02b1609032..c30577c339 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/tpch/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2168,6 +2168,19 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     spark.sql("drop table test_tbl_3149")
   }
 
+  test("GLUTEN-3289: Fix convert float to string") {
+    val tbl_create_sql = "create table test_tbl_3289(id bigint, a float) using 
parquet"
+    val tbl_insert_sql = "insert into test_tbl_3289 values(1, 2.0), (2, 2.1), 
(3, 2.2)"
+    val select_sql_1 = "select cast(a as string), cast(a * 1.0f as string) 
from test_tbl_3289"
+    val select_sql_2 =
+      "select cast(cast(a as double) as string), cast(cast(a * 1.0f as double) 
as string) from test_tbl_3289"
+    spark.sql(tbl_create_sql)
+    spark.sql(tbl_insert_sql)
+    compareResultsAgainstVanillaSpark(select_sql_1, true, { _ => })
+    compareResultsAgainstVanillaSpark(select_sql_2, true, { _ => })
+    spark.sql("drop table test_tbl_3289")
+  }
+
   test("test in-filter contains null value (bigint)") {
     val sql = "select s_nationkey from supplier where s_nationkey in (null, 1, 
2)"
     compareResultsAgainstVanillaSpark(sql, true, { _ => })
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToString.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToString.cpp
new file mode 100644
index 0000000000..2c564997a2
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToString.cpp
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <base/TypeList.h>
+#include <Columns/ColumnVector.h>
+#include <DataTypes/DataTypesNumber.h>
+#include <Functions/IFunction.h>
+#include <Functions/FunctionFactory.h>
+#include <Functions/castTypeToEither.h>
+#include <Functions/FunctionsConversion.h>
+
+using namespace DB;
+
+namespace DB
+{
+namespace ErrorCodes
+{
+    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
+}
+}
+
+namespace local_engine
+{
+
+class SparkFunctionCastFloatToString : public IFunction
+{
+public:
+    static constexpr auto name = "sparkCastFloatToString";
+    static DB::FunctionPtr create(DB::ContextPtr) { return 
std::make_shared<SparkFunctionCastFloatToString>(); }
+
+    SparkFunctionCastFloatToString() = default;
+    ~SparkFunctionCastFloatToString() override = default;
+
+    size_t getNumberOfArguments() const override { return 1; }
+    String getName() const override { return name; }
+    bool useDefaultImplementationForConstants() const override { return true; }
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+
+    DataTypePtr getReturnTypeImpl(const DataTypes &) const override
+    {
+        return std::make_shared<const DataTypeString>();
+    }
+
+    template <typename F>
+    requires is_floating_point<F>
+    inline void writeFloatEnd(F x, WriteBuffer & buf) const
+    {
+        if constexpr (std::is_same_v<F, Float64>)
+        {
+            if (DecomposedFloat64(x).isIntegerInRepresentableRange())
+            {
+                writeChar('.', buf);
+                writeChar('0', buf);
+            }
+        }
+        else if constexpr (std::is_same_v<F, Float32>)
+        {
+            if (DecomposedFloat32(x).isIntegerInRepresentableRange())
+            {
+                writeChar('.', buf);
+                writeChar('0', buf);
+            } 
+        }
+    }
+
+    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr & result_type, size_t) const override
+    {
+        if (arguments.size() != 1)
+            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {}'s arguments number must be 1", name);
+
+        if (!isFloat(arguments[0].type))
+            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function 
{}'s 1st argument must be float type", name);
+
+        auto res_col = ColumnString::create();
+        ColumnString::Chars & res_data = res_col->getChars();
+        ColumnString::Offsets & res_offsets = res_col->getOffsets();
+        size_t size = arguments[0].column->size();
+        res_data.resize_exact(size * 3);
+        res_offsets.resize_exact(size);
+        using Types = TypeList<DataTypeFloat32, DataTypeFloat64>;
+        castTypeToEither(Types{}, arguments[0].type.get(), [&](const auto & 
arg_type)
+        {
+            using F = typename std::decay_t<decltype(arg_type)>::FieldType;
+            const ColumnVector<F> * src_col = 
checkAndGetColumn<ColumnVector<F>>(arguments[0].column.get());
+            WriteBufferFromVector<ColumnString::Chars> write_buffer(res_data);
+
+            for (size_t i = 0 ; i < src_col->size(); ++i)
+            {
+                writeFloatText(src_col->getElement(i), write_buffer);
+                writeFloatEnd<F>(src_col->getElement(i), write_buffer);
+                writeChar(0, write_buffer);
+                res_offsets[i] = write_buffer.count();
+            }
+            return true;
+        });
+        return std::move(res_col);
+    }
+};
+
+REGISTER_FUNCTION(SparkFunctionCastFloatToString)
+{
+    factory.registerFunction<SparkFunctionCastFloatToString>();
+}
+
+}
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp 
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index fb78066e8d..4261241f25 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -337,6 +337,8 @@ ExpressionParser::NodeRawConstPtr 
ExpressionParser::parseExpression(ActionsDAG &
                 String function_name = "sparkCastFloatTo" + 
denull_output_type->getName();
                 result_node = toFunctionNode(actions_dag, function_name, args);
             }
+            else if (isFloat(denull_input_type) && 
isString(denull_output_type))
+                result_node = toFunctionNode(actions_dag, 
"sparkCastFloatToString", args);
             else if ((isDecimal(denull_input_type) || 
isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
             {
                 int precision = substrait_type.decimal().precision();
diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h 
b/cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h
index 3596f65d94..5a90ba0cd7 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/getTimestamp.h
@@ -73,7 +73,7 @@ public:
             fmt = fmt_string_literal ? literal_fmt_expr.string() : "";
         }
         if (!fmt_string_literal)
-            throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The 
second of function {} must be const String.", name);
+            throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The 
second argument of function {} must be const String.", name);
 
         UInt32 s_count = std::count(fmt.begin(), fmt.end(), 'S');
         String time_parser_policy = 
getContext()->getSettingsRef().has(TIMER_PARSER_POLICY) ? 
toString(getContext()->getSettingsRef().get(TIMER_PARSER_POLICY)) : "";


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to