This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new ab417e4935 [GLUTEN-7389] [CH] fix cast map to string diff with spark 
(#7393)
ab417e4935 is described below

commit ab417e4935c03d6328f284c84bcaadfccc827c7f
Author: shuai.xu <[email protected]>
AuthorDate: Mon Oct 14 10:45:28 2024 +0800

    [GLUTEN-7389] [CH] fix cast map to string diff with spark (#7393)
    
    What changes were proposed in this pull request?
    For cast(map("a", 1, "b", 2) as string), gluten ouputs {"a":1,"b":2} while 
spark outputs {a -> 1, b -> 2}.
    
    (Fixes: #7389)
    
    How was this patch tested?
    This patch was tested by unit tests.
---
 .../GlutenClickhouseFunctionSuite.scala            |  14 ++
 .../Functions/SparkFunctionMapToString.cpp         |  28 ++++
 .../Functions/SparkFunctionMapToString.h           | 169 +++++++++++++++++++++
 .../local-engine/Parser/SerializedPlanParser.cpp   |   8 +
 4 files changed, 219 insertions(+)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
index f16c897671..ce8761469c 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala
@@ -255,4 +255,18 @@ class GlutenClickhouseFunctionSuite extends 
GlutenClickHouseTPCHAbstractSuite {
     }
   }
 
+  test("GLUTEN-7389: cast map to string diff with spark") {
+    withTable("test_7389") {
+      sql("create table test_7389(a map<string, int>) using parquet")
+      sql("insert into test_7389 values(map('a', 1, 'b', 2))")
+      compareResultsAgainstVanillaSpark(
+        """
+          |select cast(a as string) from test_7389
+          |""".stripMargin,
+        true,
+        { _ => }
+      )
+    }
+  }
+
 }
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionMapToString.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionMapToString.cpp
new file mode 100644
index 0000000000..499f64e48e
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionMapToString.cpp
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <Functions/SparkFunctionMapToString.h>
+
+namespace local_engine
+{
+
+REGISTER_FUNCTION(SparkFunctionMapToString)
+{
+    factory.registerFunction<local_eingine::SparkFunctionMapToString>();
+}
+
+}
\ No newline at end of file
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionMapToString.h 
b/cpp-ch/local-engine/Functions/SparkFunctionMapToString.h
new file mode 100644
index 0000000000..444e4542c8
--- /dev/null
+++ b/cpp-ch/local-engine/Functions/SparkFunctionMapToString.h
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include <Columns/ColumnNullable.h>
+#include <Columns/ColumnStringHelpers.h>
+#include <Columns/ColumnMap.h>
+#include <Columns/ColumnsNumber.h>
+#include <DataTypes/DataTypeNullable.h>
+#include <DataTypes/DataTypesNumber.h>
+#include <DataTypes/DataTypeString.h>
+#include <Formats/FormatFactory.h>
+#include <Functions/FunctionFactory.h>
+#include <IO/WriteHelpers.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+    extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+using namespace DB;
+
+namespace local_eingine
+{
+
+class SparkFunctionMapToString : public DB::IFunction
+{
+public:
+    static constexpr auto name = "sparkCastMapToString";
+    static FunctionPtr create(ContextPtr context) { return 
std::make_shared<SparkFunctionMapToString>(context); }
+    explicit SparkFunctionMapToString(ContextPtr context_) : context(context_) 
{}
+    ~SparkFunctionMapToString() override = default;
+    String getName() const override { return name; }
+    size_t getNumberOfArguments() const override { return 3; }
+    bool isSuitableForShortCircuitArgumentsExecution(const 
DB::DataTypesWithConstInfo & /*arguments*/) const override { return true; }
+    bool useDefaultImplementationForNulls() const override { return false; }
+    bool useDefaultImplementationForLowCardinalityColumns() const { return 
false; }
+
+    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) 
const override
+    {
+        if (arguments.size() != 3)
+            throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} argument size must be 3", name);
+        
+        auto arg_type = DB::removeNullable(arguments[0].type);
+        if (!DB::WhichDataType(arg_type).isMap())
+        {
+            throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+                "Illegal type {} of argument[1] of function {}",
+                arguments[0].type->getName(), getName());
+        }
+
+        auto key_type = WhichDataType(removeNullable(arguments[1].type));
+        auto value_type = WhichDataType(removeNullable(arguments[2].type));
+        // Not support complex types in key or value
+       if (!key_type.isString() && !key_type.isNumber() && 
!value_type.isString() && !value_type.isNumber())
+        {
+            throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
+                "Cast MapToString not support {}, {} as key value",
+                arguments[1].type->getName(), arguments[2].type->getName());
+        }
+
+        return makeNullable(std::make_shared<DataTypeString>());
+    }
+
+    ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const 
DataTypePtr & result_type, size_t /*input_rows*/) const override
+    {
+        ColumnUInt8::MutablePtr null_map = nullptr;
+        if (const auto * col_nullable = 
checkAndGetColumn<ColumnNullable>(arguments[0].column.get()))
+        {
+            null_map = ColumnUInt8::create();
+            null_map->insertRangeFrom(col_nullable->getNullMapColumn(), 0, 
col_nullable->size());
+        }
+
+        const auto & col_with_type_and_name = columnGetNested(arguments[0]);
+        const IColumn & col_from = *col_with_type_and_name.column;
+
+        size_t size = col_from.size();
+        auto col_to = removeNullable(result_type)->createColumn();
+
+        {
+            FormatSettings format_settings = context ? 
getFormatSettings(context) : FormatSettings{};
+            ColumnStringHelpers::WriteHelper write_helper(
+                    assert_cast<ColumnString &>(*col_to),
+                    size);
+
+            auto & write_buffer = write_helper.getWriteBuffer();
+            auto key_serializer = arguments[1].type->getDefaultSerialization();
+            auto value_serializer = 
arguments[2].type->getDefaultSerialization();
+
+            for (size_t row = 0; row < size; ++row)
+            {
+                serializeInSparkStyle(
+                    col_from,
+                    row,
+                    write_buffer,
+                    format_settings,
+                    key_serializer,
+                    value_serializer);
+                write_helper.rowWritten();
+            }
+
+            write_helper.finalize();
+        }
+
+        if (result_type->isNullable() && null_map)
+            return ColumnNullable::create(std::move(col_to), 
std::move(null_map));
+        return col_to;
+    }
+
+private:
+    ContextPtr context;
+
+    void serializeInSparkStyle(
+        const IColumn & column,
+        size_t row_num, WriteBuffer & ostr,
+        const FormatSettings & settings,
+        const SerializationPtr& key_serializer,
+        const SerializationPtr& value_serializer) const
+    {
+
+        const auto & column_map = assert_cast<const ColumnMap &>(column);
+
+        const auto & nested_array = column_map.getNestedColumn();
+        const auto & nested_tuple = column_map.getNestedData();
+        const auto & offsets = nested_array.getOffsets();
+        const auto& key_column = nested_tuple.getColumn(0);
+        const auto& value_column = nested_tuple.getColumn(1);
+
+        size_t offset = offsets[row_num - 1];
+        size_t next_offset = offsets[row_num];
+
+        writeChar('{', ostr);
+        for (size_t i = offset; i < next_offset; ++i)
+        {
+            if (i != offset)
+            {
+                writeChar(',', ostr);
+                writeChar(' ', ostr);
+            }
+
+            key_serializer->serializeText(key_column, i, ostr, settings);
+            writeChar(' ', ostr);
+            writeChar('-', ostr);
+            writeChar('>', ostr);
+            writeChar(' ', ostr);
+            value_serializer->serializeText(value_column, i, ostr, settings);
+        }
+        writeChar('}', ostr);
+    }
+};
+
+}
diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp 
b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
index e8f1196a66..52b5709555 100644
--- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
+++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
@@ -1030,6 +1030,14 @@ const ActionsDAG::Node * 
SerializedPlanParser::parseExpression(ActionsDAG & acti
 
                 function_node = toFunctionNode(actions_dag, 
"checkDecimalOverflowSparkOrNull", args);
             }
+           else if (isMap(non_nullable_input_type) && 
isString(non_nullable_output_type))
+            {
+                // ISSUE-7389: spark cast(map to string) has different 
behavior with CH cast(map to string)
+                auto map_input_type = std::static_pointer_cast<const 
DataTypeMap>(non_nullable_input_type);
+                args.emplace_back(addColumn(actions_dag, 
map_input_type->getKeyType(), map_input_type->getKeyType()->getDefault()));
+                args.emplace_back(addColumn(actions_dag, 
map_input_type->getValueType(), map_input_type->getValueType()->getDefault()));
+                function_node = toFunctionNode(actions_dag, 
"sparkCastMapToString", args);
+            }
             else
             {
                 if (isString(non_nullable_input_type) && 
isInt(non_nullable_output_type))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to