This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 81b023eda [GLUTEN-4994][CH]Fix function conversions (#4995)
81b023eda is described below
commit 81b023eda4a3915da5ccb694a0aaed5f65a816eb
Author: KevinyhZou <[email protected]>
AuthorDate: Mon Mar 18 19:27:18 2024 +0800
[GLUTEN-4994][CH]Fix function conversions (#4995)
* Fix functon conversions
* benchmark modify
---
...mp.cpp => SparkFunctionDateToUnixTimestamp.cpp} | 6 ++--
...estamp.h => SparkFunctionDateToUnixTimestamp.h} | 36 +++++++++++++---------
.../local-engine/Functions/SparkFunctionToDate.cpp | 19 ++++++++++--
.../Functions/SparkFunctionToDateTime.h | 28 +++++++++++++++--
.../scalar_function_parser/unixTimestamp.cpp | 2 +-
.../tests/benchmark_to_datetime_function.cpp | 1 -
.../tests/benchmark_unix_timestamp_function.cpp | 6 ++--
7 files changed, 71 insertions(+), 27 deletions(-)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp
similarity index 81%
rename from cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp
rename to cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp
index c3040c60d..e7ae17feb 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.cpp
@@ -15,14 +15,14 @@
* limitations under the License.
*/
-#include <Functions/SparkFunctionUnixTimestamp.h>
+#include <Functions/SparkFunctionDateToUnixTimestamp.h>
namespace local_engine
{
-REGISTER_FUNCTION(SparkFunctionUnixTimestamp)
+REGISTER_FUNCTION(SparkFunctionDateToUnixTimestamp)
{
- factory.registerFunction<local_eingine::SparkFunctionUnixTimestamp>();
+
factory.registerFunction<local_eingine::SparkFunctionDateToUnixTimestamp>();
}
}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h
b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h
similarity index 71%
rename from cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h
rename to cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h
index 356519afa..cdf0460e0 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionUnixTimestamp.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionDateToUnixTimestamp.h
@@ -15,8 +15,13 @@
* limitations under the License.
*/
+#include <Common/DateLUT.h>
#include <Common/DateLUTImpl.h>
-#include <Functions/FunctionsConversion.h>
+#include <Common/LocalDateTime.h>
+#include <Columns/ColumnVector.h>
+#include <DataTypes/IDataType.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
namespace DB
@@ -32,12 +37,12 @@ using namespace DB;
namespace local_eingine
{
-class SparkFunctionUnixTimestamp : public FunctionToUnixTimestamp
+class SparkFunctionDateToUnixTimestamp : public IFunction
{
public:
- static constexpr auto name = "sparkToUnixTimestamp";
- static FunctionPtr create(ContextPtr) { return
std::make_shared<SparkFunctionUnixTimestamp>(); }
- SparkFunctionUnixTimestamp()
+ static constexpr auto name = "sparkDateToUnixTimestamp";
+ static FunctionPtr create(ContextPtr) { return
std::make_shared<SparkFunctionDateToUnixTimestamp>(); }
+ SparkFunctionDateToUnixTimestamp()
{
const DateLUTImpl * date_lut = &DateLUT::instance("UTC");
UInt32 utc_timestamp = static_cast<UInt32>(0);
@@ -45,21 +50,24 @@ public:
UInt32 unix_timestamp = date_time.to_time_t();
delta_timestamp_from_utc = unix_timestamp - utc_timestamp;
}
- ~SparkFunctionUnixTimestamp() override = default;
+ ~SparkFunctionDateToUnixTimestamp() override = default;
String getName() const override { return name; }
+ bool isSuitableForShortCircuitArgumentsExecution(const
DB::DataTypesWithConstInfo &) const override { return true; }
+ size_t getNumberOfArguments() const override { return 0; }
+ bool isVariadic() const override { return true; }
+ bool useDefaultImplementationForConstants() const override { return true; }
+ DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &) const
override
+ {
+ return std::make_shared<DataTypeUInt32>();
+ }
ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const
DataTypePtr & result_type, size_t input_rows) const override
{
- if (arguments.size() != 1 && arguments.size() != 2)
+ if (arguments.size() != 1 && arguments.size() != 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} argument size must be 1 or 2", name);
- ColumnWithTypeAndName first_arg = arguments[0];
-
- if (!isDateOrDate32(first_arg.type))
- {
- return FunctionToUnixTimestamp::executeImpl(arguments,
result_type, input_rows);
- }
- else if (isDate(first_arg.type))
+ ColumnWithTypeAndName first_arg = arguments[0];
+ if (isDate(first_arg.type))
return executeInternal<UInt16>(first_arg.column, input_rows);
else
return executeInternal<Int32>(first_arg.column, input_rows);
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp
index 0b963e769..3a25e383d 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionToDate.cpp
@@ -17,8 +17,13 @@
#include <Common/LocalDate.h>
#include <Common/DateLUT.h>
#include <Common/DateLUTImpl.h>
+#include <Columns/ColumnString.h>
+#include <Columns/ColumnVector.h>
+#include <Columns/ColumnsNumber.h>
+#include <Columns/ColumnNullable.h>
+#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeDate32.h>
-#include <Functions/FunctionsConversion.h>
+#include <DataTypes/DataTypeNullable.h>
#include <Functions/FunctionFactory.h>
#include <IO/ReadBufferFromMemory.h>
#include <IO/ReadHelpers.h>
@@ -35,14 +40,18 @@ namespace ErrorCodes
namespace local_engine
{
-class SparkFunctionConvertToDate : public DB::FunctionToDate32OrNull
+class SparkFunctionConvertToDate : public DB::IFunction
{
public:
static constexpr auto name = "sparkToDate";
static DB::FunctionPtr create(DB::ContextPtr) { return
std::make_shared<SparkFunctionConvertToDate>(); }
SparkFunctionConvertToDate() = default;
~SparkFunctionConvertToDate() override = default;
+ bool isSuitableForShortCircuitArgumentsExecution(const
DB::DataTypesWithConstInfo &) const override { return true; }
+ size_t getNumberOfArguments() const override { return 0; }
String getName() const override { return name; }
+ bool isVariadic() const override { return true; }
+ bool useDefaultImplementationForConstants() const override { return true; }
bool checkAndGetDate32(DB::ReadBuffer & buf, DB::DataTypeDate32::FieldType
&x, const DateLUTImpl & date_lut) const
{
@@ -99,6 +108,12 @@ public:
}
}
+ DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName &)
const override
+ {
+ DB::DataTypePtr date32_type = std::make_shared<DB::DataTypeDate32>();
+ return makeNullable(date32_type);
+ }
+
DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments,
const DB::DataTypePtr & result_type, size_t) const override
{
if (arguments.size() != 1)
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
index d185b850f..ae9ebc6e7 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionToDateTime.h
@@ -18,9 +18,12 @@
#include <Common/DateLUT.h>
#include <Common/DateLUTImpl.h>
#include <Columns/ColumnsDateTime.h>
+#include <Columns/ColumnNullable.h>
+#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeDateTime64.h>
-#include <Functions/FunctionsConversion.h>
#include <Functions/FunctionFactory.h>
+#include <Functions/FunctionHelpers.h>
+#include <Functions/extractTimeZoneFromFunctionArguments.h>
#include <IO/ReadBufferFromMemory.h>
#include <IO/parseDateTimeBestEffort.h>
#include <IO/ReadHelpers.h>
@@ -40,13 +43,17 @@ namespace ErrorCodes
namespace local_engine
{
-class SparkFunctionConvertToDateTime : public DB::FunctionToDateTime64OrNull
+class SparkFunctionConvertToDateTime : public IFunction
{
public:
static constexpr auto name = "sparkToDateTime";
static DB::FunctionPtr create(DB::ContextPtr) { return
std::make_shared<SparkFunctionConvertToDateTime>(); }
SparkFunctionConvertToDateTime() = default;
~SparkFunctionConvertToDateTime() override = default;
+ bool isSuitableForShortCircuitArgumentsExecution(const
DataTypesWithConstInfo &) const override { return true; }
+ size_t getNumberOfArguments() const override { return 0; }
+ bool isVariadic() const override { return true; }
+ bool useDefaultImplementationForConstants() const override { return true; }
String getName() const override { return name; }
bool checkDateTimeFormat(DB::ReadBuffer & buf, size_t buf_size, UInt8 &
can_be_parsed) const
@@ -109,11 +116,26 @@ public:
return true;
}
+ inline UInt32 extractDecimalScale(const ColumnWithTypeAndName &
named_column) const
+ {
+ const auto * arg_type = named_column.type.get();
+ bool ok = checkAndGetDataType<DataTypeUInt64>(arg_type)
+ || checkAndGetDataType<DataTypeUInt32>(arg_type)
+ || checkAndGetDataType<DataTypeUInt16>(arg_type)
+ || checkAndGetDataType<DataTypeUInt8>(arg_type);
+ if (!ok)
+ throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal
type of toDecimal() scale {}", named_column.type->getName());
+
+ Field field;
+ named_column.column->get(0, field);
+ return static_cast<UInt32>(field.get<UInt32>());
+ }
+
DB::DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName &
arguments) const override
{
UInt32 scale = 6;
if (arguments.size() > 1)
- scale = extractToDecimalScale(arguments[1]);
+ scale = extractDecimalScale(arguments[1]);
const auto timezone =
extractTimeZoneNameFromFunctionArguments(arguments, 2, 0, false);
return makeNullable(std::make_shared<DataTypeDateTime64>(scale,
timezone));
}
diff --git
a/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp
b/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp
index 41b0864d5..84a7d394e 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/unixTimestamp.cpp
@@ -67,7 +67,7 @@ public:
if (isString(expr_type))
result_node = toFunctionNode(actions_dag,
"parseDateTimeInJodaSyntaxOrNull", {expr_arg, fmt_arg, time_zone_node});
else if (isDateOrDate32(expr_type))
- result_node = toFunctionNode(actions_dag, "sparkToUnixTimestamp",
{expr_arg, time_zone_node});
+ result_node = toFunctionNode(actions_dag,
"sparkDateToUnixTimestamp", {expr_arg, time_zone_node});
else if (isDateTime(expr_type) || isDateTime64(expr_type))
result_node = toFunctionNode(actions_dag, "toUnixTimestamp",
{expr_arg, time_zone_node});
else
diff --git a/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp
b/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp
index 0d0cad804..01153c91e 100644
--- a/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp
+++ b/cpp-ch/local-engine/tests/benchmark_to_datetime_function.cpp
@@ -20,7 +20,6 @@
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeFactory.h>
#include <Functions/FunctionFactory.h>
-#include <Functions/FunctionsConversion.h>
#include <Functions/SparkFunctionToDateTime.h>
#include <Parser/SerializedPlanParser.h>
#include <Parser/FunctionParser.h>
diff --git a/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp
b/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp
index bad233d27..13333f0c4 100644
--- a/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp
+++ b/cpp-ch/local-engine/tests/benchmark_unix_timestamp_function.cpp
@@ -21,7 +21,7 @@
#include <DataTypes/DataTypeFactory.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsRound.h>
-#include <Functions/SparkFunctionUnixTimestamp.h>
+#include <Functions/SparkFunctionDateToUnixTimestamp.h>
#include <Parser/SerializedPlanParser.h>
#include <Parser/FunctionParser.h>
#include <benchmark/benchmark.h>
@@ -74,7 +74,7 @@ static void BM_SparkUnixTimestamp_For_Date32(benchmark::State
& state)
{
using namespace DB;
auto & factory = FunctionFactory::instance();
- auto function = factory.get("sparkToUnixTimestamp",
local_engine::SerializedPlanParser::global_context);
+ auto function = factory.get("sparkDateToUnixTimestamp",
local_engine::SerializedPlanParser::global_context);
Block block = createDataBlock("Date32", 30000000);
auto executable = function->build(block.getColumnsWithTypeAndName());
for (auto _ : state)[[maybe_unused]]
@@ -85,7 +85,7 @@ static void BM_SparkUnixTimestamp_For_Date(benchmark::State &
state)
{
using namespace DB;
auto & factory = FunctionFactory::instance();
- auto function = factory.get("sparkToUnixTimestamp",
local_engine::SerializedPlanParser::global_context);
+ auto function = factory.get("sparkDateToUnixTimestamp",
local_engine::SerializedPlanParser::global_context);
Block block = createDataBlock("Date", 30000000);
auto executable = function->build(block.getColumnsWithTypeAndName());
for (auto _ : state)[[maybe_unused]]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]