This is an automated email from the ASF dual-hosted git repository.

lwz9103 pushed a commit to branch liquid
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git

commit d5d0ccac71acbd823c85401163787444987ad322
Author: loneylee <[email protected]>
AuthorDate: Fri Mar 29 15:55:42 2024 +0800

    fix kylintimestampadd (support unit is not constant)
    
    fix timestampadd datatypenullable
    
    (cherry picked from commit 2e1808e5eaf815d666ff1f57e575e20190f6389b)
---
 .../Parser/scalar_function_parser/timestampAdd.cpp | 139 +++++++++++++++------
 1 file changed, 102 insertions(+), 37 deletions(-)

diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp
index caf6777f88..14a2117774 100644
--- a/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/timestampAdd.cpp
@@ -15,7 +15,9 @@
  * limitations under the License.
  */
 
-#include <DataTypes/DataTypeString.h>
+#include <DataTypes/DataTypeDateTime64.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypesNumber.h>
 #include <Parser/FunctionParser.h>
 #include <Common/DateLUTImpl.h>
 
@@ -32,6 +34,29 @@ extern const int ILLEGAL_TYPE_OF_ARGUMENT;
 
 namespace local_engine
 {
+static const std::map<std::string, std::string> KE_UNIT_TO_CH_FUNCTION
+    = {{"FRAC_SECOND", "addMicroseconds"},
+       {"SQL_TSI_FRAC_SECOND", "addMicroseconds"},
+       {"MICROSECOND", "addMicroseconds"},
+       {"MILLISECOND", "addMilliseconds"},
+       {"SECOND", "addSeconds"},
+       {"SQL_TSI_SECOND", "addSeconds"},
+       {"MINUTE", "addMinutes"},
+       {"SQL_TSI_MINUTE", "addMinutes"},
+       {"HOUR", "addHours"},
+       {"SQL_TSI_HOUR", "addHours"},
+       {"DAY", "addDays"},
+       {"DAYOFYEAR", "addDays"},
+       {"SQL_TSI_DAY", "addDays"},
+       {"WEEK", "addWeeks"},
+       {"SQL_TSI_WEEK", "addWeeks"},
+       {"MONTH", "addMonths"},
+       {"SQL_TSI_MONTH", "addMonths"},
+       {"QUARTER", "addQuarters"},
+       {"SQL_TSI_QUARTER", "addQuarters"},
+       {"YEAR", "addYears"},
+       {"SQL_TSI_YEAR", "addYears"}};
+
 
 class FunctionParserTimestampAdd : public FunctionParser
 {
@@ -44,62 +69,102 @@ public:
     String getName() const override { return name; }
     String getCHFunctionName(const substrait::Expression_ScalarFunction &) 
const override { return "timestamp_add"; }
 
-    const DB::ActionsDAG::Node * parse(const 
substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & 
actions_dag) const override
+    const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction 
& substrait_func, ActionsDAG & actions_dag) const override
     {
         auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
         if (parsed_args.size() < 3)
-            throw 
DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} 
requires at least three arguments", getName());
-
-        const auto & unit_field = substrait_func.arguments().at(0);
-        if (!unit_field.value().has_literal() || 
!unit_field.value().literal().has_string())
-            throw DB::Exception(
-                DB::ErrorCodes::BAD_ARGUMENTS, "Unsupported unit argument, 
should be a string literal, but: {}", unit_field.DebugString());
+            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} requires at least three arguments", getName());
 
         String timezone;
         if (parsed_args.size() == 4)
         {
             const auto & timezone_field = substrait_func.arguments().at(3);
             if (!timezone_field.value().has_literal() || 
!timezone_field.value().literal().has_string())
-            throw DB::Exception(
-                DB::ErrorCodes::BAD_ARGUMENTS,
+            throw Exception(
+                ErrorCodes::BAD_ARGUMENTS,
                 "Unsupported timezone_field argument, should be a string 
literal, but: {}",
                 timezone_field.DebugString());
             timezone = timezone_field.value().literal().string();
         }
 
+        const auto & unit_field = substrait_func.arguments().at(0);
+
+        return unit_field.value().has_literal() ? 
parseLiteralFunction(substrait_func, parsed_args, actions_dag, timezone)
+                                                : 
parseOtherFunction(substrait_func, parsed_args, actions_dag, timezone);
+    }
+
+    const ActionsDAG::Node * parseLiteralFunction(
+        const substrait::Expression_ScalarFunction & substrait_func,
+        const ActionsDAG::NodeRawConstPtrs & parsed_args,
+        ActionsDAGPtr & actions_dag,
+        const String & timezone) const
+    {
+        const auto & unit_field = substrait_func.arguments().at(0);
         const auto & unit = 
Poco::toUpper(unit_field.value().literal().string());
 
-        std::string ch_function_name;
-        if (unit == "MICROSECOND")
-            ch_function_name = "addMicroseconds";
-        else if (unit == "MILLISECOND")
-            ch_function_name = "addMilliseconds";
-        else if (unit == "SECOND")
-            ch_function_name = "addSeconds";
-        else if (unit == "MINUTE")
-            ch_function_name = "addMinutes";
-        else if (unit == "HOUR")
-            ch_function_name = "addHours";
-        else if (unit == "DAY" || unit == "DAYOFYEAR")
-            ch_function_name = "addDays";
-        else if (unit == "WEEK")
-            ch_function_name = "addWeeks";
-        else if (unit == "MONTH")
-            ch_function_name = "addMonths";
-        else if (unit == "QUARTER")
-            ch_function_name = "addQuarters";
-        else if (unit == "YEAR")
-            ch_function_name = "addYears";
-        else
-            throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unsupported 
unit argument: {}", unit);
+        if (!KE_UNIT_TO_CH_FUNCTION.contains(unit))
+            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unsupported unit 
argument: {}", unit);
 
+        const std::string & ch_function_name = KE_UNIT_TO_CH_FUNCTION.at(unit);
+        ActionsDAG::NodeRawConstPtrs args = {parsed_args[2], parsed_args[1]};
         if (timezone.empty())
-            timezone = DateLUT::instance().getTimeZone();
+        {
+            const ActionsDAG::Node * result_node = toFunctionNode(actions_dag, 
ch_function_name, args);
+            return convertNodeTypeIfNeeded(substrait_func, result_node, 
actions_dag);
+        }
+
+        const auto * time_zone_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeString>(), timezone);
+        if (isDateTimeOrDateTime64(parsed_args[2]->result_type))
+        {
+            args.emplace_back(time_zone_node);
+            const ActionsDAG::Node * result_node = toFunctionNode(actions_dag, 
ch_function_name, args);
+            return convertNodeTypeIfNeeded(substrait_func, result_node, 
actions_dag);
+        }
+
+        const ActionsDAG::Node * result_node = toFunctionNode(actions_dag, 
ch_function_name, args);
+        const auto * scale_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeUInt32>(), 6);
+        return convertNodeTypeIfNeeded(
+            substrait_func, toFunctionNode(actions_dag, "toDateTime64", 
{result_node, scale_node, time_zone_node}), actions_dag);
+    }
+
 
-        const auto * time_zone_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DB::DataTypeString>(), timezone);
-        const DB::ActionsDAG::Node * result_node
-            = toFunctionNode(actions_dag, ch_function_name, {parsed_args[2], 
parsed_args[1], time_zone_node});
+    const ActionsDAG::Node * parseOtherFunction(
+        const substrait::Expression_ScalarFunction & substrait_func,
+        const ActionsDAG::NodeRawConstPtrs & parsed_args,
+        ActionsDAGPtr & actions_dag,
+        const String & timezone) const
+    {
+        const DB::ActionsDAG::Node * timezone_node;
+        if (!timezone.empty())
+            timezone_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeString>(), timezone);
+
+        const auto * scale_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeUInt32>(), 6);
+        const auto * from_node = parsed_args[2];
+        if (!isDateTimeOrDateTime64(from_node->result_type))
+        {
+            ActionsDAG::NodeRawConstPtrs from_node_convert_args = {from_node, 
scale_node};
+            if (!timezone.empty())
+                from_node_convert_args.emplace_back(timezone_node);
+
+            from_node = toFunctionNode(actions_dag, "toDateTime64", 
from_node_convert_args);
+        }
+
+        ActionsDAG::NodeRawConstPtrs multi_if_args;
+        ActionsDAG::NodeRawConstPtrs timestamp_add_args = {from_node, 
parsed_args[1]};
+        if (!timezone.empty())
+            timestamp_add_args.emplace_back(timezone_node);
+
+        for (const auto & ke_unit_to_ch_function : KE_UNIT_TO_CH_FUNCTION)
+        {
+            const auto * unit_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeString>(), ke_unit_to_ch_function.first);
+            ActionsDAG::NodeRawConstPtrs mutiif_args = {parsed_args[0], 
unit_node};
+            multi_if_args.emplace_back(toFunctionNode(actions_dag, "equals", 
mutiif_args));
+            multi_if_args.emplace_back(toFunctionNode(actions_dag, 
ke_unit_to_ch_function.second, timestamp_add_args));
+        }
 
+        multi_if_args.emplace_back(
+            addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeDateTime64>(6)), 
Field{}));
+        const ActionsDAG::Node * result_node = toFunctionNode(actions_dag, 
"multiIf", multi_if_args);
         return convertNodeTypeIfNeeded(substrait_func, result_node, 
actions_dag);
     }
 };


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to