You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/01/13 06:52:41 UTC
spark git commit: [SPARK-19178][SQL] convert string of large numbers
to int should return null
Repository: spark
Updated Branches:
refs/heads/master 7f24a0b6c -> 6b34e745b
[SPARK-19178][SQL] convert string of large numbers to int should return null
## What changes were proposed in this pull request?
When we convert a string to integral, we will convert that string to `decimal(20, 0)` first, so that we can turn a string with decimal format to truncated integral, e.g. `CAST('1.2' AS int)` will return `1`.
However, this brings problems when we convert a string with large numbers to integral, e.g. `CAST('1234567890123' AS int)` will return `1912276171`, while Hive returns null as we expected.
This is a long standing bug(seems it was there the first day Spark SQL was created), this PR fixes this bug by adding the native support to convert `UTF8String` to integral.
## How was this patch tested?
new regression tests
Author: Wenchen Fan <we...@databricks.com>
Closes #16550 from cloud-fan/string-to-int.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6b34e745
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6b34e745
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6b34e745
Branch: refs/heads/master
Commit: 6b34e745bb8bdcf5a8bb78359fa39bbe8c6563cc
Parents: 7f24a0b
Author: Wenchen Fan <we...@databricks.com>
Authored: Thu Jan 12 22:52:34 2017 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Jan 12 22:52:34 2017 -0800
----------------------------------------------------------------------
.../apache/spark/unsafe/types/UTF8String.java | 184 +++++++++++++++++++
.../sql/catalyst/analysis/TypeCoercion.scala | 16 --
.../spark/sql/catalyst/expressions/Cast.scala | 18 +-
.../test/resources/sql-tests/inputs/cast.sql | 43 +++++
.../resources/sql-tests/results/cast.sql.out | 178 ++++++++++++++++++
5 files changed, 414 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6b34e745/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 0255f53..3800d53 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -835,6 +835,190 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
return fromString(sb.toString());
}
+ private int getDigit(byte b) {
+ if (b >= '0' && b <= '9') {
+ return b - '0';
+ }
+ throw new NumberFormatException(toString());
+ }
+
+ /**
+ * Parses this UTF8String to long.
+ *
+ * Note that, in this method we accumulate the result in negative format, and convert it to
+ * positive format at the end, if this string is not started with '-'. This is because min value
+ * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
+ * Integer.MIN_VALUE is '-2147483648'.
+ *
+ * This code is mostly copied from LazyLong.parseLong in Hive.
+ */
+ public long toLong() {
+ if (numBytes == 0) {
+ throw new NumberFormatException("Empty string");
+ }
+
+ byte b = getByte(0);
+ final boolean negative = b == '-';
+ int offset = 0;
+ if (negative || b == '+') {
+ offset++;
+ if (numBytes == 1) {
+ throw new NumberFormatException(toString());
+ }
+ }
+
+ final byte separator = '.';
+ final int radix = 10;
+ final long stopValue = Long.MIN_VALUE / radix;
+ long result = 0;
+
+ while (offset < numBytes) {
+ b = getByte(offset);
+ offset++;
+ if (b == separator) {
+ // We allow decimals and will return a truncated integral in that case.
+ // Therefore we won't throw an exception here (checking the fractional
+ // part happens below.)
+ break;
+ }
+
+ int digit = getDigit(b);
+ // We are going to process the new digit and accumulate the result. However, before doing
+ // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
+ // result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
+ if (result < stopValue) {
+ throw new NumberFormatException(toString());
+ }
+
+ result = result * radix - digit;
+ // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
+ // can just use `result > 0` to check overflow. If result overflows, we should stop and throw
+ // exception.
+ if (result > 0) {
+ throw new NumberFormatException(toString());
+ }
+ }
+
+ // This is the case when we've encountered a decimal separator. The fractional
+ // part will not change the number, but we will verify that the fractional part
+ // is well formed.
+ while (offset < numBytes) {
+ if (getDigit(getByte(offset)) == -1) {
+ throw new NumberFormatException(toString());
+ }
+ offset++;
+ }
+
+ if (!negative) {
+ result = -result;
+ if (result < 0) {
+ throw new NumberFormatException(toString());
+ }
+ }
+
+ return result;
+ }
+
+ /**
+ * Parses this UTF8String to int.
+ *
+ * Note that, in this method we accumulate the result in negative format, and convert it to
+ * positive format at the end, if this string is not started with '-'. This is because min value
+ * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
+ * Integer.MIN_VALUE is '-2147483648'.
+ *
+ * This code is mostly copied from LazyInt.parseInt in Hive.
+ *
+ * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
+ * reasons, like Hive does.
+ */
+ public int toInt() {
+ if (numBytes == 0) {
+ throw new NumberFormatException("Empty string");
+ }
+
+ byte b = getByte(0);
+ final boolean negative = b == '-';
+ int offset = 0;
+ if (negative || b == '+') {
+ offset++;
+ if (numBytes == 1) {
+ throw new NumberFormatException(toString());
+ }
+ }
+
+ final byte separator = '.';
+ final int radix = 10;
+ final int stopValue = Integer.MIN_VALUE / radix;
+ int result = 0;
+
+ while (offset < numBytes) {
+ b = getByte(offset);
+ offset++;
+ if (b == separator) {
+ // We allow decimals and will return a truncated integral in that case.
+ // Therefore we won't throw an exception here (checking the fractional
+ // part happens below.)
+ break;
+ }
+
+ int digit = getDigit(b);
+ // We are going to process the new digit and accumulate the result. However, before doing
+ // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
+ // result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
+ if (result < stopValue) {
+ throw new NumberFormatException(toString());
+ }
+
+ result = result * radix - digit;
+ // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
+ // we can just use `result > 0` to check overflow. If result overflows, we should stop and
+ // throw exception.
+ if (result > 0) {
+ throw new NumberFormatException(toString());
+ }
+ }
+
+ // This is the case when we've encountered a decimal separator. The fractional
+ // part will not change the number, but we will verify that the fractional part
+ // is well formed.
+ while (offset < numBytes) {
+ if (getDigit(getByte(offset)) == -1) {
+ throw new NumberFormatException(toString());
+ }
+ offset++;
+ }
+
+ if (!negative) {
+ result = -result;
+ if (result < 0) {
+ throw new NumberFormatException(toString());
+ }
+ }
+
+ return result;
+ }
+
+ public short toShort() {
+ int intValue = toInt();
+ short result = (short) intValue;
+ if (result != intValue) {
+ throw new NumberFormatException(toString());
+ }
+
+ return result;
+ }
+
+ public byte toByte() {
+ int intValue = toInt();
+ byte result = (byte) intValue;
+ if (result != intValue) {
+ throw new NumberFormatException(toString());
+ }
+
+ return result;
+ }
+
@Override
public String toString() {
return new String(getBytes(), StandardCharsets.UTF_8);
http://git-wip-us.apache.org/repos/asf/spark/blob/6b34e745/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index cd73f9c..5f72fa8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -51,7 +51,6 @@ object TypeCoercion {
PromoteStrings ::
DecimalPrecision ::
BooleanEquality ::
- StringToIntegralCasts ::
FunctionArgumentConversion ::
CaseWhenCoercion ::
IfCoercion ::
@@ -429,21 +428,6 @@ object TypeCoercion {
}
/**
- * When encountering a cast from a string representing a valid fractional number to an integral
- * type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the
- * truncated version of this number.
- */
- object StringToIntegralCasts extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
- // Skip nodes who's children have not been resolved yet.
- case e if !e.childrenResolved => e
-
- case Cast(e @ StringType(), t: IntegralType) =>
- Cast(Cast(e, DecimalType.forType(LongType)), t)
- }
- }
-
- /**
* This ensure that the types for various functions are as expected.
*/
object FunctionArgumentConversion extends Rule[LogicalPlan] {
http://git-wip-us.apache.org/repos/asf/spark/blob/6b34e745/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 741730e..14e275b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -247,7 +247,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toString.toLong catch {
+ buildCast[UTF8String](_, s => try s.toLong catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -263,7 +263,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// IntConverter
private[this] def castToInt(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toString.toInt catch {
+ buildCast[UTF8String](_, s => try s.toInt catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -279,7 +279,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ShortConverter
private[this] def castToShort(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toString.toShort catch {
+ buildCast[UTF8String](_, s => try s.toShort catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -295,7 +295,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
// ByteConverter
private[this] def castToByte(from: DataType): Any => Any = from match {
case StringType =>
- buildCast[UTF8String](_, s => try s.toString.toByte catch {
+ buildCast[UTF8String](_, s => try s.toByte catch {
case _: NumberFormatException => null
})
case BooleanType =>
@@ -498,7 +498,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
s"""
boolean $resultNull = $childNull;
${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
- if (!${childNull}) {
+ if (!$childNull) {
${cast(childPrim, resultPrim, resultNull)}
}
"""
@@ -705,7 +705,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
(c, evPrim, evNull) =>
s"""
try {
- $evPrim = Byte.valueOf($c.toString());
+ $evPrim = $c.toByte();
} catch (java.lang.NumberFormatException e) {
$evNull = true;
}
@@ -727,7 +727,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
(c, evPrim, evNull) =>
s"""
try {
- $evPrim = Short.valueOf($c.toString());
+ $evPrim = $c.toShort();
} catch (java.lang.NumberFormatException e) {
$evNull = true;
}
@@ -749,7 +749,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
(c, evPrim, evNull) =>
s"""
try {
- $evPrim = Integer.valueOf($c.toString());
+ $evPrim = $c.toInt();
} catch (java.lang.NumberFormatException e) {
$evNull = true;
}
@@ -771,7 +771,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
(c, evPrim, evNull) =>
s"""
try {
- $evPrim = Long.valueOf($c.toString());
+ $evPrim = $c.toLong();
} catch (java.lang.NumberFormatException e) {
$evNull = true;
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6b34e745/sql/core/src/test/resources/sql-tests/inputs/cast.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cast.sql b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
new file mode 100644
index 0000000..5fae571
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/cast.sql
@@ -0,0 +1,43 @@
+-- cast string representing a valid fractional number to integral should truncate the number
+SELECT CAST('1.23' AS int);
+SELECT CAST('1.23' AS long);
+SELECT CAST('-4.56' AS int);
+SELECT CAST('-4.56' AS long);
+
+-- cast string which are not numbers to integral should return null
+SELECT CAST('abc' AS int);
+SELECT CAST('abc' AS long);
+
+-- cast string representing a very large number to integral should return null
+SELECT CAST('1234567890123' AS int);
+SELECT CAST('12345678901234567890123' AS long);
+
+-- cast empty string to integral should return null
+SELECT CAST('' AS int);
+SELECT CAST('' AS long);
+
+-- cast null to integral should return null
+SELECT CAST(NULL AS int);
+SELECT CAST(NULL AS long);
+
+-- cast invalid decimal string to integral should return null
+SELECT CAST('123.a' AS int);
+SELECT CAST('123.a' AS long);
+
+-- '-2147483648' is the smallest int value
+SELECT CAST('-2147483648' AS int);
+SELECT CAST('-2147483649' AS int);
+
+-- '2147483647' is the largest int value
+SELECT CAST('2147483647' AS int);
+SELECT CAST('2147483648' AS int);
+
+-- '-9223372036854775808' is the smallest long value
+SELECT CAST('-9223372036854775808' AS long);
+SELECT CAST('-9223372036854775809' AS long);
+
+-- '9223372036854775807' is the largest long value
+SELECT CAST('9223372036854775807' AS long);
+SELECT CAST('9223372036854775808' AS long);
+
+-- TODO: migrate all cast tests here.
http://git-wip-us.apache.org/repos/asf/spark/blob/6b34e745/sql/core/src/test/resources/sql-tests/results/cast.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/cast.sql.out b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
new file mode 100644
index 0000000..bfa29d7
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/cast.sql.out
@@ -0,0 +1,178 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 22
+
+
+-- !query 0
+SELECT CAST('1.23' AS int)
+-- !query 0 schema
+struct<CAST(1.23 AS INT):int>
+-- !query 0 output
+1
+
+
+-- !query 1
+SELECT CAST('1.23' AS long)
+-- !query 1 schema
+struct<CAST(1.23 AS BIGINT):bigint>
+-- !query 1 output
+1
+
+
+-- !query 2
+SELECT CAST('-4.56' AS int)
+-- !query 2 schema
+struct<CAST(-4.56 AS INT):int>
+-- !query 2 output
+-4
+
+
+-- !query 3
+SELECT CAST('-4.56' AS long)
+-- !query 3 schema
+struct<CAST(-4.56 AS BIGINT):bigint>
+-- !query 3 output
+-4
+
+
+-- !query 4
+SELECT CAST('abc' AS int)
+-- !query 4 schema
+struct<CAST(abc AS INT):int>
+-- !query 4 output
+NULL
+
+
+-- !query 5
+SELECT CAST('abc' AS long)
+-- !query 5 schema
+struct<CAST(abc AS BIGINT):bigint>
+-- !query 5 output
+NULL
+
+
+-- !query 6
+SELECT CAST('1234567890123' AS int)
+-- !query 6 schema
+struct<CAST(1234567890123 AS INT):int>
+-- !query 6 output
+NULL
+
+
+-- !query 7
+SELECT CAST('12345678901234567890123' AS long)
+-- !query 7 schema
+struct<CAST(12345678901234567890123 AS BIGINT):bigint>
+-- !query 7 output
+NULL
+
+
+-- !query 8
+SELECT CAST('' AS int)
+-- !query 8 schema
+struct<CAST( AS INT):int>
+-- !query 8 output
+NULL
+
+
+-- !query 9
+SELECT CAST('' AS long)
+-- !query 9 schema
+struct<CAST( AS BIGINT):bigint>
+-- !query 9 output
+NULL
+
+
+-- !query 10
+SELECT CAST(NULL AS int)
+-- !query 10 schema
+struct<CAST(NULL AS INT):int>
+-- !query 10 output
+NULL
+
+
+-- !query 11
+SELECT CAST(NULL AS long)
+-- !query 11 schema
+struct<CAST(NULL AS BIGINT):bigint>
+-- !query 11 output
+NULL
+
+
+-- !query 12
+SELECT CAST('123.a' AS int)
+-- !query 12 schema
+struct<CAST(123.a AS INT):int>
+-- !query 12 output
+NULL
+
+
+-- !query 13
+SELECT CAST('123.a' AS long)
+-- !query 13 schema
+struct<CAST(123.a AS BIGINT):bigint>
+-- !query 13 output
+NULL
+
+
+-- !query 14
+SELECT CAST('-2147483648' AS int)
+-- !query 14 schema
+struct<CAST(-2147483648 AS INT):int>
+-- !query 14 output
+-2147483648
+
+
+-- !query 15
+SELECT CAST('-2147483649' AS int)
+-- !query 15 schema
+struct<CAST(-2147483649 AS INT):int>
+-- !query 15 output
+NULL
+
+
+-- !query 16
+SELECT CAST('2147483647' AS int)
+-- !query 16 schema
+struct<CAST(2147483647 AS INT):int>
+-- !query 16 output
+2147483647
+
+
+-- !query 17
+SELECT CAST('2147483648' AS int)
+-- !query 17 schema
+struct<CAST(2147483648 AS INT):int>
+-- !query 17 output
+NULL
+
+
+-- !query 18
+SELECT CAST('-9223372036854775808' AS long)
+-- !query 18 schema
+struct<CAST(-9223372036854775808 AS BIGINT):bigint>
+-- !query 18 output
+-9223372036854775808
+
+
+-- !query 19
+SELECT CAST('-9223372036854775809' AS long)
+-- !query 19 schema
+struct<CAST(-9223372036854775809 AS BIGINT):bigint>
+-- !query 19 output
+NULL
+
+
+-- !query 20
+SELECT CAST('9223372036854775807' AS long)
+-- !query 20 schema
+struct<CAST(9223372036854775807 AS BIGINT):bigint>
+-- !query 20 output
+9223372036854775807
+
+
+-- !query 21
+SELECT CAST('9223372036854775808' AS long)
+-- !query 21 schema
+struct<CAST(9223372036854775808 AS BIGINT):bigint>
+-- !query 21 output
+NULL
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org