Repository: spark
Updated Branches:
  refs/heads/branch-2.0 449231c65 -> f56819f9b


[SPARK-19178][SQL] convert string of large numbers to int should return null

## What changes were proposed in this pull request?

When we convert a string to integral, we will convert that string to 
`decimal(20, 0)` first, so that we can turn a string with decimal format to 
truncated integral, e.g. `CAST('1.2' AS int)` will return `1`.

However, this brings problems when we convert a string with large numbers to 
integral, e.g. `CAST('1234567890123' AS int)` will return `1912276171`, while 
Hive returns null as we expected.

This is a long standing bug(seems it was there the first day Spark SQL was 
created), this PR fixes this bug by adding the native support to convert 
`UTF8String` to integral.

## How was this patch tested?

new regression tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #16550 from cloud-fan/string-to-int.

(cherry picked from commit 6b34e745bb8bdcf5a8bb78359fa39bbe8c6563cc)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f56819f9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f56819f9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f56819f9

Branch: refs/heads/branch-2.0
Commit: f56819f9bacb2c3e13f148dbd86589b7248352bb
Parents: 449231c
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Thu Jan 12 22:52:34 2017 -0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Jan 13 18:48:21 2017 +0800

----------------------------------------------------------------------
 .../apache/spark/unsafe/types/UTF8String.java   | 184 +++++++++++++++++++
 .../sql/catalyst/analysis/TypeCoercion.scala    |  16 --
 .../spark/sql/catalyst/expressions/Cast.scala   |  18 +-
 .../test/resources/sql-tests/inputs/cast.sql    |  43 +++++
 .../resources/sql-tests/results/cast.sql.out    | 178 ++++++++++++++++++
 5 files changed, 414 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f56819f9/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index ce83516..d330c1d 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -816,6 +816,190 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
     return fromString(sb.toString());
   }
 
+  private int getDigit(byte b) {
+    if (b >= '0' && b <= '9') {
+      return b - '0';
+    }
+    throw new NumberFormatException(toString());
+  }
+
+  /**
+   * Parses this UTF8String to long.
+   *
+   * Note that, in this method we accumulate the result in negative format, 
and convert it to
+   * positive format at the end, if this string is not started with '-'. This 
is because min value
+   * is bigger than max value in digits, e.g. Integer.MAX_VALUE is 
'2147483647' and
+   * Integer.MIN_VALUE is '-2147483648'.
+   *
+   * This code is mostly copied from LazyLong.parseLong in Hive.
+   */
+  public long toLong() {
+    if (numBytes == 0) {
+      throw new NumberFormatException("Empty string");
+    }
+
+    byte b = getByte(0);
+    final boolean negative = b == '-';
+    int offset = 0;
+    if (negative || b == '+') {
+      offset++;
+      if (numBytes == 1) {
+        throw new NumberFormatException(toString());
+      }
+    }
+
+    final byte separator = '.';
+    final int radix = 10;
+    final long stopValue = Long.MIN_VALUE / radix;
+    long result = 0;
+
+    while (offset < numBytes) {
+      b = getByte(offset);
+      offset++;
+      if (b == separator) {
+        // We allow decimals and will return a truncated integral in that case.
+        // Therefore we won't throw an exception here (checking the fractional
+        // part happens below.)
+        break;
+      }
+
+      int digit = getDigit(b);
+      // We are going to process the new digit and accumulate the result. 
However, before doing
+      // this, if the result is already smaller than the 
stopValue(Long.MIN_VALUE / radix), then
+      // result * 10 will definitely be smaller than minValue, and we can stop 
and throw exception.
+      if (result < stopValue) {
+        throw new NumberFormatException(toString());
+      }
+
+      result = result * radix - digit;
+      // Since the previous result is less than or equal to 
stopValue(Long.MIN_VALUE / radix), we
+      // can just use `result > 0` to check overflow. If result overflows, we 
should stop and throw
+      // exception.
+      if (result > 0) {
+        throw new NumberFormatException(toString());
+      }
+    }
+
+    // This is the case when we've encountered a decimal separator. The 
fractional
+    // part will not change the number, but we will verify that the fractional 
part
+    // is well formed.
+    while (offset < numBytes) {
+      if (getDigit(getByte(offset)) == -1) {
+        throw new NumberFormatException(toString());
+      }
+      offset++;
+    }
+
+    if (!negative) {
+      result = -result;
+      if (result < 0) {
+        throw new NumberFormatException(toString());
+      }
+    }
+
+    return result;
+  }
+
+  /**
+   * Parses this UTF8String to int.
+   *
+   * Note that, in this method we accumulate the result in negative format, 
and convert it to
+   * positive format at the end, if this string is not started with '-'. This 
is because min value
+   * is bigger than max value in digits, e.g. Integer.MAX_VALUE is 
'2147483647' and
+   * Integer.MIN_VALUE is '-2147483648'.
+   *
+   * This code is mostly copied from LazyInt.parseInt in Hive.
+   *
+   * Note that, this method is almost same as `toLong`, but we leave it 
duplicated for performance
+   * reasons, like Hive does.
+   */
+  public int toInt() {
+    if (numBytes == 0) {
+      throw new NumberFormatException("Empty string");
+    }
+
+    byte b = getByte(0);
+    final boolean negative = b == '-';
+    int offset = 0;
+    if (negative || b == '+') {
+      offset++;
+      if (numBytes == 1) {
+        throw new NumberFormatException(toString());
+      }
+    }
+
+    final byte separator = '.';
+    final int radix = 10;
+    final int stopValue = Integer.MIN_VALUE / radix;
+    int result = 0;
+
+    while (offset < numBytes) {
+      b = getByte(offset);
+      offset++;
+      if (b == separator) {
+        // We allow decimals and will return a truncated integral in that case.
+        // Therefore we won't throw an exception here (checking the fractional
+        // part happens below.)
+        break;
+      }
+
+      int digit = getDigit(b);
+      // We are going to process the new digit and accumulate the result. 
However, before doing
+      // this, if the result is already smaller than the 
stopValue(Integer.MIN_VALUE / radix), then
+      // result * 10 will definitely be smaller than minValue, and we can stop 
and throw exception.
+      if (result < stopValue) {
+        throw new NumberFormatException(toString());
+      }
+
+      result = result * radix - digit;
+      // Since the previous result is less than or equal to 
stopValue(Integer.MIN_VALUE / radix),
+      // we can just use `result > 0` to check overflow. If result overflows, 
we should stop and
+      // throw exception.
+      if (result > 0) {
+        throw new NumberFormatException(toString());
+      }
+    }
+
+    // This is the case when we've encountered a decimal separator. The 
fractional
+    // part will not change the number, but we will verify that the fractional 
part
+    // is well formed.
+    while (offset < numBytes) {
+      if (getDigit(getByte(offset)) == -1) {
+        throw new NumberFormatException(toString());
+      }
+      offset++;
+    }
+
+    if (!negative) {
+      result = -result;
+      if (result < 0) {
+        throw new NumberFormatException(toString());
+      }
+    }
+
+    return result;
+  }
+
+  public short toShort() {
+    int intValue = toInt();
+    short result = (short) intValue;
+    if (result != intValue) {
+      throw new NumberFormatException(toString());
+    }
+
+    return result;
+  }
+
+  public byte toByte() {
+    int intValue = toInt();
+    byte result = (byte) intValue;
+    if (result != intValue) {
+      throw new NumberFormatException(toString());
+    }
+
+    return result;
+  }
+
   @Override
   public String toString() {
     return new String(getBytes(), StandardCharsets.UTF_8);

http://git-wip-us.apache.org/repos/asf/spark/blob/f56819f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 193c3ec..7e84dc8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -51,7 +51,6 @@ object TypeCoercion {
       PromoteStrings ::
       DecimalPrecision ::
       BooleanEquality ::
-      StringToIntegralCasts ::
       FunctionArgumentConversion ::
       CaseWhenCoercion ::
       IfCoercion ::
@@ -426,21 +425,6 @@ object TypeCoercion {
   }
 
   /**
-   * When encountering a cast from a string representing a valid fractional 
number to an integral
-   * type the jvm will throw a `java.lang.NumberFormatException`.  Hive, in 
contrast, returns the
-   * truncated version of this number.
-   */
-  object StringToIntegralCasts extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
-      // Skip nodes who's children have not been resolved yet.
-      case e if !e.childrenResolved => e
-
-      case Cast(e @ StringType(), t: IntegralType) =>
-        Cast(Cast(e, DecimalType.forType(LongType)), t)
-    }
-  }
-
-  /**
    * This ensure that the types for various functions are as expected.
    */
   object FunctionArgumentConversion extends Rule[LogicalPlan] {

http://git-wip-us.apache.org/repos/asf/spark/blob/f56819f9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a53468c..885b394 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -243,7 +243,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
   // LongConverter
   private[this] def castToLong(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, s => try s.toString.toLong catch {
+      buildCast[UTF8String](_, s => try s.toLong catch {
         case _: NumberFormatException => null
       })
     case BooleanType =>
@@ -259,7 +259,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
   // IntConverter
   private[this] def castToInt(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, s => try s.toString.toInt catch {
+      buildCast[UTF8String](_, s => try s.toInt catch {
         case _: NumberFormatException => null
       })
     case BooleanType =>
@@ -275,7 +275,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
   // ShortConverter
   private[this] def castToShort(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, s => try s.toString.toShort catch {
+      buildCast[UTF8String](_, s => try s.toShort catch {
         case _: NumberFormatException => null
       })
     case BooleanType =>
@@ -291,7 +291,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
   // ByteConverter
   private[this] def castToByte(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, s => try s.toString.toByte catch {
+      buildCast[UTF8String](_, s => try s.toByte catch {
         case _: NumberFormatException => null
       })
     case BooleanType =>
@@ -494,7 +494,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
     s"""
       boolean $resultNull = $childNull;
       ${ctx.javaType(resultType)} $resultPrim = 
${ctx.defaultValue(resultType)};
-      if (!${childNull}) {
+      if (!$childNull) {
         ${cast(childPrim, resultPrim, resultNull)}
       }
     """
@@ -701,7 +701,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
       (c, evPrim, evNull) =>
         s"""
           try {
-            $evPrim = Byte.valueOf($c.toString());
+            $evPrim = $c.toByte();
           } catch (java.lang.NumberFormatException e) {
             $evNull = true;
           }
@@ -723,7 +723,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
       (c, evPrim, evNull) =>
         s"""
           try {
-            $evPrim = Short.valueOf($c.toString());
+            $evPrim = $c.toShort();
           } catch (java.lang.NumberFormatException e) {
             $evNull = true;
           }
@@ -745,7 +745,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
       (c, evPrim, evNull) =>
         s"""
           try {
-            $evPrim = Integer.valueOf($c.toString());
+            $evPrim = $c.toInt();
           } catch (java.lang.NumberFormatException e) {
             $evNull = true;
           }
@@ -767,7 +767,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
       (c, evPrim, evNull) =>
         s"""
           try {
-            $evPrim = Long.valueOf($c.toString());
+            $evPrim = $c.toLong();
           } catch (java.lang.NumberFormatException e) {
             $evNull = true;
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/f56819f9/sql/core/src/test/resources/sql-tests/inputs/cast.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql 
b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
new file mode 100644
index 0000000..5fae571
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
@@ -0,0 +1,43 @@
+-- cast string representing a valid fractional number to integral should 
truncate the number
+SELECT CAST('1.23' AS int);
+SELECT CAST('1.23' AS long);
+SELECT CAST('-4.56' AS int);
+SELECT CAST('-4.56' AS long);
+
+-- cast string which are not numbers to integral should return null
+SELECT CAST('abc' AS int);
+SELECT CAST('abc' AS long);
+
+-- cast string representing a very large number to integral should return null
+SELECT CAST('1234567890123' AS int);
+SELECT CAST('12345678901234567890123' AS long);
+
+-- cast empty string to integral should return null
+SELECT CAST('' AS int);
+SELECT CAST('' AS long);
+
+-- cast null to integral should return null
+SELECT CAST(NULL AS int);
+SELECT CAST(NULL AS long);
+
+-- cast invalid decimal string to integral should return null
+SELECT CAST('123.a' AS int);
+SELECT CAST('123.a' AS long);
+
+-- '-2147483648' is the smallest int value
+SELECT CAST('-2147483648' AS int);
+SELECT CAST('-2147483649' AS int);
+
+-- '2147483647' is the largest int value
+SELECT CAST('2147483647' AS int);
+SELECT CAST('2147483648' AS int);
+
+-- '-9223372036854775808' is the smallest long value
+SELECT CAST('-9223372036854775808' AS long);
+SELECT CAST('-9223372036854775809' AS long);
+
+-- '9223372036854775807' is the largest long value
+SELECT CAST('9223372036854775807' AS long);
+SELECT CAST('9223372036854775808' AS long);
+
+-- TODO: migrate all cast tests here.

http://git-wip-us.apache.org/repos/asf/spark/blob/f56819f9/sql/core/src/test/resources/sql-tests/results/cast.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out 
b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
new file mode 100644
index 0000000..bfa29d7
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
@@ -0,0 +1,178 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 22
+
+
+-- !query 0
+SELECT CAST('1.23' AS int)
+-- !query 0 schema
+struct<CAST(1.23 AS INT):int>
+-- !query 0 output
+1
+
+
+-- !query 1
+SELECT CAST('1.23' AS long)
+-- !query 1 schema
+struct<CAST(1.23 AS BIGINT):bigint>
+-- !query 1 output
+1
+
+
+-- !query 2
+SELECT CAST('-4.56' AS int)
+-- !query 2 schema
+struct<CAST(-4.56 AS INT):int>
+-- !query 2 output
+-4
+
+
+-- !query 3
+SELECT CAST('-4.56' AS long)
+-- !query 3 schema
+struct<CAST(-4.56 AS BIGINT):bigint>
+-- !query 3 output
+-4
+
+
+-- !query 4
+SELECT CAST('abc' AS int)
+-- !query 4 schema
+struct<CAST(abc AS INT):int>
+-- !query 4 output
+NULL
+
+
+-- !query 5
+SELECT CAST('abc' AS long)
+-- !query 5 schema
+struct<CAST(abc AS BIGINT):bigint>
+-- !query 5 output
+NULL
+
+
+-- !query 6
+SELECT CAST('1234567890123' AS int)
+-- !query 6 schema
+struct<CAST(1234567890123 AS INT):int>
+-- !query 6 output
+NULL
+
+
+-- !query 7
+SELECT CAST('12345678901234567890123' AS long)
+-- !query 7 schema
+struct<CAST(12345678901234567890123 AS BIGINT):bigint>
+-- !query 7 output
+NULL
+
+
+-- !query 8
+SELECT CAST('' AS int)
+-- !query 8 schema
+struct<CAST( AS INT):int>
+-- !query 8 output
+NULL
+
+
+-- !query 9
+SELECT CAST('' AS long)
+-- !query 9 schema
+struct<CAST( AS BIGINT):bigint>
+-- !query 9 output
+NULL
+
+
+-- !query 10
+SELECT CAST(NULL AS int)
+-- !query 10 schema
+struct<CAST(NULL AS INT):int>
+-- !query 10 output
+NULL
+
+
+-- !query 11
+SELECT CAST(NULL AS long)
+-- !query 11 schema
+struct<CAST(NULL AS BIGINT):bigint>
+-- !query 11 output
+NULL
+
+
+-- !query 12
+SELECT CAST('123.a' AS int)
+-- !query 12 schema
+struct<CAST(123.a AS INT):int>
+-- !query 12 output
+NULL
+
+
+-- !query 13
+SELECT CAST('123.a' AS long)
+-- !query 13 schema
+struct<CAST(123.a AS BIGINT):bigint>
+-- !query 13 output
+NULL
+
+
+-- !query 14
+SELECT CAST('-2147483648' AS int)
+-- !query 14 schema
+struct<CAST(-2147483648 AS INT):int>
+-- !query 14 output
+-2147483648
+
+
+-- !query 15
+SELECT CAST('-2147483649' AS int)
+-- !query 15 schema
+struct<CAST(-2147483649 AS INT):int>
+-- !query 15 output
+NULL
+
+
+-- !query 16
+SELECT CAST('2147483647' AS int)
+-- !query 16 schema
+struct<CAST(2147483647 AS INT):int>
+-- !query 16 output
+2147483647
+
+
+-- !query 17
+SELECT CAST('2147483648' AS int)
+-- !query 17 schema
+struct<CAST(2147483648 AS INT):int>
+-- !query 17 output
+NULL
+
+
+-- !query 18
+SELECT CAST('-9223372036854775808' AS long)
+-- !query 18 schema
+struct<CAST(-9223372036854775808 AS BIGINT):bigint>
+-- !query 18 output
+-9223372036854775808
+
+
+-- !query 19
+SELECT CAST('-9223372036854775809' AS long)
+-- !query 19 schema
+struct<CAST(-9223372036854775809 AS BIGINT):bigint>
+-- !query 19 output
+NULL
+
+
+-- !query 20
+SELECT CAST('9223372036854775807' AS long)
+-- !query 20 schema
+struct<CAST(9223372036854775807 AS BIGINT):bigint>
+-- !query 20 output
+9223372036854775807
+
+
+-- !query 21
+SELECT CAST('9223372036854775808' AS long)
+-- !query 21 schema
+struct<CAST(9223372036854775808 AS BIGINT):bigint>
+-- !query 21 output
+NULL


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to