This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 81b023eda [GLUTEN-4994][CH]Fix function conversions (#4995)
81b023eda is described below

commit 81b023eda4a3915da5ccb694a0aaed5f65a816eb
Author: KevinyhZou <[email protected]>
AuthorDate: Mon Mar 18 19:27:18 2024 +0800

    [GLUTEN-4994][CH]Fix function conversions (#4995)
    
    * Fix functon conversions
    
    * benchmark modify
---
 ...mp.cpp => SparkFunctionDateToUnixTimestamp.cpp} |  6 ++--
 ...estamp.h => SparkFunctionDateToUnixTimestamp.h} | 36 +++++++++++++---------
 .../local-engine/Functions/SparkFunctionToDate.cpp | 19 ++++++++++--
 .../Functions/SparkFunctionToDateTime.h            | 28 +++++++++++++++--
 .../scalar_function_parser/unixTimestamp.cpp       |  2 +-
 .../tests/benchmark_to_datetime_function.cpp       |  1 -
 .../tests/benchmark_unix_timestamp_function.cpp    |  6 ++--
 7 files changed, 71 insertions(+), 27 deletions(-)

diff --git a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp
similarity index 81%
rename from cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp
rename to cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp
index c3040c60d..e7ae17feb 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp
@@ -15,14 +15,14 @@
  * limitations under the License.
  */
 
-#include <Functions/SparkFunctionUnixTimestamp.h>
+#include <Functions/SparkFunctionDateToUnixTimestamp.h>
 
 namespace local_engine
 {
 
-REGISTER_FUNCTION(SparkFunctionUnixTimestamp)
+REGISTER_FUNCTION(SparkFunctionDateToUnixTimestamp)
 {
-    factory.registerFunction<local_eingine::SparkFunctionUnixTimestamp>();
+    
factory.registerFunction<local_eingine::SparkFunctionDateToUnixTimestamp>();
 }
 
 }
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h 
b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h
similarity index 71%
rename from cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h
rename to cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h
index 356519afa..cdf0460e0 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h
@@ -15,8 +15,13 @@
  * limitations under the License.
  */
 
+#include <Common/DateLUT.h>
 #include <Common/DateLUTImpl.h>
-#include <Functions/FunctionsConversion.h>
+#include <Common/LocalDateTime.h>
+#include <Columns/ColumnVector.h>
+#include <DataTypes/IDataType.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypesNumber.h>
 #include <Functions/FunctionFactory.h>
 
 namespace DB
@@ -32,12 +37,12 @@ using namespace DB;
 namespace local_eingine
 {
 
-class SparkFunctionUnixTimestamp : public FunctionToUnixTimestamp
+class SparkFunctionDateToUnixTimestamp : public IFunction
 {
 public:
-    static constexpr auto name = "sparkToUnixTimestamp";
-    static FunctionPtr create(ContextPtr) { return 
std::make_shared<SparkFunctionUnixTimestamp>(); }
-    SparkFunctionUnixTimestamp()
+    static constexpr auto name = "sparkDateToUnixTimestamp";
+    static FunctionPtr create(ContextPtr) { return 
std::make_shared<SparkFunctionDateToUnixTimestamp>(); }
+    SparkFunctionDateToUnixTimestamp()
     {
         const DateLUTImpl * date_lut = &DateLUT::instance("UTC");
         UInt32 utc_timestamp = static_cast<UInt32>(0);
@@ -45,21 +50,24 @@ public:
         UInt32 unix_timestamp = date_time.to_time_t();
         delta_timestamp_from_utc = unix_timestamp - utc_timestamp;
     }
-    ~SparkFunctionUnixTimestamp() override = default;
+    ~SparkFunctionDateToUnixTimestamp() override = default;
     String getName() const override { return name; }
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DB::DataTypesWithConstInfo &) const override { return true; }
+    size_t getNumberOfArguments() const override { return 0; }
+    bool isVariadic() const override { return true; }
+    bool useDefaultImplementationForConstants() const override { return true; }
+    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const 
override
+    {
+        return std::make_shared<DataTypeUInt32>();
+    }
 
     ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const 
DataTypePtr & result_type, size_t input_rows) const override
     {
-        if (arguments.size() != 1 && arguments.size() != 2)
+       if (arguments.size() != 1 && arguments.size() != 2)
             throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} argument size must be 1 or 2", name);
         
-        ColumnWithTypeAndName first_arg = arguments[0];
-        
-        if (!isDateOrDate32(first_arg.type))
-        {
-            return FunctionToUnixTimestamp::executeImpl(arguments, 
result_type, input_rows);
-        }
-        else if (isDate(first_arg.type))
+       ColumnWithTypeAndName first_arg = arguments[0];
+       if (isDate(first_arg.type))
             return executeInternal<UInt16>(first_arg.column, input_rows);
         else
             return executeInternal<Int32>(first_arg.column, input_rows);
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp
index 0b963e769..3a25e383d 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp
@@ -17,8 +17,13 @@
 #include <Common/LocalDate.h>
 #include <Common/DateLUT.h>
 #include <Common/DateLUTImpl.h>
+#include <Columns/ColumnString.h>
+#include <Columns/ColumnVector.h>
+#include <Columns/ColumnsNumber.h>
+#include <Columns/ColumnNullable.h>
+#include <DataTypes/IDataType.h>
 #include <DataTypes/DataTypeDate32.h>
-#include <Functions/FunctionsConversion.h>
+#include <DataTypes/DataTypeNullable.h>
 #include <Functions/FunctionFactory.h>
 #include <IO/ReadBufferFromMemory.h>
 #include <IO/ReadHelpers.h>
@@ -35,14 +40,18 @@ namespace ErrorCodes
 
 namespace local_engine
 {
-class SparkFunctionConvertToDate : public DB::FunctionToDate32OrNull
+class SparkFunctionConvertToDate : public DB::IFunction
 {
 public:
     static constexpr auto name = "sparkToDate";
     static DB::FunctionPtr create(DB::ContextPtr) { return 
std::make_shared<SparkFunctionConvertToDate>(); }
     SparkFunctionConvertToDate() = default;
     ~SparkFunctionConvertToDate() override = default;
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DB::DataTypesWithConstInfo &) const override { return true; }
+    size_t getNumberOfArguments() const override { return 0; }
     String getName() const override { return name; }
+    bool isVariadic() const override { return true; }
+    bool useDefaultImplementationForConstants() const override { return true; }
 
     bool checkAndGetDate32(DB::ReadBuffer & buf, DB::DataTypeDate32::FieldType 
&x, const DateLUTImpl & date_lut) const
     {
@@ -99,6 +108,12 @@ public:
         }
     }
 
+    DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &) 
const override
+    {
+        DB::DataTypePtr date32_type = std::make_shared<DB::DataTypeDate32>();
+        return makeNullable(date32_type);
+    }
+
     DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, 
const DB::DataTypePtr & result_type, size_t) const override
     {
         if (arguments.size() != 1)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h 
b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
index d185b850f..ae9ebc6e7 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
@@ -18,9 +18,12 @@
 #include <Common/DateLUT.h>
 #include <Common/DateLUTImpl.h>
 #include <Columns/ColumnsDateTime.h>
+#include <Columns/ColumnNullable.h>
+#include <Columns/ColumnsNumber.h>
 #include <DataTypes/DataTypeDateTime64.h>
-#include <Functions/FunctionsConversion.h>
 #include <Functions/FunctionFactory.h>
+#include <Functions/FunctionHelpers.h>
+#include <Functions/extractTimeZoneFromFunctionArguments.h>
 #include <IO/ReadBufferFromMemory.h>
 #include <IO/parseDateTimeBestEffort.h>
 #include <IO/ReadHelpers.h>
@@ -40,13 +43,17 @@ namespace ErrorCodes
 
 namespace local_engine
 {
-class SparkFunctionConvertToDateTime : public DB::FunctionToDateTime64OrNull
+class SparkFunctionConvertToDateTime : public IFunction
 {
 public:
     static constexpr auto name = "sparkToDateTime";
     static DB::FunctionPtr create(DB::ContextPtr) { return 
std::make_shared<SparkFunctionConvertToDateTime>(); }
     SparkFunctionConvertToDateTime() = default;
     ~SparkFunctionConvertToDateTime() override = default;
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DataTypesWithConstInfo &) const override { return true; }
+    size_t getNumberOfArguments() const override { return 0; }
+    bool isVariadic() const override { return true; }
+    bool useDefaultImplementationForConstants() const override { return true; }
     String getName() const override { return name; }
 
     bool checkDateTimeFormat(DB::ReadBuffer & buf, size_t buf_size, UInt8 & 
can_be_parsed) const
@@ -109,11 +116,26 @@ public:
         return true;
     }
 
+    inline UInt32 extractDecimalScale(const ColumnWithTypeAndName & 
named_column) const
+    {
+        const auto * arg_type = named_column.type.get();
+        bool ok = checkAndGetDataType<DataTypeUInt64>(arg_type)
+            || checkAndGetDataType<DataTypeUInt32>(arg_type)
+            || checkAndGetDataType<DataTypeUInt16>(arg_type)
+            || checkAndGetDataType<DataTypeUInt8>(arg_type);
+        if (!ok)
+            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal 
type of toDecimal() scale {}", named_column.type->getName());
+
+        Field field;
+        named_column.column->get(0, field);
+        return static_cast<UInt32>(field.get<UInt32>());
+    }
+
     DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & 
arguments) const override
     {
         UInt32 scale = 6;
         if (arguments.size() > 1)
-            scale = extractToDecimalScale(arguments[1]);
+            scale = extractDecimalScale(arguments[1]);
         const auto timezone = 
extractTimeZoneNameFromFunctionArguments(arguments, 2, 0, false);
         return makeNullable(std::make_shared<DataTypeDateTime64>(scale, 
timezone));
     }
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp
index 41b0864d5..84a7d394e 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp
@@ -67,7 +67,7 @@ public:
         if (isString(expr_type))
             result_node = toFunctionNode(actions_dag, 
"parseDateTimeInJodaSyntaxOrNull", {expr_arg, fmt_arg, time_zone_node});
         else if (isDateOrDate32(expr_type))
-            result_node = toFunctionNode(actions_dag, "sparkToUnixTimestamp", 
{expr_arg, time_zone_node});
+            result_node = toFunctionNode(actions_dag, 
"sparkDateToUnixTimestamp", {expr_arg, time_zone_node});
         else if (isDateTime(expr_type) || isDateTime64(expr_type))
             result_node = toFunctionNode(actions_dag, "toUnixTimestamp", 
{expr_arg, time_zone_node});
         else
diff --git a/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp 
b/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp
index 0d0cad804..01153c91e 100644
--- a/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp
+++ b/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp
@@ -20,7 +20,6 @@
 #include <DataTypes/IDataType.h>
 #include <DataTypes/DataTypeFactory.h>
 #include <Functions/FunctionFactory.h>
-#include <Functions/FunctionsConversion.h>
 #include <Functions/SparkFunctionToDateTime.h>
 #include <Parser/SerializedPlanParser.h>
 #include <Parser/FunctionParser.h>
diff --git a/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp 
b/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp
index bad233d27..13333f0c4 100644
--- a/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp
+++ b/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp
@@ -21,7 +21,7 @@
 #include <DataTypes/DataTypeFactory.h>
 #include <Functions/FunctionFactory.h>
 #include <Functions/FunctionsRound.h>
-#include <Functions/SparkFunctionUnixTimestamp.h>
+#include <Functions/SparkFunctionDateToUnixTimestamp.h>
 #include <Parser/SerializedPlanParser.h>
 #include <Parser/FunctionParser.h>
 #include <benchmark/benchmark.h>
@@ -74,7 +74,7 @@ static void BM_SparkUnixTimestamp_For_Date32(benchmark::State 
& state)
 {
     using namespace DB;
     auto & factory = FunctionFactory::instance();
-    auto function = factory.get("sparkToUnixTimestamp", 
local_engine::SerializedPlanParser::global_context);
+    auto function = factory.get("sparkDateToUnixTimestamp", 
local_engine::SerializedPlanParser::global_context);
     Block block = createDataBlock("Date32", 30000000);
     auto executable = function->build(block.getColumnsWithTypeAndName());
     for (auto _ : state)[[maybe_unused]]
@@ -85,7 +85,7 @@ static void BM_SparkUnixTimestamp_For_Date(benchmark::State & 
state)
 {
     using namespace DB;
     auto & factory = FunctionFactory::instance();
-    auto function = factory.get("sparkToUnixTimestamp", 
local_engine::SerializedPlanParser::global_context);
+    auto function = factory.get("sparkDateToUnixTimestamp", 
local_engine::SerializedPlanParser::global_context);
     Block block = createDataBlock("Date", 30000000);
     auto executable = function->build(block.getColumnsWithTypeAndName());
     for (auto _ : state)[[maybe_unused]]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to