You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/03 06:45:28 UTC

spark git commit: [SPARK-8801][SQL] Support TypeCollection in ExpectsInputTypes

Repository: spark
Updated Branches:
  refs/heads/master 20a4d7dbd -> a59d14f62


[SPARK-8801][SQL] Support TypeCollection in ExpectsInputTypes

This patch adds a new TypeCollection AbstractDataType that can be used by expressions to specify more than one expected input types.

Author: Reynold Xin <rx...@databricks.com>

Closes #7202 from rxin/type-collection and squashes the following commits:

c714ca1 [Reynold Xin] Fixed style.
a0c0d12 [Reynold Xin] Fixed bugs and unit tests.
d8b8ae7 [Reynold Xin] Added TypeCollection.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a59d14f6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a59d14f6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a59d14f6

Branch: refs/heads/master
Commit: a59d14f623633c7aef97991341b587c11ca42328
Parents: 20a4d7d
Author: Reynold Xin <rx...@databricks.com>
Authored: Thu Jul 2 21:45:25 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jul 2 21:45:25 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    | 47 ++++++++++++++---
 .../spark/sql/types/AbstractDataType.scala      | 50 ++++++++++++++----
 .../org/apache/spark/sql/types/ArrayType.scala  |  6 ++-
 .../org/apache/spark/sql/types/DataType.scala   |  4 +-
 .../apache/spark/sql/types/DecimalType.scala    |  4 ++
 .../org/apache/spark/sql/types/MapType.scala    |  4 ++
 .../org/apache/spark/sql/types/StructType.scala |  8 ++-
 .../analysis/HiveTypeCoercionSuite.scala        | 55 +++++++++++++-------
 8 files changed, 140 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a59d14f6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 0bc8932..6006e7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
+import javax.annotation.Nullable
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -713,39 +715,68 @@ object HiveTypeCoercion {
 
       case e: ExpectsInputTypes =>
         val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
-          implicitCast(in, expected)
+          // If we cannot do the implicit cast, just use the original input.
+          implicitCast(in, expected).getOrElse(in)
         }
         e.withNewChildren(children)
     }
 
     /**
-     * If needed, cast the expression into the expected type.
-     * If the implicit cast is not allowed, return the expression itself.
+     * Given an expected data type, try to cast the expression and return the cast expression.
+     *
+     * If the expression already fits the input type, we simply return the expression itself.
+     * If the expression has an incompatible type that cannot be implicitly cast, return None.
      */
-    def implicitCast(e: Expression, expectedType: AbstractDataType): Expression = {
+    def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = {
       val inType = e.dataType
-      (inType, expectedType) match {
+
+      // Note that ret is nullable to avoid typing a lot of Some(...) in this local scope.
+      // We wrap immediately an Option after this.
+      @Nullable val ret: Expression = (inType, expectedType) match {
+
+        // If the expected type is already a parent of the input type, no need to cast.
+        case _ if expectedType.isParentOf(inType) => e
+
         // Cast null type (usually from null literals) into target types
-        case (NullType, target: DataType) => Cast(e, target.defaultConcreteType)
+        case (NullType, target) => Cast(e, target.defaultConcreteType)
 
         // Implicit cast among numeric types
+        // If input is decimal, and we expect a decimal type, just use the input.
+        case (_: DecimalType, DecimalType) => e
+        // If input is a numeric type but not decimal, and we expect a decimal type,
+        // cast the input to unlimited precision decimal.
+        case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
+          Cast(e, DecimalType.Unlimited)
+        // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
         case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
+        case (_: NumericType, target: NumericType) => e
 
         // Implicit cast between date time types
         case (DateType, TimestampType) => Cast(e, TimestampType)
         case (TimestampType, DateType) => Cast(e, DateType)
 
         // Implicit cast from/to string
-        case (StringType, NumericType) => Cast(e, DoubleType)
+        case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited)
         case (StringType, target: NumericType) => Cast(e, target)
         case (StringType, DateType) => Cast(e, DateType)
         case (StringType, TimestampType) => Cast(e, TimestampType)
         case (StringType, BinaryType) => Cast(e, BinaryType)
         case (any, StringType) if any != StringType => Cast(e, StringType)
 
+        // Type collection.
+        // First see if we can find our input type in the type collection. If we can, then just
+        // use the current expression; otherwise, find the first one we can implicitly cast.
+        case (_, TypeCollection(types)) =>
+          if (types.exists(_.isParentOf(inType))) {
+            e
+          } else {
+            types.flatMap(implicitCast(e, _)).headOption.orNull
+          }
+
         // Else, just return the same input expression
-        case _ => e
+        case _ => null
       }
+      Option(ret)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a59d14f6/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 43e2f8a..e5dc99f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -28,7 +28,45 @@ import org.apache.spark.util.Utils
  * A non-concrete data type, reserved for internal uses.
  */
 private[sql] abstract class AbstractDataType {
+  /**
+   * The default concrete type to use if we want to cast a null literal into this type.
+   */
   private[sql] def defaultConcreteType: DataType
+
+  /**
+   * Returns true if this data type is a parent of the `childCandidate`.
+   */
+  private[sql] def isParentOf(childCandidate: DataType): Boolean
+}
+
+
+/**
+ * A collection of types that can be used to specify type constraints. The sequence also specifies
+ * precedence: an earlier type takes precedence over a latter type.
+ *
+ * {{{
+ *   TypeCollection(StringType, BinaryType)
+ * }}}
+ *
+ * This means that we prefer StringType over BinaryType if it is possible to cast to StringType.
+ */
+private[sql] class TypeCollection(private val types: Seq[DataType]) extends AbstractDataType {
+  require(types.nonEmpty, s"TypeCollection ($types) cannot be empty")
+
+  private[sql] override def defaultConcreteType: DataType = types.head
+
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = false
+}
+
+
+private[sql] object TypeCollection {
+
+  def apply(types: DataType*): TypeCollection = new TypeCollection(types)
+
+  def unapply(typ: AbstractDataType): Option[Seq[DataType]] = typ match {
+    case typ: TypeCollection => Some(typ.types)
+    case _ => None
+  }
 }
 
 
@@ -61,7 +99,7 @@ abstract class NumericType extends AtomicType {
 }
 
 
-private[sql] object NumericType extends AbstractDataType {
+private[sql] object NumericType {
   /**
    * Enables matching against NumericType for expressions:
    * {{{
@@ -70,12 +108,10 @@ private[sql] object NumericType extends AbstractDataType {
    * }}}
    */
   def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
-
-  private[sql] override def defaultConcreteType: DataType = IntegerType
 }
 
 
-private[sql] object IntegralType extends AbstractDataType {
+private[sql] object IntegralType {
   /**
    * Enables matching against IntegralType for expressions:
    * {{{
@@ -84,8 +120,6 @@ private[sql] object IntegralType extends AbstractDataType {
    * }}}
    */
   def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
-
-  private[sql] override def defaultConcreteType: DataType = IntegerType
 }
 
 
@@ -94,7 +128,7 @@ private[sql] abstract class IntegralType extends NumericType {
 }
 
 
-private[sql] object FractionalType extends AbstractDataType {
+private[sql] object FractionalType {
   /**
    * Enables matching against FractionalType for expressions:
    * {{{
@@ -103,8 +137,6 @@ private[sql] object FractionalType extends AbstractDataType {
    * }}}
    */
   def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[FractionalType]
-
-  private[sql] override def defaultConcreteType: DataType = DoubleType
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a59d14f6/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 81553e7..8ea6cb1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -26,7 +26,11 @@ object ArrayType extends AbstractDataType {
   /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
   def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true)
 
-  override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
+  private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
+
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+    childCandidate.isInstanceOf[ArrayType]
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a59d14f6/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index c333fa7..7d00047 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -75,7 +75,9 @@ abstract class DataType extends AbstractDataType {
    */
   private[spark] def asNullable: DataType
 
-  override def defaultConcreteType: DataType = this
+  private[sql] override def defaultConcreteType: DataType = this
+
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a59d14f6/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 06373a0..434fc03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -86,6 +86,10 @@ object DecimalType extends AbstractDataType {
 
   private[sql] override def defaultConcreteType: DataType = Unlimited
 
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+    childCandidate.isInstanceOf[DecimalType]
+  }
+
   val Unlimited: DecimalType = DecimalType(None)
 
   private[sql] object Fixed {

http://git-wip-us.apache.org/repos/asf/spark/blob/a59d14f6/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index 69c2119..2b25617 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -71,6 +71,10 @@ object MapType extends AbstractDataType {
 
   private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType)
 
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+    childCandidate.isInstanceOf[MapType]
+  }
+
   /**
    * Construct a [[MapType]] object with the given key type and value type.
    * The `valueContainsNull` is true.

http://git-wip-us.apache.org/repos/asf/spark/blob/a59d14f6/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 6fedeab..7e77b77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -301,7 +301,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
 }
 
 
-object StructType {
+object StructType extends AbstractDataType {
+
+  private[sql] override def defaultConcreteType: DataType = new StructType
+
+  private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
+    childCandidate.isInstanceOf[StructType]
+  }
 
   def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a59d14f6/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 498fd86..60e727c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -27,28 +27,47 @@ import org.apache.spark.sql.types._
 class HiveTypeCoercionSuite extends PlanTest {
 
   test("implicit type cast") {
-    def shouldCast(from: DataType, to: AbstractDataType): Unit = {
+    def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
       val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
-      assert(got.dataType === to.defaultConcreteType)
+      assert(got.map(_.dataType) == Option(expected),
+        s"Failed to cast $from to $to")
     }
 
+    shouldCast(NullType, NullType, NullType)
+    shouldCast(NullType, IntegerType, IntegerType)
+    shouldCast(NullType, DecimalType, DecimalType.Unlimited)
+
     // TODO: write the entire implicit cast table out for test cases.
-    shouldCast(ByteType, IntegerType)
-    shouldCast(IntegerType, IntegerType)
-    shouldCast(IntegerType, LongType)
-    shouldCast(IntegerType, DecimalType.Unlimited)
-    shouldCast(LongType, IntegerType)
-    shouldCast(LongType, DecimalType.Unlimited)
-
-    shouldCast(DateType, TimestampType)
-    shouldCast(TimestampType, DateType)
-
-    shouldCast(StringType, IntegerType)
-    shouldCast(StringType, DateType)
-    shouldCast(StringType, TimestampType)
-    shouldCast(IntegerType, StringType)
-    shouldCast(DateType, StringType)
-    shouldCast(TimestampType, StringType)
+    shouldCast(ByteType, IntegerType, IntegerType)
+    shouldCast(IntegerType, IntegerType, IntegerType)
+    shouldCast(IntegerType, LongType, LongType)
+    shouldCast(IntegerType, DecimalType, DecimalType.Unlimited)
+    shouldCast(LongType, IntegerType, IntegerType)
+    shouldCast(LongType, DecimalType, DecimalType.Unlimited)
+
+    shouldCast(DateType, TimestampType, TimestampType)
+    shouldCast(TimestampType, DateType, DateType)
+
+    shouldCast(StringType, IntegerType, IntegerType)
+    shouldCast(StringType, DateType, DateType)
+    shouldCast(StringType, TimestampType, TimestampType)
+    shouldCast(IntegerType, StringType, StringType)
+    shouldCast(DateType, StringType, StringType)
+    shouldCast(TimestampType, StringType, StringType)
+
+    shouldCast(StringType, BinaryType, BinaryType)
+    shouldCast(BinaryType, StringType, StringType)
+
+    shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType)
+
+    shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType)
+    shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType)
+    shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType)
+
+    shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType)
+    shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
+    shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
+    shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
   }
 
   test("tightest common bound for types") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org