Repository: spark Updated Branches: refs/heads/branch-2.1 b2c9a2c8c -> 2c2ca8943
[SPARK-19178][SQL] convert string of large numbers to int should return null ## What changes were proposed in this pull request? When we convert a string to integral, we will convert that string to `decimal(20, 0)` first, so that we can turn a string with decimal format to truncated integral, e.g. `CAST('1.2' AS int)` will return `1`. However, this brings problems when we convert a string with large numbers to integral, e.g. `CAST('1234567890123' AS int)` will return `1912276171`, while Hive returns null as we expected. This is a long standing bug(seems it was there the first day Spark SQL was created), this PR fixes this bug by adding the native support to convert `UTF8String` to integral. ## How was this patch tested? new regression tests Author: Wenchen Fan <wenc...@databricks.com> Closes #16550 from cloud-fan/string-to-int. (cherry picked from commit 6b34e745bb8bdcf5a8bb78359fa39bbe8c6563cc) Signed-off-by: Wenchen Fan <wenc...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2c2ca894 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2c2ca894 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2c2ca894 Branch: refs/heads/branch-2.1 Commit: 2c2ca8943c4355af491ec19fe6d13949182260ab Parents: b2c9a2c Author: Wenchen Fan <wenc...@databricks.com> Authored: Thu Jan 12 22:52:34 2017 -0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Jan 13 18:44:45 2017 +0800 ---------------------------------------------------------------------- .../apache/spark/unsafe/types/UTF8String.java | 184 +++++++++++++++++++ .../sql/catalyst/analysis/TypeCoercion.scala | 16 -- .../spark/sql/catalyst/expressions/Cast.scala | 18 +- .../test/resources/sql-tests/inputs/cast.sql | 43 +++++ .../resources/sql-tests/results/cast.sql.out | 178 ++++++++++++++++++ 5 files changed, 414 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2c2ca894/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java ---------------------------------------------------------------------- diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e09a6b7..b03e718 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -816,6 +816,190 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, return fromString(sb.toString()); } + private int getDigit(byte b) { + if (b >= '0' && b <= '9') { + return b - '0'; + } + throw new NumberFormatException(toString()); + } + + /** + * Parses this UTF8String to long. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and + * Integer.MIN_VALUE is '-2147483648'. + * + * This code is mostly copied from LazyLong.parseLong in Hive. + */ + public long toLong() { + if (numBytes == 0) { + throw new NumberFormatException("Empty string"); + } + + byte b = getByte(0); + final boolean negative = b == '-'; + int offset = 0; + if (negative || b == '+') { + offset++; + if (numBytes == 1) { + throw new NumberFormatException(toString()); + } + } + + final byte separator = '.'; + final int radix = 10; + final long stopValue = Long.MIN_VALUE / radix; + long result = 0; + + while (offset < numBytes) { + b = getByte(offset); + offset++; + if (b == separator) { + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below.) + break; + } + + int digit = getDigit(b); + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + if (result < stopValue) { + throw new NumberFormatException(toString()); + } + + result = result * radix - digit; + // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we + // can just use `result > 0` to check overflow. If result overflows, we should stop and throw + // exception. + if (result > 0) { + throw new NumberFormatException(toString()); + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset < numBytes) { + if (getDigit(getByte(offset)) == -1) { + throw new NumberFormatException(toString()); + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException(toString()); + } + } + + return result; + } + + /** + * Parses this UTF8String to int. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and + * Integer.MIN_VALUE is '-2147483648'. + * + * This code is mostly copied from LazyInt.parseInt in Hive. + * + * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance + * reasons, like Hive does. + */ + public int toInt() { + if (numBytes == 0) { + throw new NumberFormatException("Empty string"); + } + + byte b = getByte(0); + final boolean negative = b == '-'; + int offset = 0; + if (negative || b == '+') { + offset++; + if (numBytes == 1) { + throw new NumberFormatException(toString()); + } + } + + final byte separator = '.'; + final int radix = 10; + final int stopValue = Integer.MIN_VALUE / radix; + int result = 0; + + while (offset < numBytes) { + b = getByte(offset); + offset++; + if (b == separator) { + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below.) + break; + } + + int digit = getDigit(b); + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + if (result < stopValue) { + throw new NumberFormatException(toString()); + } + + result = result * radix - digit; + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), + // we can just use `result > 0` to check overflow. If result overflows, we should stop and + // throw exception. + if (result > 0) { + throw new NumberFormatException(toString()); + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset < numBytes) { + if (getDigit(getByte(offset)) == -1) { + throw new NumberFormatException(toString()); + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException(toString()); + } + } + + return result; + } + + public short toShort() { + int intValue = toInt(); + short result = (short) intValue; + if (result != intValue) { + throw new NumberFormatException(toString()); + } + + return result; + } + + public byte toByte() { + int intValue = toInt(); + byte result = (byte) intValue; + if (result != intValue) { + throw new NumberFormatException(toString()); + } + + return result; + } + @Override public String toString() { return new String(getBytes(), StandardCharsets.UTF_8); http://git-wip-us.apache.org/repos/asf/spark/blob/2c2ca894/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 6662a9e..6d9799f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -51,7 +51,6 @@ object TypeCoercion { PromoteStrings :: DecimalPrecision :: BooleanEquality :: - StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: IfCoercion :: @@ -429,21 +428,6 @@ object TypeCoercion { } /** - * When encountering a cast from a string representing a valid fractional number to an integral - * type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the - * truncated version of this number. - */ - object StringToIntegralCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { - // Skip nodes who's children have not been resolved yet. - case e if !e.childrenResolved => e - - case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType.forType(LongType)), t) - } - } - - /** * This ensure that the types for various functions are as expected. */ object FunctionArgumentConversion extends Rule[LogicalPlan] { http://git-wip-us.apache.org/repos/asf/spark/blob/2c2ca894/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4db1ae6..f15ae32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -247,7 +247,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toLong catch { + buildCast[UTF8String](_, s => try s.toLong catch { case _: NumberFormatException => null }) case BooleanType => @@ -263,7 +263,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toInt catch { + buildCast[UTF8String](_, s => try s.toInt catch { case _: NumberFormatException => null }) case BooleanType => @@ -279,7 +279,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toShort catch { + buildCast[UTF8String](_, s => try s.toShort catch { case _: NumberFormatException => null }) case BooleanType => @@ -295,7 +295,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toString.toByte catch { + buildCast[UTF8String](_, s => try s.toByte catch { case _: NumberFormatException => null }) case BooleanType => @@ -498,7 +498,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w s""" boolean $resultNull = $childNull; ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; - if (!${childNull}) { + if (!$childNull) { ${cast(childPrim, resultPrim, resultNull)} } """ @@ -705,7 +705,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" try { - $evPrim = Byte.valueOf($c.toString()); + $evPrim = $c.toByte(); } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -727,7 +727,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" try { - $evPrim = Short.valueOf($c.toString()); + $evPrim = $c.toShort(); } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -749,7 +749,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" try { - $evPrim = Integer.valueOf($c.toString()); + $evPrim = $c.toInt(); } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -771,7 +771,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s""" try { - $evPrim = Long.valueOf($c.toString()); + $evPrim = $c.toLong(); } catch (java.lang.NumberFormatException e) { $evNull = true; } http://git-wip-us.apache.org/repos/asf/spark/blob/2c2ca894/sql/core/src/test/resources/sql-tests/inputs/cast.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql new file mode 100644 index 0000000..5fae571 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql @@ -0,0 +1,43 @@ +-- cast string representing a valid fractional number to integral should truncate the number +SELECT CAST('1.23' AS int); +SELECT CAST('1.23' AS long); +SELECT CAST('-4.56' AS int); +SELECT CAST('-4.56' AS long); + +-- cast string which are not numbers to integral should return null +SELECT CAST('abc' AS int); +SELECT CAST('abc' AS long); + +-- cast string representing a very large number to integral should return null +SELECT CAST('1234567890123' AS int); +SELECT CAST('12345678901234567890123' AS long); + +-- cast empty string to integral should return null +SELECT CAST('' AS int); +SELECT CAST('' AS long); + +-- cast null to integral should return null +SELECT CAST(NULL AS int); +SELECT CAST(NULL AS long); + +-- cast invalid decimal string to integral should return null +SELECT CAST('123.a' AS int); +SELECT CAST('123.a' AS long); + +-- '-2147483648' is the smallest int value +SELECT CAST('-2147483648' AS int); +SELECT CAST('-2147483649' AS int); + +-- '2147483647' is the largest int value +SELECT CAST('2147483647' AS int); +SELECT CAST('2147483648' AS int); + +-- '-9223372036854775808' is the smallest long value +SELECT CAST('-9223372036854775808' AS long); +SELECT CAST('-9223372036854775809' AS long); + +-- '9223372036854775807' is the largest long value +SELECT CAST('9223372036854775807' AS long); +SELECT CAST('9223372036854775808' AS long); + +-- TODO: migrate all cast tests here. http://git-wip-us.apache.org/repos/asf/spark/blob/2c2ca894/sql/core/src/test/resources/sql-tests/results/cast.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out new file mode 100644 index 0000000..bfa29d7 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out @@ -0,0 +1,178 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 22 + + +-- !query 0 +SELECT CAST('1.23' AS int) +-- !query 0 schema +struct<CAST(1.23 AS INT):int> +-- !query 0 output +1 + + +-- !query 1 +SELECT CAST('1.23' AS long) +-- !query 1 schema +struct<CAST(1.23 AS BIGINT):bigint> +-- !query 1 output +1 + + +-- !query 2 +SELECT CAST('-4.56' AS int) +-- !query 2 schema +struct<CAST(-4.56 AS INT):int> +-- !query 2 output +-4 + + +-- !query 3 +SELECT CAST('-4.56' AS long) +-- !query 3 schema +struct<CAST(-4.56 AS BIGINT):bigint> +-- !query 3 output +-4 + + +-- !query 4 +SELECT CAST('abc' AS int) +-- !query 4 schema +struct<CAST(abc AS INT):int> +-- !query 4 output +NULL + + +-- !query 5 +SELECT CAST('abc' AS long) +-- !query 5 schema +struct<CAST(abc AS BIGINT):bigint> +-- !query 5 output +NULL + + +-- !query 6 +SELECT CAST('1234567890123' AS int) +-- !query 6 schema +struct<CAST(1234567890123 AS INT):int> +-- !query 6 output +NULL + + +-- !query 7 +SELECT CAST('12345678901234567890123' AS long) +-- !query 7 schema +struct<CAST(12345678901234567890123 AS BIGINT):bigint> +-- !query 7 output +NULL + + +-- !query 8 +SELECT CAST('' AS int) +-- !query 8 schema +struct<CAST( AS INT):int> +-- !query 8 output +NULL + + +-- !query 9 +SELECT CAST('' AS long) +-- !query 9 schema +struct<CAST( AS BIGINT):bigint> +-- !query 9 output +NULL + + +-- !query 10 +SELECT CAST(NULL AS int) +-- !query 10 schema +struct<CAST(NULL AS INT):int> +-- !query 10 output +NULL + + +-- !query 11 +SELECT CAST(NULL AS long) +-- !query 11 schema +struct<CAST(NULL AS BIGINT):bigint> +-- !query 11 output +NULL + + +-- !query 12 +SELECT CAST('123.a' AS int) +-- !query 12 schema +struct<CAST(123.a AS INT):int> +-- !query 12 output +NULL + + +-- !query 13 +SELECT CAST('123.a' AS long) +-- !query 13 schema +struct<CAST(123.a AS BIGINT):bigint> +-- !query 13 output +NULL + + +-- !query 14 +SELECT CAST('-2147483648' AS int) +-- !query 14 schema +struct<CAST(-2147483648 AS INT):int> +-- !query 14 output +-2147483648 + + +-- !query 15 +SELECT CAST('-2147483649' AS int) +-- !query 15 schema +struct<CAST(-2147483649 AS INT):int> +-- !query 15 output +NULL + + +-- !query 16 +SELECT CAST('2147483647' AS int) +-- !query 16 schema +struct<CAST(2147483647 AS INT):int> +-- !query 16 output +2147483647 + + +-- !query 17 +SELECT CAST('2147483648' AS int) +-- !query 17 schema +struct<CAST(2147483648 AS INT):int> +-- !query 17 output +NULL + + +-- !query 18 +SELECT CAST('-9223372036854775808' AS long) +-- !query 18 schema +struct<CAST(-9223372036854775808 AS BIGINT):bigint> +-- !query 18 output +-9223372036854775808 + + +-- !query 19 +SELECT CAST('-9223372036854775809' AS long) +-- !query 19 schema +struct<CAST(-9223372036854775809 AS BIGINT):bigint> +-- !query 19 output +NULL + + +-- !query 20 +SELECT CAST('9223372036854775807' AS long) +-- !query 20 schema +struct<CAST(9223372036854775807 AS BIGINT):bigint> +-- !query 20 output +9223372036854775807 + + +-- !query 21 +SELECT CAST('9223372036854775808' AS long) +-- !query 21 schema +struct<CAST(9223372036854775808 AS BIGINT):bigint> +-- !query 21 output +NULL --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org