You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/09/11 23:15:21 UTC

spark git commit: [SPARK-10442] [SQL] fix string to boolean cast

Repository: spark
Updated Branches:
  refs/heads/master c37386677 -> d5d647380


[SPARK-10442] [SQL] fix string to boolean cast

When we cast string to boolean in hive, it returns `true` if the length of string is > 0, and spark SQL follows this behavior.

However, this behavior is very different from other SQL systems:

1. [presto](https://github.com/facebook/presto/blob/master/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java#L89-L118) will return `true` for 't' 'true' '1', `false` for 'f' 'false' '0', throw exception for others.
2. [redshift](http://docs.aws.amazon.com/redshift/latest/dg/r_Boolean_type.html) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others.
3. [postgresql](http://www.postgresql.org/docs/devel/static/datatype-boolean.html) will return `true` for 't' 'true' 'y' 'yes' 'on' '1', `false` for 'f' 'false' 'n' 'no' 'off' '0', throw exception for others.
4. [vertica](https://my.vertica.com/docs/5.0/HTML/Master/2983.htm) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others.
5. [impala](http://www.cloudera.com/content/cloudera/en/documentation/cloudera-impala/latest/topics/impala_boolean.html) throw exception when try to cast string to boolean.
6. mysql, oracle, sqlserver don't have boolean type

Whether we should change the cast behavior according to other SQL system or not is not decided yet, this PR is a test to see if we changed, how many compatibility tests will fail.

Author: Wenchen Fan <cl...@outlook.com>

Closes #8698 from cloud-fan/string2boolean.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d5d64738
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d5d64738
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d5d64738

Branch: refs/heads/master
Commit: d5d647380f93f4773f9cb85ea6544892d409b5a1
Parents: c373866
Author: Wenchen Fan <cl...@outlook.com>
Authored: Fri Sep 11 14:15:16 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Fri Sep 11 14:15:16 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 24 +++++++-
 .../spark/sql/catalyst/util/StringUtils.scala   |  8 +++
 .../sql/catalyst/expressions/CastSuite.scala    | 61 +++++++++++++-------
 .../sql/sources/hadoopFsRelationSuites.scala    | 13 +++++
 4 files changed, 82 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d5d64738/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 2db9542..f0bce38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 
@@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType)
   // UDFToBoolean
   private[this] def castToBoolean(from: DataType): Any => Any = from match {
     case StringType =>
-      buildCast[UTF8String](_, _.numBytes() != 0)
+      buildCast[UTF8String](_, s => {
+        if (StringUtils.isTrueString(s)) {
+          true
+        } else if (StringUtils.isFalseString(s)) {
+          false
+        } else {
+          null
+        }
+      })
     case TimestampType =>
       buildCast[Long](_, t => t != 0)
     case DateType =>
@@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType)
 
   private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
     case StringType =>
-      (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
+      val stringUtils = StringUtils.getClass.getName.stripSuffix("$")
+      (c, evPrim, evNull) =>
+        s"""
+          if ($stringUtils.isTrueString($c)) {
+            $evPrim = true;
+          } else if ($stringUtils.isFalseString($c)) {
+            $evPrim = false;
+          } else {
+            $evNull = true;
+          }
+        """
     case TimestampType =>
       (c, evPrim, evNull) => s"$evPrim = $c != 0;"
     case DateType =>

http://git-wip-us.apache.org/repos/asf/spark/blob/d5d64738/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index 9ddfb3a..c2eeb3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util
 
 import java.util.regex.Pattern
 
+import org.apache.spark.unsafe.types.UTF8String
+
 object StringUtils {
 
   // replace the _ with .{1} exactly match 1 time of any character
@@ -44,4 +46,10 @@ object StringUtils {
       v
     }
   }
+
+  private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
+  private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
+
+  def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
+  def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d5d64738/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 1ad7073..f4db4da 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("cast from array") {
-    val array = Literal.create(Seq("123", "abc", "", null),
+    val array = Literal.create(Seq("123", "true", "f", null),
       ArrayType(StringType, containsNull = true))
-    val array_notNull = Literal.create(Seq("123", "abc", ""),
+    val array_notNull = Literal.create(Seq("123", "true", "f"),
       ArrayType(StringType, containsNull = false))
 
     checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
@@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     {
       val ret = cast(array, ArrayType(BooleanType, containsNull = true))
       assert(ret.resolved === true)
-      checkEvaluation(ret, Seq(true, true, false, null))
+      checkEvaluation(ret, Seq(null, true, false, null))
     }
     {
       val ret = cast(array, ArrayType(BooleanType, containsNull = false))
@@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     {
       val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
       assert(ret.resolved === true)
-      checkEvaluation(ret, Seq(true, true, false))
+      checkEvaluation(ret, Seq(null, true, false))
     }
     {
       val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
       assert(ret.resolved === true)
-      checkEvaluation(ret, Seq(true, true, false))
+      checkEvaluation(ret, Seq(null, true, false))
     }
 
     {
@@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
 
   test("cast from map") {
     val map = Literal.create(
-      Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
+      Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
       MapType(StringType, StringType, valueContainsNull = true))
     val map_notNull = Literal.create(
-      Map("a" -> "123", "b" -> "abc", "c" -> ""),
+      Map("a" -> "123", "b" -> "true", "c" -> "f"),
       MapType(StringType, StringType, valueContainsNull = false))
 
     checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
@@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     {
       val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
       assert(ret.resolved === true)
-      checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
+      checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null))
     }
     {
       val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
@@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     {
       val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
       assert(ret.resolved === true)
-      checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
+      checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
     }
     {
       val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
       assert(ret.resolved === true)
-      checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
+      checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
     }
     {
       val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
@@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     val struct = Literal.create(
       InternalRow(
         UTF8String.fromString("123"),
-        UTF8String.fromString("abc"),
-        UTF8String.fromString(""),
+        UTF8String.fromString("true"),
+        UTF8String.fromString("f"),
         null),
       StructType(Seq(
         StructField("a", StringType, nullable = true),
@@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     val struct_notNull = Literal.create(
       InternalRow(
         UTF8String.fromString("123"),
-        UTF8String.fromString("abc"),
-        UTF8String.fromString("")),
+        UTF8String.fromString("true"),
+        UTF8String.fromString("f")),
       StructType(Seq(
         StructField("a", StringType, nullable = false),
         StructField("b", StringType, nullable = false),
@@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
         StructField("c", BooleanType, nullable = true),
         StructField("d", BooleanType, nullable = true))))
       assert(ret.resolved === true)
-      checkEvaluation(ret, InternalRow(true, true, false, null))
+      checkEvaluation(ret, InternalRow(null, true, false, null))
     }
     {
       val ret = cast(struct, StructType(Seq(
@@ -704,7 +704,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
         StructField("b", BooleanType, nullable = true),
         StructField("c", BooleanType, nullable = true))))
       assert(ret.resolved === true)
-      checkEvaluation(ret, InternalRow(true, true, false))
+      checkEvaluation(ret, InternalRow(null, true, false))
     }
     {
       val ret = cast(struct_notNull, StructType(Seq(
@@ -712,7 +712,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
         StructField("b", BooleanType, nullable = true),
         StructField("c", BooleanType, nullable = false))))
       assert(ret.resolved === true)
-      checkEvaluation(ret, InternalRow(true, true, false))
+      checkEvaluation(ret, InternalRow(null, true, false))
     }
 
     {
@@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("complex casting") {
     val complex = Literal.create(
       Row(
-        Seq("123", "abc", ""),
-        Map("a" ->"123", "b" -> "abc", "c" -> ""),
+        Seq("123", "true", "f"),
+        Map("a" ->"123", "b" -> "true", "c" -> "f"),
         Row(0)),
       StructType(Seq(
         StructField("a",
@@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
     assert(ret.resolved === true)
     checkEvaluation(ret, Row(
       Seq(123, null, null),
-      Map("a" -> true, "b" -> true, "c" -> false),
+      Map("a" -> null, "b" -> true, "c" -> false),
       Row(0L)))
   }
 
-  test("case between string and interval") {
+  test("cast between string and interval") {
     import org.apache.spark.unsafe.types.CalendarInterval
 
     checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType),
@@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
       StringType),
       "interval 1 years 3 months -3 days")
   }
+
+  test("cast string to boolean") {
+    checkCast("t", true)
+    checkCast("true", true)
+    checkCast("tRUe", true)
+    checkCast("y", true)
+    checkCast("yes", true)
+    checkCast("1", true)
+
+    checkCast("f", false)
+    checkCast("false", false)
+    checkCast("FAlsE", false)
+    checkCast("n", false)
+    checkCast("no", false)
+    checkCast("0", false)
+
+    checkEvaluation(cast("abc", BooleanType), null)
+    checkEvaluation(cast("", BooleanType), null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d5d64738/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 13223c6..8ffcef8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
     }
   }
 
+  test("saveAsTable()/load() - partitioned table - boolean type") {
+    sqlContext.range(2)
+      .select('id, ('id % 2 === 0).as("b"))
+      .write.partitionBy("b").saveAsTable("t")
+
+    withTable("t") {
+      checkAnswer(
+        sqlContext.table("t").sort('id),
+        Row(0, true) :: Row(1, false) :: Nil
+      )
+    }
+  }
+
   test("saveAsTable()/load() - partitioned table - Overwrite") {
     partitionedTestDF.write
       .format(dataSourceName)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org