This is an automated email from the ASF dual-hosted git repository.
liuneng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new b96ddb4a1 [GLUTEN-5580][CH]Fix cast to int exceed max (#5581)
b96ddb4a1 is described below
commit b96ddb4a1fd1e2bbb65473ea24c02d30f3e77fd9
Author: KevinyhZou <[email protected]>
AuthorDate: Tue May 7 16:27:29 2024 +0800
[GLUTEN-5580][CH]Fix cast to int exceed max (#5581)
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)
(Fixes: #5580)
How was this patch tested?
test by ut
---
.../GlutenClickHouseTPCHSaltNullParquetSuite.scala | 12 ++++++----
.../Functions/SparkFunctionCastFloatToInt.cpp | 27 ++++++++++++----------
.../Functions/SparkFunctionCastFloatToInt.h | 12 +++++-----
3 files changed, 28 insertions(+), 23 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
index cc2eebcab..866f0ffaa 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala
@@ -2094,13 +2094,15 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends
GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
- test("GLUTEN-3149: Fix convert exception of Inf to int") {
- val tbl_create_sql = "create table test_tbl_3149(a int, b int) using
parquet";
- val tbl_insert_sql = "insert into test_tbl_3149 values(1, 0)"
- val select_sql = "select cast(a * 1.0f/b as int) as x from test_tbl_3149
where a = 1"
+ test("GLUTEN-3149/GLUTEN-5580: Fix convert float to int") {
+ val tbl_create_sql = "create table test_tbl_3149(a int, b bigint) using
parquet";
+ val tbl_insert_sql = "insert into test_tbl_3149 values(1, 0), (2,
171396196666200)"
+ val select_sql_1 = "select cast(a * 1.0f/b as int) as x from test_tbl_3149
where a = 1"
+ val select_sql_2 = "select cast(b/100 as int) from test_tbl_3149 where a =
2"
spark.sql(tbl_create_sql)
spark.sql(tbl_insert_sql);
- compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
+ compareResultsAgainstVanillaSpark(select_sql_1, true, { _ => })
+ compareResultsAgainstVanillaSpark(select_sql_2, true, { _ => })
spark.sql("drop table test_tbl_3149")
}
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
index 322f9c08a..c378f9fbf 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.cpp
@@ -15,7 +15,10 @@
* limitations under the License.
*/
+#include <limits.h>
#include <base/types.h>
+#include <base/wide_integer.h>
+#include <base/wide_integer_impl.h>
#include <Functions/SparkFunctionCastFloatToInt.h>
using namespace DB;
@@ -36,18 +39,18 @@ struct NameToInt64 { static constexpr auto name =
"sparkCastFloatToInt64"; };
struct NameToInt128 { static constexpr auto name = "sparkCastFloatToInt128"; };
struct NameToInt256 { static constexpr auto name = "sparkCastFloatToInt256"; };
-using SparkFunctionCastFloatToInt8 =
local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8>;
-using SparkFunctionCastFloatToInt16 =
local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16>;
-using SparkFunctionCastFloatToInt32 =
local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32>;
-using SparkFunctionCastFloatToInt64 =
local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64>;
-using SparkFunctionCastFloatToInt128 =
local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128>;
-using SparkFunctionCastFloatToInt256 =
local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256>;
-using SparkFunctionCastFloatToUInt8 =
local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8>;
-using SparkFunctionCastFloatToUInt16 =
local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16>;
-using SparkFunctionCastFloatToUInt32 =
local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32>;
-using SparkFunctionCastFloatToUInt64 =
local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64>;
-using SparkFunctionCastFloatToUInt128 =
local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128>;
-using SparkFunctionCastFloatToUInt256 =
local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256>;
+using SparkFunctionCastFloatToInt8 =
local_engine::SparkFunctionCastFloatToInt<Int8, NameToInt8, INT8_MAX, INT8_MIN>;
+using SparkFunctionCastFloatToInt16 =
local_engine::SparkFunctionCastFloatToInt<Int16, NameToInt16, INT16_MAX,
INT16_MIN>;
+using SparkFunctionCastFloatToInt32 =
local_engine::SparkFunctionCastFloatToInt<Int32, NameToInt32, INT32_MAX,
INT32_MIN>;
+using SparkFunctionCastFloatToInt64 =
local_engine::SparkFunctionCastFloatToInt<Int64, NameToInt64, INT64_MAX,
INT64_MIN>;
+using SparkFunctionCastFloatToInt128 =
local_engine::SparkFunctionCastFloatToInt<Int128, NameToInt128,
std::numeric_limits<Int128>::max(), std::numeric_limits<Int128>::min()>;
+using SparkFunctionCastFloatToInt256 =
local_engine::SparkFunctionCastFloatToInt<Int256, NameToInt256,
std::numeric_limits<Int256>::max(), std::numeric_limits<Int256>::min()>;
+using SparkFunctionCastFloatToUInt8 =
local_engine::SparkFunctionCastFloatToInt<UInt8, NameToUInt8, UINT8_MAX, 0>;
+using SparkFunctionCastFloatToUInt16 =
local_engine::SparkFunctionCastFloatToInt<UInt16, NameToUInt16, UINT16_MAX, 0>;
+using SparkFunctionCastFloatToUInt32 =
local_engine::SparkFunctionCastFloatToInt<UInt32, NameToUInt32, UINT32_MAX, 0>;
+using SparkFunctionCastFloatToUInt64 =
local_engine::SparkFunctionCastFloatToInt<UInt64, NameToUInt64, UINT64_MAX, 0>;
+using SparkFunctionCastFloatToUInt128 =
local_engine::SparkFunctionCastFloatToInt<UInt128, NameToUInt128,
std::numeric_limits<UInt128>::max(), 0>;
+using SparkFunctionCastFloatToUInt256 =
local_engine::SparkFunctionCastFloatToInt<UInt256, NameToUInt256,
std::numeric_limits<UInt256>::max(), 0>;
REGISTER_FUNCTION(SparkFunctionCastToInt)
{
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
index 675db9e30..4522e0e7d 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
+++ b/cpp-ch/local-engine/Functions/SparkFunctionCastFloatToInt.h
@@ -40,7 +40,7 @@ namespace ErrorCodes
namespace local_engine
{
-template <typename T, typename Name>
+template <typename T, typename Name, T int_max_value, T int_min_value>
class SparkFunctionCastFloatToInt : public DB::IFunction
{
public:
@@ -74,7 +74,7 @@ public:
DB::ColumnPtr src_col = arguments[0].column;
size_t size = src_col->size();
- auto res_col = DB::ColumnVector<T>::create(size);
+ auto res_col = DB::ColumnVector<T>::create(size, 0);
auto null_map_col = DB::ColumnUInt8::create(size, 0);
switch(removeNullable(arguments[0].type)->getTypeId())
@@ -101,15 +101,15 @@ public:
{
F element = src_vec->getElement(i);
if (isNaN(element) || !isFinite(element))
- {
- data[i] = 0;
null_map_data[i] = 1;
- }
+ else if (element > int_max_value)
+ data[i] = int_max_value;
+ else if (element < int_min_value)
+ data[i] = int_min_value;
else
data[i] = static_cast<T>(element);
}
}
-
};
}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]