You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/03/08 04:19:34 UTC

spark git commit: [SPARK-19843][SQL] UTF8String => (int / long) conversion expensive for invalid inputs

Repository: spark
Updated Branches:
  refs/heads/master 47b2f68a8 -> c96d14aba


[SPARK-19843][SQL] UTF8String => (int / long) conversion expensive for invalid inputs

## What changes were proposed in this pull request?

Jira : https://issues.apache.org/jira/browse/SPARK-19843

Created wrapper classes (`IntWrapper`, `LongWrapper`) to wrap the result of parsing (which are primitive types). In case of problem in parsing, the method would return a boolean.

## How was this patch tested?

- Added new unit tests
- Ran a prod job which had conversion from string -> int and verified the outputs

## Performance

Tiny regression when all strings are valid integers

```
conversion to int:       Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
--------------------------------------------------------------------------------
trunk                         502 /  522         33.4          29.9       1.0X
SPARK-19843                   493 /  503         34.0          29.4       1.0X
```

Huge gain when all strings are invalid integers
```
conversion to int:      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
-------------------------------------------------------------------------------
trunk                     33913 / 34219          0.5        2021.4       1.0X
SPARK-19843                  154 /  162        108.8           9.2     220.0X
```

Author: Tejas Patil <te...@fb.com>

Closes #17184 from tejasapatil/SPARK-19843_is_numeric_maybe.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c96d14ab
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c96d14ab
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c96d14ab

Branch: refs/heads/master
Commit: c96d14abae5962a7b15239319c2a151b95f7db94
Parents: 47b2f68
Author: Tejas Patil <te...@fb.com>
Authored: Tue Mar 7 20:19:30 2017 -0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Mar 7 20:19:30 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/unsafe/types/UTF8String.java   | 120 ++++++++++-------
 .../spark/unsafe/types/UTF8StringSuite.java     | 128 ++++++++++++++++++-
 .../spark/sql/catalyst/expressions/Cast.scala   |  81 +++++++-----
 3 files changed, 247 insertions(+), 82 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c96d14ab/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 10a7cb1..7abe0fa 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -850,11 +850,8 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     return fromString(sb.toString());
   }
 
-  private int getDigit(byte b) {
-    if (b >= '0' && b <= '9') {
-      return b - '0';
-    }
-    throw new NumberFormatException(toString());
+  public static class LongWrapper {
+    public long value = 0;
   }
 
   /**
@@ -862,14 +859,18 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
    *
    * Note that, in this method we accumulate the result in negative format, and convert it to
    * positive format at the end, if this string is not started with '-'. This is because min value
-   * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
-   * Integer.MIN_VALUE is '-2147483648'.
+   * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and
+   * Long.MIN_VALUE is '-9223372036854775808'.
    *
    * This code is mostly copied from LazyLong.parseLong in Hive.
+   *
+   * @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would
+   *                     be set in `toLongResult`
+   * @return true if the parsing was successful else false
    */
-  public long toLong() {
+  public boolean toLong(LongWrapper toLongResult) {
     if (numBytes == 0) {
-      throw new NumberFormatException("Empty string");
+      return false;
     }
 
     byte b = getByte(0);
@@ -878,7 +879,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     if (negative || b == '+') {
       offset++;
       if (numBytes == 1) {
-        throw new NumberFormatException(toString());
+        return false;
       }
     }
 
@@ -897,20 +898,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
         break;
       }
 
-      int digit = getDigit(b);
+      int digit;
+      if (b >= '0' && b <= '9') {
+        digit = b - '0';
+      } else {
+        return false;
+      }
+
       // We are going to process the new digit and accumulate the result. However, before doing
       // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
-      // result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
+      // result * 10 will definitely be smaller than minValue, and we can stop.
       if (result < stopValue) {
-        throw new NumberFormatException(toString());
+        return false;
       }
 
       result = result * radix - digit;
       // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
-      // can just use `result > 0` to check overflow. If result overflows, we should stop and throw
-      // exception.
+      // can just use `result > 0` to check overflow. If result overflows, we should stop.
       if (result > 0) {
-        throw new NumberFormatException(toString());
+        return false;
       }
     }
 
@@ -918,8 +924,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     // part will not change the number, but we will verify that the fractional part
     // is well formed.
     while (offset < numBytes) {
-      if (getDigit(getByte(offset)) == -1) {
-        throw new NumberFormatException(toString());
+      byte currentByte = getByte(offset);
+      if (currentByte < '0' || currentByte > '9') {
+        return false;
       }
       offset++;
     }
@@ -927,11 +934,16 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     if (!negative) {
       result = -result;
       if (result < 0) {
-        throw new NumberFormatException(toString());
+        return false;
       }
     }
 
-    return result;
+    toLongResult.value = result;
+    return true;
+  }
+
+  public static class IntWrapper {
+    public int value = 0;
   }
 
   /**
@@ -946,10 +958,14 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
    *
    * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
    * reasons, like Hive does.
+   *
+   * @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would
+   *                    be set in `intWrapper`
+   * @return true if the parsing was successful else false
    */
-  public int toInt() {
+  public boolean toInt(IntWrapper intWrapper) {
     if (numBytes == 0) {
-      throw new NumberFormatException("Empty string");
+      return false;
     }
 
     byte b = getByte(0);
@@ -958,7 +974,7 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     if (negative || b == '+') {
       offset++;
       if (numBytes == 1) {
-        throw new NumberFormatException(toString());
+        return false;
       }
     }
 
@@ -977,20 +993,25 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
         break;
       }
 
-      int digit = getDigit(b);
+      int digit;
+      if (b >= '0' && b <= '9') {
+        digit = b - '0';
+      } else {
+        return false;
+      }
+
       // We are going to process the new digit and accumulate the result. However, before doing
       // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
-      // result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
+      // result * 10 will definitely be smaller than minValue, and we can stop
       if (result < stopValue) {
-        throw new NumberFormatException(toString());
+        return false;
       }
 
       result = result * radix - digit;
       // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
-      // we can just use `result > 0` to check overflow. If result overflows, we should stop and
-      // throw exception.
+      // we can just use `result > 0` to check overflow. If result overflows, we should stop
       if (result > 0) {
-        throw new NumberFormatException(toString());
+        return false;
       }
     }
 
@@ -998,8 +1019,9 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     // part will not change the number, but we will verify that the fractional part
     // is well formed.
     while (offset < numBytes) {
-      if (getDigit(getByte(offset)) == -1) {
-        throw new NumberFormatException(toString());
+      byte currentByte = getByte(offset);
+      if (currentByte < '0' || currentByte > '9') {
+        return false;
       }
       offset++;
     }
@@ -1007,31 +1029,33 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
     if (!negative) {
       result = -result;
       if (result < 0) {
-        throw new NumberFormatException(toString());
+        return false;
       }
     }
-
-    return result;
+    intWrapper.value = result;
+    return true;
   }
 
-  public short toShort() {
-    int intValue = toInt();
-    short result = (short) intValue;
-    if (result != intValue) {
-      throw new NumberFormatException(toString());
+  public boolean toShort(IntWrapper intWrapper) {
+    if (toInt(intWrapper)) {
+      int intValue = intWrapper.value;
+      short result = (short) intValue;
+      if (result == intValue) {
+        return true;
+      }
     }
-
-    return result;
+    return false;
   }
 
-  public byte toByte() {
-    int intValue = toInt();
-    byte result = (byte) intValue;
-    if (result != intValue) {
-      throw new NumberFormatException(toString());
+  public boolean toByte(IntWrapper intWrapper) {
+    if (toInt(intWrapper)) {
+      int intValue = intWrapper.value;
+      byte result = (byte) intValue;
+      if (result == intValue) {
+        return true;
+      }
     }
-
-    return result;
+    return false;
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/c96d14ab/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 6f6e0ef..c376371 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -22,9 +22,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.HashSet;
+import java.util.*;
 
 import com.google.common.collect.ImmutableMap;
 import org.apache.spark.unsafe.Platform;
@@ -608,4 +606,128 @@ public class UTF8StringSuite {
         .writeTo(outputStream);
     assertEquals("\u5927\u5343\u4e16\u754c", outputStream.toString("UTF-8"));
   }
+
+  @Test
+  public void testToShort() throws IOException {
+    Map<String, Short> inputToExpectedOutput = new HashMap<>();
+    inputToExpectedOutput.put("1", (short) 1);
+    inputToExpectedOutput.put("+1", (short) 1);
+    inputToExpectedOutput.put("-1", (short) -1);
+    inputToExpectedOutput.put("0", (short) 0);
+    inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111);
+    inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE);
+    inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE);
+
+    Random rand = new Random();
+    for (int i = 0; i < 10; i++) {
+      short value = (short) rand.nextInt();
+      inputToExpectedOutput.put(String.valueOf(value), value);
+    }
+
+    IntWrapper wrapper = new IntWrapper();
+    for (Map.Entry<String, Short> entry : inputToExpectedOutput.entrySet()) {
+      assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper));
+      assertEquals((short) entry.getValue(), wrapper.value);
+    }
+
+    List<String> negativeInputs =
+      Arrays.asList("", "  ", "null", "NULL", "\n", "~1212121", "3276700");
+
+    for (String negativeInput : negativeInputs) {
+      assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper));
+    }
+  }
+
+  @Test
+  public void testToByte() throws IOException {
+    Map<String, Byte> inputToExpectedOutput = new HashMap<>();
+    inputToExpectedOutput.put("1", (byte) 1);
+    inputToExpectedOutput.put("+1",(byte)  1);
+    inputToExpectedOutput.put("-1", (byte)  -1);
+    inputToExpectedOutput.put("0", (byte)  0);
+    inputToExpectedOutput.put("111.12345678901234567890", (byte) 111);
+    inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE);
+    inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE);
+
+    Random rand = new Random();
+    for (int i = 0; i < 10; i++) {
+      byte value = (byte) rand.nextInt();
+      inputToExpectedOutput.put(String.valueOf(value), value);
+    }
+
+    IntWrapper intWrapper = new IntWrapper();
+    for (Map.Entry<String, Byte> entry : inputToExpectedOutput.entrySet()) {
+      assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper));
+      assertEquals((byte) entry.getValue(), intWrapper.value);
+    }
+
+    List<String> negativeInputs =
+      Arrays.asList("", "  ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
+
+    for (String negativeInput : negativeInputs) {
+      assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper));
+    }
+  }
+
+  @Test
+  public void testToInt() throws IOException {
+    Map<String, Integer> inputToExpectedOutput = new HashMap<>();
+    inputToExpectedOutput.put("1", 1);
+    inputToExpectedOutput.put("+1", 1);
+    inputToExpectedOutput.put("-1", -1);
+    inputToExpectedOutput.put("0", 0);
+    inputToExpectedOutput.put("11111.1234567", 11111);
+    inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE);
+    inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE);
+
+    Random rand = new Random();
+    for (int i = 0; i < 10; i++) {
+      int value = rand.nextInt();
+      inputToExpectedOutput.put(String.valueOf(value), value);
+    }
+
+    IntWrapper intWrapper = new IntWrapper();
+    for (Map.Entry<String, Integer> entry : inputToExpectedOutput.entrySet()) {
+      assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper));
+      assertEquals((int) entry.getValue(), intWrapper.value);
+    }
+
+    List<String> negativeInputs =
+      Arrays.asList("", "  ", "null", "NULL", "\n", "~1212121", "12345678901234567890");
+
+    for (String negativeInput : negativeInputs) {
+      assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper));
+    }
+  }
+
+  @Test
+  public void testToLong() throws IOException {
+    Map<String, Long> inputToExpectedOutput = new HashMap<>();
+    inputToExpectedOutput.put("1", 1L);
+    inputToExpectedOutput.put("+1", 1L);
+    inputToExpectedOutput.put("-1", -1L);
+    inputToExpectedOutput.put("0", 0L);
+    inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L);
+    inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE);
+    inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE);
+
+    Random rand = new Random();
+    for (int i = 0; i < 10; i++) {
+      long value = rand.nextLong();
+      inputToExpectedOutput.put(String.valueOf(value), value);
+    }
+
+    LongWrapper wrapper = new LongWrapper();
+    for (Map.Entry<String, Long> entry : inputToExpectedOutput.entrySet()) {
+      assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper));
+      assertEquals((long) entry.getValue(), wrapper.value);
+    }
+
+    List<String> negativeInputs = Arrays.asList("", "  ", "null", "NULL", "\n", "~1212121",
+        "1234567890123456789012345678901234");
+
+    for (String negativeInput : negativeInputs) {
+      assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper));
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c96d14ab/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a36d350..7c60f7d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-
+import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
 
 object Cast {
 
@@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   // LongConverter
   private[this] def castToLong(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, s => try s.toLong catch {
-        case _: NumberFormatException => null
-      })
+      val result = new LongWrapper()
+      buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
     case BooleanType =>
       buildCast[Boolean](_, b => if (b) 1L else 0L)
     case DateType =>
@@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   // IntConverter
   private[this] def castToInt(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, s => try s.toInt catch {
-        case _: NumberFormatException => null
-      })
+      val result = new IntWrapper()
+      buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
     case BooleanType =>
       buildCast[Boolean](_, b => if (b) 1 else 0)
     case DateType =>
@@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   // ShortConverter
   private[this] def castToShort(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, s => try s.toShort catch {
-        case _: NumberFormatException => null
+      val result = new IntWrapper()
+      buildCast[UTF8String](_, s => if (s.toShort(result)) {
+        result.value.toShort
+      } else {
+        null
       })
     case BooleanType =>
       buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
@@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   // ByteConverter
   private[this] def castToByte(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, s => try s.toByte catch {
-        case _: NumberFormatException => null
+      val result = new IntWrapper()
+      buildCast[UTF8String](_, s => if (s.toByte(result)) {
+        result.value.toByte
+      } else {
+        null
       })
     case BooleanType =>
       buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
@@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
     case TimestampType => castToTimestampCode(from, ctx)
     case CalendarIntervalType => castToIntervalCode(from)
     case BooleanType => castToBooleanCode(from)
-    case ByteType => castToByteCode(from)
-    case ShortType => castToShortCode(from)
-    case IntegerType => castToIntCode(from)
+    case ByteType => castToByteCode(from, ctx)
+    case ShortType => castToShortCode(from, ctx)
+    case IntegerType => castToIntCode(from, ctx)
     case FloatType => castToFloatCode(from)
-    case LongType => castToLongCode(from)
+    case LongType => castToLongCode(from, ctx)
     case DoubleType => castToDoubleCode(from)
 
     case array: ArrayType =>
@@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
       (c, evPrim, evNull) => s"$evPrim = $c != 0;"
   }
 
-  private[this] def castToByteCode(from: DataType): CastFunction = from match {
+  private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType =>
+      val wrapper = ctx.freshName("wrapper")
+      ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+        s"$wrapper = new UTF8String.IntWrapper();")
       (c, evPrim, evNull) =>
         s"""
-          try {
-            $evPrim = $c.toByte();
-          } catch (java.lang.NumberFormatException e) {
+          if ($c.toByte($wrapper)) {
+            $evPrim = (byte) $wrapper.value;
+          } else {
             $evNull = true;
           }
         """
@@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
       (c, evPrim, evNull) => s"$evPrim = (byte) $c;"
   }
 
-  private[this] def castToShortCode(from: DataType): CastFunction = from match {
+  private[this] def castToShortCode(
+      from: DataType,
+      ctx: CodegenContext): CastFunction = from match {
     case StringType =>
+      val wrapper = ctx.freshName("wrapper")
+      ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+        s"$wrapper = new UTF8String.IntWrapper();")
       (c, evPrim, evNull) =>
         s"""
-          try {
-            $evPrim = $c.toShort();
-          } catch (java.lang.NumberFormatException e) {
+          if ($c.toShort($wrapper)) {
+            $evPrim = (short) $wrapper.value;
+          } else {
             $evNull = true;
           }
         """
@@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
       (c, evPrim, evNull) => s"$evPrim = (short) $c;"
   }
 
-  private[this] def castToIntCode(from: DataType): CastFunction = from match {
+  private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType =>
+      val wrapper = ctx.freshName("wrapper")
+      ctx.addMutableState("UTF8String.IntWrapper", wrapper,
+        s"$wrapper = new UTF8String.IntWrapper();")
       (c, evPrim, evNull) =>
         s"""
-          try {
-            $evPrim = $c.toInt();
-          } catch (java.lang.NumberFormatException e) {
+          if ($c.toInt($wrapper)) {
+            $evPrim = $wrapper.value;
+          } else {
             $evNull = true;
           }
         """
@@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
       (c, evPrim, evNull) => s"$evPrim = (int) $c;"
   }
 
-  private[this] def castToLongCode(from: DataType): CastFunction = from match {
+  private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
     case StringType =>
+      val wrapper = ctx.freshName("wrapper")
+      ctx.addMutableState("UTF8String.LongWrapper", wrapper,
+        s"$wrapper = new UTF8String.LongWrapper();")
+
       (c, evPrim, evNull) =>
         s"""
-          try {
-            $evPrim = $c.toLong();
-          } catch (java.lang.NumberFormatException e) {
+          if ($c.toLong($wrapper)) {
+            $evPrim = $wrapper.value;
+          } else {
             $evNull = true;
           }
         """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org