This is an automated email from the ASF dual-hosted git repository.

liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new b96ddb4a1 [GLUTEN-5580][CH]Fix cast to int exceed max (#5581)
b96ddb4a1 is described below

commit b96ddb4a1fd1e2bbb65473ea24c02d30f3e77fd9
Author: KevinyhZou <[email protected]>
AuthorDate: Tue May 7 16:27:29 2024 +0800

    [GLUTEN-5580][CH]Fix cast to int exceed max (#5581)
    
    What changes were proposed in this pull request?
    (Please fill in changes proposed in this fix)
    
    (Fixes: #5580)
    
    How was this patch tested?
    test by ut
---
 .../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 12 ++++++----
 .../Functions/SparkFunctionCastFloatToInt.cpp      | 27 ++++++++++++----------
 .../Functions/SparkFunctionCastFloatToInt.h        | 12 +++++-----
 3 files changed, 28 insertions(+), 23 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index cc2eebcab..866f0ffaa 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2094,13 +2094,15 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends 
GlutenClickHouseTPCHAbstr
     compareResultsAgainstVanillaSpark(sql, true, { _ => })
   }
 
-  test("GLUTEN-3149: Fix convert exception of Inf to int") {
-    val tbl_create_sql = "create table test_tbl_3149(a int, b int) using 
parquet";
-    val tbl_insert_sql = "insert into test_tbl_3149 values(1, 0)"
-    val select_sql = "select cast(a * 1.0f/b as int) as x from test_tbl_3149 
where a = 1"
+  test("GLUTEN-3149/GLUTEN-5580: Fix convert float to int") {
+    val tbl_create_sql = "create table test_tbl_3149(a int, b bigint) using 
parquet";
+    val tbl_insert_sql = "insert into test_tbl_3149 values(1, 0), (2, 
171396196666200)"
+    val select_sql_1 = "select cast(a * 1.0f/b as int) as x from test_tbl_3149 
where a = 1"
+    val select_sql_2 = "select cast(b/100 as int) from test_tbl_3149 where a = 
2"
     spark.sql(tbl_create_sql)
     spark.sql(tbl_insert_sql);
-    compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
+    compareResultsAgainstVanillaSpark(select_sql_1, true, { _ => })
+    compareResultsAgainstVanillaSpark(select_sql_2, true, { _ => })
     spark.sql("drop table test_tbl_3149")
   }
 
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
index 322f9c08a..c378f9fbf 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
@@ -15,7 +15,10 @@
  * limitations under the License.
  */
 
+#include <limits.h>
 #include <base/types.h>
+#include <base/wide_integer.h>
+#include <base/wide_integer_impl.h>
 #include <Functions/SparkFunctionCastFloatToInt.h>
 
 using namespace DB;
@@ -36,18 +39,18 @@ struct NameToInt64 { static constexpr auto name = 
"sparkCastFloatToInt64"; };
 struct NameToInt128 { static constexpr auto name = "sparkCastFloatToInt128"; };
 struct NameToInt256 { static constexpr auto name = "sparkCastFloatToInt256"; };
 
-using SparkFunctionCastFloatToInt8 = 
local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8>;
-using SparkFunctionCastFloatToInt16 = 
local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16>;
-using SparkFunctionCastFloatToInt32 = 
local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32>;
-using SparkFunctionCastFloatToInt64 = 
local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64>;
-using SparkFunctionCastFloatToInt128 = 
local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128>;
-using SparkFunctionCastFloatToInt256 = 
local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256>;
-using SparkFunctionCastFloatToUInt8 = 
local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8>;
-using SparkFunctionCastFloatToUInt16 = 
local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16>;
-using SparkFunctionCastFloatToUInt32 = 
local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32>;
-using SparkFunctionCastFloatToUInt64 = 
local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64>;
-using SparkFunctionCastFloatToUInt128 = 
local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128>;
-using SparkFunctionCastFloatToUInt256 = 
local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256>;
+using SparkFunctionCastFloatToInt8 = 
local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8, INT8_MAX, INT8_MIN>;
+using SparkFunctionCastFloatToInt16 = 
local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16, INT16_MAX, 
INT16_MIN>;
+using SparkFunctionCastFloatToInt32 = 
local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32, INT32_MAX, 
INT32_MIN>;
+using SparkFunctionCastFloatToInt64 = 
local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64, INT64_MAX, 
INT64_MIN>;
+using SparkFunctionCastFloatToInt128 = 
local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128, 
std::numeric_limits<Int128>::max(), std::numeric_limits<Int128>::min()>;
+using SparkFunctionCastFloatToInt256 = 
local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256, 
std::numeric_limits<Int256>::max(), std::numeric_limits<Int256>::min()>;
+using SparkFunctionCastFloatToUInt8 = 
local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8, UINT8_MAX, 0>;
+using SparkFunctionCastFloatToUInt16 = 
local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16, UINT16_MAX, 0>;
+using SparkFunctionCastFloatToUInt32 = 
local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32, UINT32_MAX, 0>;
+using SparkFunctionCastFloatToUInt64 = 
local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64, UINT64_MAX, 0>;
+using SparkFunctionCastFloatToUInt128 = 
local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128, 
std::numeric_limits<UInt128>::max(), 0>;
+using SparkFunctionCastFloatToUInt256 = 
local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256, 
std::numeric_limits<UInt256>::max(), 0>;
 
 REGISTER_FUNCTION(SparkFunctionCastToInt)
 {
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h 
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
index 675db9e30..4522e0e7d 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
@@ -40,7 +40,7 @@ namespace ErrorCodes
 namespace local_engine
 {
 
-template <typename T, typename Name>
+template <typename T, typename Name, T int_max_value, T int_min_value>
 class SparkFunctionCastFloatToInt : public DB::IFunction
 {
 public:
@@ -74,7 +74,7 @@ public:
         DB::ColumnPtr src_col = arguments[0].column;
         size_t size = src_col->size();
 
-        auto res_col = DB::ColumnVector<T>::create(size);
+        auto res_col = DB::ColumnVector<T>::create(size, 0);
         auto null_map_col = DB::ColumnUInt8::create(size, 0);
 
         switch(removeNullable(arguments[0].type)->getTypeId())
@@ -101,15 +101,15 @@ public:
         {
             F element = src_vec->getElement(i);
             if (isNaN(element) || !isFinite(element))
-            {
-                data[i] = 0;
                 null_map_data[i] = 1;
-            }
+            else if (element > int_max_value)
+                data[i] = int_max_value;
+            else if (element < int_min_value)
+                data[i] = int_min_value;
             else
                 data[i] = static_cast<T>(element);
         }
     }
-
 };
 
 }
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to